pysunny commited on
Commit
baeb61b
1 Parent(s): f245392

Updated files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +1 -1
  2. sd/stable-diffusion-webui/CODEOWNERS +12 -12
  3. sd/stable-diffusion-webui/LICENSE.txt +663 -663
  4. sd/stable-diffusion-webui/README.md +162 -0
  5. sd/stable-diffusion-webui/extensions-builtin/LDSR/preload.py +6 -6
  6. sd/stable-diffusion-webui/extensions-builtin/Lora/extra_networks_lora.py +26 -26
  7. sd/stable-diffusion-webui/extensions-builtin/Lora/lora.py +207 -207
  8. sd/stable-diffusion-webui/extensions-builtin/Lora/preload.py +6 -6
  9. sd/stable-diffusion-webui/extensions-builtin/Lora/scripts/lora_script.py +38 -38
  10. sd/stable-diffusion-webui/extensions-builtin/Lora/ui_extra_networks_lora.py +30 -37
  11. sd/stable-diffusion-webui/extensions-builtin/ScuNET/preload.py +6 -6
  12. sd/stable-diffusion-webui/extensions-builtin/SwinIR/preload.py +6 -6
  13. sd/stable-diffusion-webui/extensions-builtin/SwinIR/swinir_model_arch_v2.py +1016 -1016
  14. sd/stable-diffusion-webui/html/extra-networks-card.html +1 -0
  15. sd/stable-diffusion-webui/html/footer.html +13 -13
  16. sd/stable-diffusion-webui/html/licenses.html +638 -419
  17. sd/stable-diffusion-webui/javascript/aspectRatioOverlay.js +113 -113
  18. sd/stable-diffusion-webui/javascript/contextMenus.js +177 -177
  19. sd/stable-diffusion-webui/javascript/edit-attention.js +95 -95
  20. sd/stable-diffusion-webui/javascript/extensions.js +49 -49
  21. sd/stable-diffusion-webui/javascript/extraNetworks.js +106 -106
  22. sd/stable-diffusion-webui/javascript/hints.js +1 -0
  23. sd/stable-diffusion-webui/javascript/hires_fix.js +22 -22
  24. sd/stable-diffusion-webui/javascript/localization.js +165 -165
  25. sd/stable-diffusion-webui/javascript/notification.js +1 -1
  26. sd/stable-diffusion-webui/javascript/progressbar.js +1 -1
  27. sd/stable-diffusion-webui/javascript/textualInversion.js +17 -17
  28. sd/stable-diffusion-webui/launch.py +375 -361
  29. sd/stable-diffusion-webui/modules/api/api.py +28 -17
  30. sd/stable-diffusion-webui/modules/api/models.py +24 -4
  31. sd/stable-diffusion-webui/modules/call_queue.py +109 -109
  32. sd/stable-diffusion-webui/modules/codeformer_model.py +143 -143
  33. sd/stable-diffusion-webui/modules/deepbooru_model.py +678 -678
  34. sd/stable-diffusion-webui/modules/errors.py +43 -43
  35. sd/stable-diffusion-webui/modules/esrgan_model.py +233 -233
  36. sd/stable-diffusion-webui/modules/esrgan_model_arch.py +464 -464
  37. sd/stable-diffusion-webui/modules/extensions.py +107 -107
  38. sd/stable-diffusion-webui/modules/extra_networks.py +147 -147
  39. sd/stable-diffusion-webui/modules/extra_networks_hypernet.py +27 -27
  40. sd/stable-diffusion-webui/modules/extras.py +258 -258
  41. sd/stable-diffusion-webui/modules/face_restoration.py +19 -19
  42. sd/stable-diffusion-webui/modules/generation_parameters_copypaste.py +408 -402
  43. sd/stable-diffusion-webui/modules/gfpgan_model.py +116 -116
  44. sd/stable-diffusion-webui/modules/hashes.py +91 -91
  45. sd/stable-diffusion-webui/modules/hypernetworks/hypernetwork.py +811 -811
  46. sd/stable-diffusion-webui/modules/hypernetworks/ui.py +40 -40
  47. sd/stable-diffusion-webui/modules/images.py +669 -669
  48. sd/stable-diffusion-webui/modules/img2img.py +184 -184
  49. sd/stable-diffusion-webui/modules/interrogate.py +227 -227
  50. sd/stable-diffusion-webui/modules/localization.py +37 -37
README.md CHANGED
@@ -3,7 +3,7 @@ license: apache-2.0
3
  title: Automatic Stable Diffusion
4
  sdk: gradio
5
  sdk_version: 3.16.2
6
- app_file: sd/stable-diffusion-webui/webui.py
7
  emoji: 🚀
8
  colorFrom: indigo
9
  colorTo: purple
 
3
  title: Automatic Stable Diffusion
4
  sdk: gradio
5
  sdk_version: 3.16.2
6
+ app_file: sd/stable-diffusion-webui/webui.py --api
7
  emoji: 🚀
8
  colorFrom: indigo
9
  colorTo: purple
sd/stable-diffusion-webui/CODEOWNERS CHANGED
@@ -1,12 +1,12 @@
1
- * @AUTOMATIC1111
2
-
3
- # if you were managing a localization and were removed from this file, this is because
4
- # the intended way to do localizations now is via extensions. See:
5
- # https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions
6
- # Make a repo with your localization and since you are still listed as a collaborator
7
- # you can add it to the wiki page yourself. This change is because some people complained
8
- # the git commit log is cluttered with things unrelated to almost everyone and
9
- # because I believe this is the best overall for the project to handle localizations almost
10
- # entirely without my oversight.
11
-
12
-
 
1
+ * @AUTOMATIC1111
2
+
3
+ # if you were managing a localization and were removed from this file, this is because
4
+ # the intended way to do localizations now is via extensions. See:
5
+ # https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions
6
+ # Make a repo with your localization and since you are still listed as a collaborator
7
+ # you can add it to the wiki page yourself. This change is because some people complained
8
+ # the git commit log is cluttered with things unrelated to almost everyone and
9
+ # because I believe this is the best overall for the project to handle localizations almost
10
+ # entirely without my oversight.
11
+
12
+
sd/stable-diffusion-webui/LICENSE.txt CHANGED
@@ -1,663 +1,663 @@
1
- GNU AFFERO GENERAL PUBLIC LICENSE
2
- Version 3, 19 November 2007
3
-
4
- Copyright (c) 2023 AUTOMATIC1111
5
-
6
- Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
7
- Everyone is permitted to copy and distribute verbatim copies
8
- of this license document, but changing it is not allowed.
9
-
10
- Preamble
11
-
12
- The GNU Affero General Public License is a free, copyleft license for
13
- software and other kinds of works, specifically designed to ensure
14
- cooperation with the community in the case of network server software.
15
-
16
- The licenses for most software and other practical works are designed
17
- to take away your freedom to share and change the works. By contrast,
18
- our General Public Licenses are intended to guarantee your freedom to
19
- share and change all versions of a program--to make sure it remains free
20
- software for all its users.
21
-
22
- When we speak of free software, we are referring to freedom, not
23
- price. Our General Public Licenses are designed to make sure that you
24
- have the freedom to distribute copies of free software (and charge for
25
- them if you wish), that you receive source code or can get it if you
26
- want it, that you can change the software or use pieces of it in new
27
- free programs, and that you know you can do these things.
28
-
29
- Developers that use our General Public Licenses protect your rights
30
- with two steps: (1) assert copyright on the software, and (2) offer
31
- you this License which gives you legal permission to copy, distribute
32
- and/or modify the software.
33
-
34
- A secondary benefit of defending all users' freedom is that
35
- improvements made in alternate versions of the program, if they
36
- receive widespread use, become available for other developers to
37
- incorporate. Many developers of free software are heartened and
38
- encouraged by the resulting cooperation. However, in the case of
39
- software used on network servers, this result may fail to come about.
40
- The GNU General Public License permits making a modified version and
41
- letting the public access it on a server without ever releasing its
42
- source code to the public.
43
-
44
- The GNU Affero General Public License is designed specifically to
45
- ensure that, in such cases, the modified source code becomes available
46
- to the community. It requires the operator of a network server to
47
- provide the source code of the modified version running there to the
48
- users of that server. Therefore, public use of a modified version, on
49
- a publicly accessible server, gives the public access to the source
50
- code of the modified version.
51
-
52
- An older license, called the Affero General Public License and
53
- published by Affero, was designed to accomplish similar goals. This is
54
- a different license, not a version of the Affero GPL, but Affero has
55
- released a new version of the Affero GPL which permits relicensing under
56
- this license.
57
-
58
- The precise terms and conditions for copying, distribution and
59
- modification follow.
60
-
61
- TERMS AND CONDITIONS
62
-
63
- 0. Definitions.
64
-
65
- "This License" refers to version 3 of the GNU Affero General Public License.
66
-
67
- "Copyright" also means copyright-like laws that apply to other kinds of
68
- works, such as semiconductor masks.
69
-
70
- "The Program" refers to any copyrightable work licensed under this
71
- License. Each licensee is addressed as "you". "Licensees" and
72
- "recipients" may be individuals or organizations.
73
-
74
- To "modify" a work means to copy from or adapt all or part of the work
75
- in a fashion requiring copyright permission, other than the making of an
76
- exact copy. The resulting work is called a "modified version" of the
77
- earlier work or a work "based on" the earlier work.
78
-
79
- A "covered work" means either the unmodified Program or a work based
80
- on the Program.
81
-
82
- To "propagate" a work means to do anything with it that, without
83
- permission, would make you directly or secondarily liable for
84
- infringement under applicable copyright law, except executing it on a
85
- computer or modifying a private copy. Propagation includes copying,
86
- distribution (with or without modification), making available to the
87
- public, and in some countries other activities as well.
88
-
89
- To "convey" a work means any kind of propagation that enables other
90
- parties to make or receive copies. Mere interaction with a user through
91
- a computer network, with no transfer of a copy, is not conveying.
92
-
93
- An interactive user interface displays "Appropriate Legal Notices"
94
- to the extent that it includes a convenient and prominently visible
95
- feature that (1) displays an appropriate copyright notice, and (2)
96
- tells the user that there is no warranty for the work (except to the
97
- extent that warranties are provided), that licensees may convey the
98
- work under this License, and how to view a copy of this License. If
99
- the interface presents a list of user commands or options, such as a
100
- menu, a prominent item in the list meets this criterion.
101
-
102
- 1. Source Code.
103
-
104
- The "source code" for a work means the preferred form of the work
105
- for making modifications to it. "Object code" means any non-source
106
- form of a work.
107
-
108
- A "Standard Interface" means an interface that either is an official
109
- standard defined by a recognized standards body, or, in the case of
110
- interfaces specified for a particular programming language, one that
111
- is widely used among developers working in that language.
112
-
113
- The "System Libraries" of an executable work include anything, other
114
- than the work as a whole, that (a) is included in the normal form of
115
- packaging a Major Component, but which is not part of that Major
116
- Component, and (b) serves only to enable use of the work with that
117
- Major Component, or to implement a Standard Interface for which an
118
- implementation is available to the public in source code form. A
119
- "Major Component", in this context, means a major essential component
120
- (kernel, window system, and so on) of the specific operating system
121
- (if any) on which the executable work runs, or a compiler used to
122
- produce the work, or an object code interpreter used to run it.
123
-
124
- The "Corresponding Source" for a work in object code form means all
125
- the source code needed to generate, install, and (for an executable
126
- work) run the object code and to modify the work, including scripts to
127
- control those activities. However, it does not include the work's
128
- System Libraries, or general-purpose tools or generally available free
129
- programs which are used unmodified in performing those activities but
130
- which are not part of the work. For example, Corresponding Source
131
- includes interface definition files associated with source files for
132
- the work, and the source code for shared libraries and dynamically
133
- linked subprograms that the work is specifically designed to require,
134
- such as by intimate data communication or control flow between those
135
- subprograms and other parts of the work.
136
-
137
- The Corresponding Source need not include anything that users
138
- can regenerate automatically from other parts of the Corresponding
139
- Source.
140
-
141
- The Corresponding Source for a work in source code form is that
142
- same work.
143
-
144
- 2. Basic Permissions.
145
-
146
- All rights granted under this License are granted for the term of
147
- copyright on the Program, and are irrevocable provided the stated
148
- conditions are met. This License explicitly affirms your unlimited
149
- permission to run the unmodified Program. The output from running a
150
- covered work is covered by this License only if the output, given its
151
- content, constitutes a covered work. This License acknowledges your
152
- rights of fair use or other equivalent, as provided by copyright law.
153
-
154
- You may make, run and propagate covered works that you do not
155
- convey, without conditions so long as your license otherwise remains
156
- in force. You may convey covered works to others for the sole purpose
157
- of having them make modifications exclusively for you, or provide you
158
- with facilities for running those works, provided that you comply with
159
- the terms of this License in conveying all material for which you do
160
- not control copyright. Those thus making or running the covered works
161
- for you must do so exclusively on your behalf, under your direction
162
- and control, on terms that prohibit them from making any copies of
163
- your copyrighted material outside their relationship with you.
164
-
165
- Conveying under any other circumstances is permitted solely under
166
- the conditions stated below. Sublicensing is not allowed; section 10
167
- makes it unnecessary.
168
-
169
- 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
170
-
171
- No covered work shall be deemed part of an effective technological
172
- measure under any applicable law fulfilling obligations under article
173
- 11 of the WIPO copyright treaty adopted on 20 December 1996, or
174
- similar laws prohibiting or restricting circumvention of such
175
- measures.
176
-
177
- When you convey a covered work, you waive any legal power to forbid
178
- circumvention of technological measures to the extent such circumvention
179
- is effected by exercising rights under this License with respect to
180
- the covered work, and you disclaim any intention to limit operation or
181
- modification of the work as a means of enforcing, against the work's
182
- users, your or third parties' legal rights to forbid circumvention of
183
- technological measures.
184
-
185
- 4. Conveying Verbatim Copies.
186
-
187
- You may convey verbatim copies of the Program's source code as you
188
- receive it, in any medium, provided that you conspicuously and
189
- appropriately publish on each copy an appropriate copyright notice;
190
- keep intact all notices stating that this License and any
191
- non-permissive terms added in accord with section 7 apply to the code;
192
- keep intact all notices of the absence of any warranty; and give all
193
- recipients a copy of this License along with the Program.
194
-
195
- You may charge any price or no price for each copy that you convey,
196
- and you may offer support or warranty protection for a fee.
197
-
198
- 5. Conveying Modified Source Versions.
199
-
200
- You may convey a work based on the Program, or the modifications to
201
- produce it from the Program, in the form of source code under the
202
- terms of section 4, provided that you also meet all of these conditions:
203
-
204
- a) The work must carry prominent notices stating that you modified
205
- it, and giving a relevant date.
206
-
207
- b) The work must carry prominent notices stating that it is
208
- released under this License and any conditions added under section
209
- 7. This requirement modifies the requirement in section 4 to
210
- "keep intact all notices".
211
-
212
- c) You must license the entire work, as a whole, under this
213
- License to anyone who comes into possession of a copy. This
214
- License will therefore apply, along with any applicable section 7
215
- additional terms, to the whole of the work, and all its parts,
216
- regardless of how they are packaged. This License gives no
217
- permission to license the work in any other way, but it does not
218
- invalidate such permission if you have separately received it.
219
-
220
- d) If the work has interactive user interfaces, each must display
221
- Appropriate Legal Notices; however, if the Program has interactive
222
- interfaces that do not display Appropriate Legal Notices, your
223
- work need not make them do so.
224
-
225
- A compilation of a covered work with other separate and independent
226
- works, which are not by their nature extensions of the covered work,
227
- and which are not combined with it such as to form a larger program,
228
- in or on a volume of a storage or distribution medium, is called an
229
- "aggregate" if the compilation and its resulting copyright are not
230
- used to limit the access or legal rights of the compilation's users
231
- beyond what the individual works permit. Inclusion of a covered work
232
- in an aggregate does not cause this License to apply to the other
233
- parts of the aggregate.
234
-
235
- 6. Conveying Non-Source Forms.
236
-
237
- You may convey a covered work in object code form under the terms
238
- of sections 4 and 5, provided that you also convey the
239
- machine-readable Corresponding Source under the terms of this License,
240
- in one of these ways:
241
-
242
- a) Convey the object code in, or embodied in, a physical product
243
- (including a physical distribution medium), accompanied by the
244
- Corresponding Source fixed on a durable physical medium
245
- customarily used for software interchange.
246
-
247
- b) Convey the object code in, or embodied in, a physical product
248
- (including a physical distribution medium), accompanied by a
249
- written offer, valid for at least three years and valid for as
250
- long as you offer spare parts or customer support for that product
251
- model, to give anyone who possesses the object code either (1) a
252
- copy of the Corresponding Source for all the software in the
253
- product that is covered by this License, on a durable physical
254
- medium customarily used for software interchange, for a price no
255
- more than your reasonable cost of physically performing this
256
- conveying of source, or (2) access to copy the
257
- Corresponding Source from a network server at no charge.
258
-
259
- c) Convey individual copies of the object code with a copy of the
260
- written offer to provide the Corresponding Source. This
261
- alternative is allowed only occasionally and noncommercially, and
262
- only if you received the object code with such an offer, in accord
263
- with subsection 6b.
264
-
265
- d) Convey the object code by offering access from a designated
266
- place (gratis or for a charge), and offer equivalent access to the
267
- Corresponding Source in the same way through the same place at no
268
- further charge. You need not require recipients to copy the
269
- Corresponding Source along with the object code. If the place to
270
- copy the object code is a network server, the Corresponding Source
271
- may be on a different server (operated by you or a third party)
272
- that supports equivalent copying facilities, provided you maintain
273
- clear directions next to the object code saying where to find the
274
- Corresponding Source. Regardless of what server hosts the
275
- Corresponding Source, you remain obligated to ensure that it is
276
- available for as long as needed to satisfy these requirements.
277
-
278
- e) Convey the object code using peer-to-peer transmission, provided
279
- you inform other peers where the object code and Corresponding
280
- Source of the work are being offered to the general public at no
281
- charge under subsection 6d.
282
-
283
- A separable portion of the object code, whose source code is excluded
284
- from the Corresponding Source as a System Library, need not be
285
- included in conveying the object code work.
286
-
287
- A "User Product" is either (1) a "consumer product", which means any
288
- tangible personal property which is normally used for personal, family,
289
- or household purposes, or (2) anything designed or sold for incorporation
290
- into a dwelling. In determining whether a product is a consumer product,
291
- doubtful cases shall be resolved in favor of coverage. For a particular
292
- product received by a particular user, "normally used" refers to a
293
- typical or common use of that class of product, regardless of the status
294
- of the particular user or of the way in which the particular user
295
- actually uses, or expects or is expected to use, the product. A product
296
- is a consumer product regardless of whether the product has substantial
297
- commercial, industrial or non-consumer uses, unless such uses represent
298
- the only significant mode of use of the product.
299
-
300
- "Installation Information" for a User Product means any methods,
301
- procedures, authorization keys, or other information required to install
302
- and execute modified versions of a covered work in that User Product from
303
- a modified version of its Corresponding Source. The information must
304
- suffice to ensure that the continued functioning of the modified object
305
- code is in no case prevented or interfered with solely because
306
- modification has been made.
307
-
308
- If you convey an object code work under this section in, or with, or
309
- specifically for use in, a User Product, and the conveying occurs as
310
- part of a transaction in which the right of possession and use of the
311
- User Product is transferred to the recipient in perpetuity or for a
312
- fixed term (regardless of how the transaction is characterized), the
313
- Corresponding Source conveyed under this section must be accompanied
314
- by the Installation Information. But this requirement does not apply
315
- if neither you nor any third party retains the ability to install
316
- modified object code on the User Product (for example, the work has
317
- been installed in ROM).
318
-
319
- The requirement to provide Installation Information does not include a
320
- requirement to continue to provide support service, warranty, or updates
321
- for a work that has been modified or installed by the recipient, or for
322
- the User Product in which it has been modified or installed. Access to a
323
- network may be denied when the modification itself materially and
324
- adversely affects the operation of the network or violates the rules and
325
- protocols for communication across the network.
326
-
327
- Corresponding Source conveyed, and Installation Information provided,
328
- in accord with this section must be in a format that is publicly
329
- documented (and with an implementation available to the public in
330
- source code form), and must require no special password or key for
331
- unpacking, reading or copying.
332
-
333
- 7. Additional Terms.
334
-
335
- "Additional permissions" are terms that supplement the terms of this
336
- License by making exceptions from one or more of its conditions.
337
- Additional permissions that are applicable to the entire Program shall
338
- be treated as though they were included in this License, to the extent
339
- that they are valid under applicable law. If additional permissions
340
- apply only to part of the Program, that part may be used separately
341
- under those permissions, but the entire Program remains governed by
342
- this License without regard to the additional permissions.
343
-
344
- When you convey a copy of a covered work, you may at your option
345
- remove any additional permissions from that copy, or from any part of
346
- it. (Additional permissions may be written to require their own
347
- removal in certain cases when you modify the work.) You may place
348
- additional permissions on material, added by you to a covered work,
349
- for which you have or can give appropriate copyright permission.
350
-
351
- Notwithstanding any other provision of this License, for material you
352
- add to a covered work, you may (if authorized by the copyright holders of
353
- that material) supplement the terms of this License with terms:
354
-
355
- a) Disclaiming warranty or limiting liability differently from the
356
- terms of sections 15 and 16 of this License; or
357
-
358
- b) Requiring preservation of specified reasonable legal notices or
359
- author attributions in that material or in the Appropriate Legal
360
- Notices displayed by works containing it; or
361
-
362
- c) Prohibiting misrepresentation of the origin of that material, or
363
- requiring that modified versions of such material be marked in
364
- reasonable ways as different from the original version; or
365
-
366
- d) Limiting the use for publicity purposes of names of licensors or
367
- authors of the material; or
368
-
369
- e) Declining to grant rights under trademark law for use of some
370
- trade names, trademarks, or service marks; or
371
-
372
- f) Requiring indemnification of licensors and authors of that
373
- material by anyone who conveys the material (or modified versions of
374
- it) with contractual assumptions of liability to the recipient, for
375
- any liability that these contractual assumptions directly impose on
376
- those licensors and authors.
377
-
378
- All other non-permissive additional terms are considered "further
379
- restrictions" within the meaning of section 10. If the Program as you
380
- received it, or any part of it, contains a notice stating that it is
381
- governed by this License along with a term that is a further
382
- restriction, you may remove that term. If a license document contains
383
- a further restriction but permits relicensing or conveying under this
384
- License, you may add to a covered work material governed by the terms
385
- of that license document, provided that the further restriction does
386
- not survive such relicensing or conveying.
387
-
388
- If you add terms to a covered work in accord with this section, you
389
- must place, in the relevant source files, a statement of the
390
- additional terms that apply to those files, or a notice indicating
391
- where to find the applicable terms.
392
-
393
- Additional terms, permissive or non-permissive, may be stated in the
394
- form of a separately written license, or stated as exceptions;
395
- the above requirements apply either way.
396
-
397
- 8. Termination.
398
-
399
- You may not propagate or modify a covered work except as expressly
400
- provided under this License. Any attempt otherwise to propagate or
401
- modify it is void, and will automatically terminate your rights under
402
- this License (including any patent licenses granted under the third
403
- paragraph of section 11).
404
-
405
- However, if you cease all violation of this License, then your
406
- license from a particular copyright holder is reinstated (a)
407
- provisionally, unless and until the copyright holder explicitly and
408
- finally terminates your license, and (b) permanently, if the copyright
409
- holder fails to notify you of the violation by some reasonable means
410
- prior to 60 days after the cessation.
411
-
412
- Moreover, your license from a particular copyright holder is
413
- reinstated permanently if the copyright holder notifies you of the
414
- violation by some reasonable means, this is the first time you have
415
- received notice of violation of this License (for any work) from that
416
- copyright holder, and you cure the violation prior to 30 days after
417
- your receipt of the notice.
418
-
419
- Termination of your rights under this section does not terminate the
420
- licenses of parties who have received copies or rights from you under
421
- this License. If your rights have been terminated and not permanently
422
- reinstated, you do not qualify to receive new licenses for the same
423
- material under section 10.
424
-
425
- 9. Acceptance Not Required for Having Copies.
426
-
427
- You are not required to accept this License in order to receive or
428
- run a copy of the Program. Ancillary propagation of a covered work
429
- occurring solely as a consequence of using peer-to-peer transmission
430
- to receive a copy likewise does not require acceptance. However,
431
- nothing other than this License grants you permission to propagate or
432
- modify any covered work. These actions infringe copyright if you do
433
- not accept this License. Therefore, by modifying or propagating a
434
- covered work, you indicate your acceptance of this License to do so.
435
-
436
- 10. Automatic Licensing of Downstream Recipients.
437
-
438
- Each time you convey a covered work, the recipient automatically
439
- receives a license from the original licensors, to run, modify and
440
- propagate that work, subject to this License. You are not responsible
441
- for enforcing compliance by third parties with this License.
442
-
443
- An "entity transaction" is a transaction transferring control of an
444
- organization, or substantially all assets of one, or subdividing an
445
- organization, or merging organizations. If propagation of a covered
446
- work results from an entity transaction, each party to that
447
- transaction who receives a copy of the work also receives whatever
448
- licenses to the work the party's predecessor in interest had or could
449
- give under the previous paragraph, plus a right to possession of the
450
- Corresponding Source of the work from the predecessor in interest, if
451
- the predecessor has it or can get it with reasonable efforts.
452
-
453
- You may not impose any further restrictions on the exercise of the
454
- rights granted or affirmed under this License. For example, you may
455
- not impose a license fee, royalty, or other charge for exercise of
456
- rights granted under this License, and you may not initiate litigation
457
- (including a cross-claim or counterclaim in a lawsuit) alleging that
458
- any patent claim is infringed by making, using, selling, offering for
459
- sale, or importing the Program or any portion of it.
460
-
461
- 11. Patents.
462
-
463
- A "contributor" is a copyright holder who authorizes use under this
464
- License of the Program or a work on which the Program is based. The
465
- work thus licensed is called the contributor's "contributor version".
466
-
467
- A contributor's "essential patent claims" are all patent claims
468
- owned or controlled by the contributor, whether already acquired or
469
- hereafter acquired, that would be infringed by some manner, permitted
470
- by this License, of making, using, or selling its contributor version,
471
- but do not include claims that would be infringed only as a
472
- consequence of further modification of the contributor version. For
473
- purposes of this definition, "control" includes the right to grant
474
- patent sublicenses in a manner consistent with the requirements of
475
- this License.
476
-
477
- Each contributor grants you a non-exclusive, worldwide, royalty-free
478
- patent license under the contributor's essential patent claims, to
479
- make, use, sell, offer for sale, import and otherwise run, modify and
480
- propagate the contents of its contributor version.
481
-
482
- In the following three paragraphs, a "patent license" is any express
483
- agreement or commitment, however denominated, not to enforce a patent
484
- (such as an express permission to practice a patent or covenant not to
485
- sue for patent infringement). To "grant" such a patent license to a
486
- party means to make such an agreement or commitment not to enforce a
487
- patent against the party.
488
-
489
- If you convey a covered work, knowingly relying on a patent license,
490
- and the Corresponding Source of the work is not available for anyone
491
- to copy, free of charge and under the terms of this License, through a
492
- publicly available network server or other readily accessible means,
493
- then you must either (1) cause the Corresponding Source to be so
494
- available, or (2) arrange to deprive yourself of the benefit of the
495
- patent license for this particular work, or (3) arrange, in a manner
496
- consistent with the requirements of this License, to extend the patent
497
- license to downstream recipients. "Knowingly relying" means you have
498
- actual knowledge that, but for the patent license, your conveying the
499
- covered work in a country, or your recipient's use of the covered work
500
- in a country, would infringe one or more identifiable patents in that
501
- country that you have reason to believe are valid.
502
-
503
- If, pursuant to or in connection with a single transaction or
504
- arrangement, you convey, or propagate by procuring conveyance of, a
505
- covered work, and grant a patent license to some of the parties
506
- receiving the covered work authorizing them to use, propagate, modify
507
- or convey a specific copy of the covered work, then the patent license
508
- you grant is automatically extended to all recipients of the covered
509
- work and works based on it.
510
-
511
- A patent license is "discriminatory" if it does not include within
512
- the scope of its coverage, prohibits the exercise of, or is
513
- conditioned on the non-exercise of one or more of the rights that are
514
- specifically granted under this License. You may not convey a covered
515
- work if you are a party to an arrangement with a third party that is
516
- in the business of distributing software, under which you make payment
517
- to the third party based on the extent of your activity of conveying
518
- the work, and under which the third party grants, to any of the
519
- parties who would receive the covered work from you, a discriminatory
520
- patent license (a) in connection with copies of the covered work
521
- conveyed by you (or copies made from those copies), or (b) primarily
522
- for and in connection with specific products or compilations that
523
- contain the covered work, unless you entered into that arrangement,
524
- or that patent license was granted, prior to 28 March 2007.
525
-
526
- Nothing in this License shall be construed as excluding or limiting
527
- any implied license or other defenses to infringement that may
528
- otherwise be available to you under applicable patent law.
529
-
530
- 12. No Surrender of Others' Freedom.
531
-
532
- If conditions are imposed on you (whether by court order, agreement or
533
- otherwise) that contradict the conditions of this License, they do not
534
- excuse you from the conditions of this License. If you cannot convey a
535
- covered work so as to satisfy simultaneously your obligations under this
536
- License and any other pertinent obligations, then as a consequence you may
537
- not convey it at all. For example, if you agree to terms that obligate you
538
- to collect a royalty for further conveying from those to whom you convey
539
- the Program, the only way you could satisfy both those terms and this
540
- License would be to refrain entirely from conveying the Program.
541
-
542
- 13. Remote Network Interaction; Use with the GNU General Public License.
543
-
544
- Notwithstanding any other provision of this License, if you modify the
545
- Program, your modified version must prominently offer all users
546
- interacting with it remotely through a computer network (if your version
547
- supports such interaction) an opportunity to receive the Corresponding
548
- Source of your version by providing access to the Corresponding Source
549
- from a network server at no charge, through some standard or customary
550
- means of facilitating copying of software. This Corresponding Source
551
- shall include the Corresponding Source for any work covered by version 3
552
- of the GNU General Public License that is incorporated pursuant to the
553
- following paragraph.
554
-
555
- Notwithstanding any other provision of this License, you have
556
- permission to link or combine any covered work with a work licensed
557
- under version 3 of the GNU General Public License into a single
558
- combined work, and to convey the resulting work. The terms of this
559
- License will continue to apply to the part which is the covered work,
560
- but the work with which it is combined will remain governed by version
561
- 3 of the GNU General Public License.
562
-
563
- 14. Revised Versions of this License.
564
-
565
- The Free Software Foundation may publish revised and/or new versions of
566
- the GNU Affero General Public License from time to time. Such new versions
567
- will be similar in spirit to the present version, but may differ in detail to
568
- address new problems or concerns.
569
-
570
- Each version is given a distinguishing version number. If the
571
- Program specifies that a certain numbered version of the GNU Affero General
572
- Public License "or any later version" applies to it, you have the
573
- option of following the terms and conditions either of that numbered
574
- version or of any later version published by the Free Software
575
- Foundation. If the Program does not specify a version number of the
576
- GNU Affero General Public License, you may choose any version ever published
577
- by the Free Software Foundation.
578
-
579
- If the Program specifies that a proxy can decide which future
580
- versions of the GNU Affero General Public License can be used, that proxy's
581
- public statement of acceptance of a version permanently authorizes you
582
- to choose that version for the Program.
583
-
584
- Later license versions may give you additional or different
585
- permissions. However, no additional obligations are imposed on any
586
- author or copyright holder as a result of your choosing to follow a
587
- later version.
588
-
589
- 15. Disclaimer of Warranty.
590
-
591
- THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
- APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
- HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
- OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
- THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
- PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
- IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
- ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
-
600
- 16. Limitation of Liability.
601
-
602
- IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
- WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
- THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
- GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
- USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
- DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
- PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
- EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
- SUCH DAMAGES.
611
-
612
- 17. Interpretation of Sections 15 and 16.
613
-
614
- If the disclaimer of warranty and limitation of liability provided
615
- above cannot be given local legal effect according to their terms,
616
- reviewing courts shall apply local law that most closely approximates
617
- an absolute waiver of all civil liability in connection with the
618
- Program, unless a warranty or assumption of liability accompanies a
619
- copy of the Program in return for a fee.
620
-
621
- END OF TERMS AND CONDITIONS
622
-
623
- How to Apply These Terms to Your New Programs
624
-
625
- If you develop a new program, and you want it to be of the greatest
626
- possible use to the public, the best way to achieve this is to make it
627
- free software which everyone can redistribute and change under these terms.
628
-
629
- To do so, attach the following notices to the program. It is safest
630
- to attach them to the start of each source file to most effectively
631
- state the exclusion of warranty; and each file should have at least
632
- the "copyright" line and a pointer to where the full notice is found.
633
-
634
- <one line to give the program's name and a brief idea of what it does.>
635
- Copyright (C) <year> <name of author>
636
-
637
- This program is free software: you can redistribute it and/or modify
638
- it under the terms of the GNU Affero General Public License as published by
639
- the Free Software Foundation, either version 3 of the License, or
640
- (at your option) any later version.
641
-
642
- This program is distributed in the hope that it will be useful,
643
- but WITHOUT ANY WARRANTY; without even the implied warranty of
644
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
- GNU Affero General Public License for more details.
646
-
647
- You should have received a copy of the GNU Affero General Public License
648
- along with this program. If not, see <https://www.gnu.org/licenses/>.
649
-
650
- Also add information on how to contact you by electronic and paper mail.
651
-
652
- If your software can interact with users remotely through a computer
653
- network, you should also make sure that it provides a way for users to
654
- get its source. For example, if your program is a web application, its
655
- interface could display a "Source" link that leads users to an archive
656
- of the code. There are many ways you could offer source, and different
657
- solutions will be better for different programs; see section 13 for the
658
- specific requirements.
659
-
660
- You should also get your employer (if you work as a programmer) or school,
661
- if any, to sign a "copyright disclaimer" for the program, if necessary.
662
- For more information on this, and how to apply and follow the GNU AGPL, see
663
- <https://www.gnu.org/licenses/>.
 
1
+ GNU AFFERO GENERAL PUBLIC LICENSE
2
+ Version 3, 19 November 2007
3
+
4
+ Copyright (c) 2023 AUTOMATIC1111
5
+
6
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
7
+ Everyone is permitted to copy and distribute verbatim copies
8
+ of this license document, but changing it is not allowed.
9
+
10
+ Preamble
11
+
12
+ The GNU Affero General Public License is a free, copyleft license for
13
+ software and other kinds of works, specifically designed to ensure
14
+ cooperation with the community in the case of network server software.
15
+
16
+ The licenses for most software and other practical works are designed
17
+ to take away your freedom to share and change the works. By contrast,
18
+ our General Public Licenses are intended to guarantee your freedom to
19
+ share and change all versions of a program--to make sure it remains free
20
+ software for all its users.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ Developers that use our General Public Licenses protect your rights
30
+ with two steps: (1) assert copyright on the software, and (2) offer
31
+ you this License which gives you legal permission to copy, distribute
32
+ and/or modify the software.
33
+
34
+ A secondary benefit of defending all users' freedom is that
35
+ improvements made in alternate versions of the program, if they
36
+ receive widespread use, become available for other developers to
37
+ incorporate. Many developers of free software are heartened and
38
+ encouraged by the resulting cooperation. However, in the case of
39
+ software used on network servers, this result may fail to come about.
40
+ The GNU General Public License permits making a modified version and
41
+ letting the public access it on a server without ever releasing its
42
+ source code to the public.
43
+
44
+ The GNU Affero General Public License is designed specifically to
45
+ ensure that, in such cases, the modified source code becomes available
46
+ to the community. It requires the operator of a network server to
47
+ provide the source code of the modified version running there to the
48
+ users of that server. Therefore, public use of a modified version, on
49
+ a publicly accessible server, gives the public access to the source
50
+ code of the modified version.
51
+
52
+ An older license, called the Affero General Public License and
53
+ published by Affero, was designed to accomplish similar goals. This is
54
+ a different license, not a version of the Affero GPL, but Affero has
55
+ released a new version of the Affero GPL which permits relicensing under
56
+ this license.
57
+
58
+ The precise terms and conditions for copying, distribution and
59
+ modification follow.
60
+
61
+ TERMS AND CONDITIONS
62
+
63
+ 0. Definitions.
64
+
65
+ "This License" refers to version 3 of the GNU Affero General Public License.
66
+
67
+ "Copyright" also means copyright-like laws that apply to other kinds of
68
+ works, such as semiconductor masks.
69
+
70
+ "The Program" refers to any copyrightable work licensed under this
71
+ License. Each licensee is addressed as "you". "Licensees" and
72
+ "recipients" may be individuals or organizations.
73
+
74
+ To "modify" a work means to copy from or adapt all or part of the work
75
+ in a fashion requiring copyright permission, other than the making of an
76
+ exact copy. The resulting work is called a "modified version" of the
77
+ earlier work or a work "based on" the earlier work.
78
+
79
+ A "covered work" means either the unmodified Program or a work based
80
+ on the Program.
81
+
82
+ To "propagate" a work means to do anything with it that, without
83
+ permission, would make you directly or secondarily liable for
84
+ infringement under applicable copyright law, except executing it on a
85
+ computer or modifying a private copy. Propagation includes copying,
86
+ distribution (with or without modification), making available to the
87
+ public, and in some countries other activities as well.
88
+
89
+ To "convey" a work means any kind of propagation that enables other
90
+ parties to make or receive copies. Mere interaction with a user through
91
+ a computer network, with no transfer of a copy, is not conveying.
92
+
93
+ An interactive user interface displays "Appropriate Legal Notices"
94
+ to the extent that it includes a convenient and prominently visible
95
+ feature that (1) displays an appropriate copyright notice, and (2)
96
+ tells the user that there is no warranty for the work (except to the
97
+ extent that warranties are provided), that licensees may convey the
98
+ work under this License, and how to view a copy of this License. If
99
+ the interface presents a list of user commands or options, such as a
100
+ menu, a prominent item in the list meets this criterion.
101
+
102
+ 1. Source Code.
103
+
104
+ The "source code" for a work means the preferred form of the work
105
+ for making modifications to it. "Object code" means any non-source
106
+ form of a work.
107
+
108
+ A "Standard Interface" means an interface that either is an official
109
+ standard defined by a recognized standards body, or, in the case of
110
+ interfaces specified for a particular programming language, one that
111
+ is widely used among developers working in that language.
112
+
113
+ The "System Libraries" of an executable work include anything, other
114
+ than the work as a whole, that (a) is included in the normal form of
115
+ packaging a Major Component, but which is not part of that Major
116
+ Component, and (b) serves only to enable use of the work with that
117
+ Major Component, or to implement a Standard Interface for which an
118
+ implementation is available to the public in source code form. A
119
+ "Major Component", in this context, means a major essential component
120
+ (kernel, window system, and so on) of the specific operating system
121
+ (if any) on which the executable work runs, or a compiler used to
122
+ produce the work, or an object code interpreter used to run it.
123
+
124
+ The "Corresponding Source" for a work in object code form means all
125
+ the source code needed to generate, install, and (for an executable
126
+ work) run the object code and to modify the work, including scripts to
127
+ control those activities. However, it does not include the work's
128
+ System Libraries, or general-purpose tools or generally available free
129
+ programs which are used unmodified in performing those activities but
130
+ which are not part of the work. For example, Corresponding Source
131
+ includes interface definition files associated with source files for
132
+ the work, and the source code for shared libraries and dynamically
133
+ linked subprograms that the work is specifically designed to require,
134
+ such as by intimate data communication or control flow between those
135
+ subprograms and other parts of the work.
136
+
137
+ The Corresponding Source need not include anything that users
138
+ can regenerate automatically from other parts of the Corresponding
139
+ Source.
140
+
141
+ The Corresponding Source for a work in source code form is that
142
+ same work.
143
+
144
+ 2. Basic Permissions.
145
+
146
+ All rights granted under this License are granted for the term of
147
+ copyright on the Program, and are irrevocable provided the stated
148
+ conditions are met. This License explicitly affirms your unlimited
149
+ permission to run the unmodified Program. The output from running a
150
+ covered work is covered by this License only if the output, given its
151
+ content, constitutes a covered work. This License acknowledges your
152
+ rights of fair use or other equivalent, as provided by copyright law.
153
+
154
+ You may make, run and propagate covered works that you do not
155
+ convey, without conditions so long as your license otherwise remains
156
+ in force. You may convey covered works to others for the sole purpose
157
+ of having them make modifications exclusively for you, or provide you
158
+ with facilities for running those works, provided that you comply with
159
+ the terms of this License in conveying all material for which you do
160
+ not control copyright. Those thus making or running the covered works
161
+ for you must do so exclusively on your behalf, under your direction
162
+ and control, on terms that prohibit them from making any copies of
163
+ your copyrighted material outside their relationship with you.
164
+
165
+ Conveying under any other circumstances is permitted solely under
166
+ the conditions stated below. Sublicensing is not allowed; section 10
167
+ makes it unnecessary.
168
+
169
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
170
+
171
+ No covered work shall be deemed part of an effective technological
172
+ measure under any applicable law fulfilling obligations under article
173
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
174
+ similar laws prohibiting or restricting circumvention of such
175
+ measures.
176
+
177
+ When you convey a covered work, you waive any legal power to forbid
178
+ circumvention of technological measures to the extent such circumvention
179
+ is effected by exercising rights under this License with respect to
180
+ the covered work, and you disclaim any intention to limit operation or
181
+ modification of the work as a means of enforcing, against the work's
182
+ users, your or third parties' legal rights to forbid circumvention of
183
+ technological measures.
184
+
185
+ 4. Conveying Verbatim Copies.
186
+
187
+ You may convey verbatim copies of the Program's source code as you
188
+ receive it, in any medium, provided that you conspicuously and
189
+ appropriately publish on each copy an appropriate copyright notice;
190
+ keep intact all notices stating that this License and any
191
+ non-permissive terms added in accord with section 7 apply to the code;
192
+ keep intact all notices of the absence of any warranty; and give all
193
+ recipients a copy of this License along with the Program.
194
+
195
+ You may charge any price or no price for each copy that you convey,
196
+ and you may offer support or warranty protection for a fee.
197
+
198
+ 5. Conveying Modified Source Versions.
199
+
200
+ You may convey a work based on the Program, or the modifications to
201
+ produce it from the Program, in the form of source code under the
202
+ terms of section 4, provided that you also meet all of these conditions:
203
+
204
+ a) The work must carry prominent notices stating that you modified
205
+ it, and giving a relevant date.
206
+
207
+ b) The work must carry prominent notices stating that it is
208
+ released under this License and any conditions added under section
209
+ 7. This requirement modifies the requirement in section 4 to
210
+ "keep intact all notices".
211
+
212
+ c) You must license the entire work, as a whole, under this
213
+ License to anyone who comes into possession of a copy. This
214
+ License will therefore apply, along with any applicable section 7
215
+ additional terms, to the whole of the work, and all its parts,
216
+ regardless of how they are packaged. This License gives no
217
+ permission to license the work in any other way, but it does not
218
+ invalidate such permission if you have separately received it.
219
+
220
+ d) If the work has interactive user interfaces, each must display
221
+ Appropriate Legal Notices; however, if the Program has interactive
222
+ interfaces that do not display Appropriate Legal Notices, your
223
+ work need not make them do so.
224
+
225
+ A compilation of a covered work with other separate and independent
226
+ works, which are not by their nature extensions of the covered work,
227
+ and which are not combined with it such as to form a larger program,
228
+ in or on a volume of a storage or distribution medium, is called an
229
+ "aggregate" if the compilation and its resulting copyright are not
230
+ used to limit the access or legal rights of the compilation's users
231
+ beyond what the individual works permit. Inclusion of a covered work
232
+ in an aggregate does not cause this License to apply to the other
233
+ parts of the aggregate.
234
+
235
+ 6. Conveying Non-Source Forms.
236
+
237
+ You may convey a covered work in object code form under the terms
238
+ of sections 4 and 5, provided that you also convey the
239
+ machine-readable Corresponding Source under the terms of this License,
240
+ in one of these ways:
241
+
242
+ a) Convey the object code in, or embodied in, a physical product
243
+ (including a physical distribution medium), accompanied by the
244
+ Corresponding Source fixed on a durable physical medium
245
+ customarily used for software interchange.
246
+
247
+ b) Convey the object code in, or embodied in, a physical product
248
+ (including a physical distribution medium), accompanied by a
249
+ written offer, valid for at least three years and valid for as
250
+ long as you offer spare parts or customer support for that product
251
+ model, to give anyone who possesses the object code either (1) a
252
+ copy of the Corresponding Source for all the software in the
253
+ product that is covered by this License, on a durable physical
254
+ medium customarily used for software interchange, for a price no
255
+ more than your reasonable cost of physically performing this
256
+ conveying of source, or (2) access to copy the
257
+ Corresponding Source from a network server at no charge.
258
+
259
+ c) Convey individual copies of the object code with a copy of the
260
+ written offer to provide the Corresponding Source. This
261
+ alternative is allowed only occasionally and noncommercially, and
262
+ only if you received the object code with such an offer, in accord
263
+ with subsection 6b.
264
+
265
+ d) Convey the object code by offering access from a designated
266
+ place (gratis or for a charge), and offer equivalent access to the
267
+ Corresponding Source in the same way through the same place at no
268
+ further charge. You need not require recipients to copy the
269
+ Corresponding Source along with the object code. If the place to
270
+ copy the object code is a network server, the Corresponding Source
271
+ may be on a different server (operated by you or a third party)
272
+ that supports equivalent copying facilities, provided you maintain
273
+ clear directions next to the object code saying where to find the
274
+ Corresponding Source. Regardless of what server hosts the
275
+ Corresponding Source, you remain obligated to ensure that it is
276
+ available for as long as needed to satisfy these requirements.
277
+
278
+ e) Convey the object code using peer-to-peer transmission, provided
279
+ you inform other peers where the object code and Corresponding
280
+ Source of the work are being offered to the general public at no
281
+ charge under subsection 6d.
282
+
283
+ A separable portion of the object code, whose source code is excluded
284
+ from the Corresponding Source as a System Library, need not be
285
+ included in conveying the object code work.
286
+
287
+ A "User Product" is either (1) a "consumer product", which means any
288
+ tangible personal property which is normally used for personal, family,
289
+ or household purposes, or (2) anything designed or sold for incorporation
290
+ into a dwelling. In determining whether a product is a consumer product,
291
+ doubtful cases shall be resolved in favor of coverage. For a particular
292
+ product received by a particular user, "normally used" refers to a
293
+ typical or common use of that class of product, regardless of the status
294
+ of the particular user or of the way in which the particular user
295
+ actually uses, or expects or is expected to use, the product. A product
296
+ is a consumer product regardless of whether the product has substantial
297
+ commercial, industrial or non-consumer uses, unless such uses represent
298
+ the only significant mode of use of the product.
299
+
300
+ "Installation Information" for a User Product means any methods,
301
+ procedures, authorization keys, or other information required to install
302
+ and execute modified versions of a covered work in that User Product from
303
+ a modified version of its Corresponding Source. The information must
304
+ suffice to ensure that the continued functioning of the modified object
305
+ code is in no case prevented or interfered with solely because
306
+ modification has been made.
307
+
308
+ If you convey an object code work under this section in, or with, or
309
+ specifically for use in, a User Product, and the conveying occurs as
310
+ part of a transaction in which the right of possession and use of the
311
+ User Product is transferred to the recipient in perpetuity or for a
312
+ fixed term (regardless of how the transaction is characterized), the
313
+ Corresponding Source conveyed under this section must be accompanied
314
+ by the Installation Information. But this requirement does not apply
315
+ if neither you nor any third party retains the ability to install
316
+ modified object code on the User Product (for example, the work has
317
+ been installed in ROM).
318
+
319
+ The requirement to provide Installation Information does not include a
320
+ requirement to continue to provide support service, warranty, or updates
321
+ for a work that has been modified or installed by the recipient, or for
322
+ the User Product in which it has been modified or installed. Access to a
323
+ network may be denied when the modification itself materially and
324
+ adversely affects the operation of the network or violates the rules and
325
+ protocols for communication across the network.
326
+
327
+ Corresponding Source conveyed, and Installation Information provided,
328
+ in accord with this section must be in a format that is publicly
329
+ documented (and with an implementation available to the public in
330
+ source code form), and must require no special password or key for
331
+ unpacking, reading or copying.
332
+
333
+ 7. Additional Terms.
334
+
335
+ "Additional permissions" are terms that supplement the terms of this
336
+ License by making exceptions from one or more of its conditions.
337
+ Additional permissions that are applicable to the entire Program shall
338
+ be treated as though they were included in this License, to the extent
339
+ that they are valid under applicable law. If additional permissions
340
+ apply only to part of the Program, that part may be used separately
341
+ under those permissions, but the entire Program remains governed by
342
+ this License without regard to the additional permissions.
343
+
344
+ When you convey a copy of a covered work, you may at your option
345
+ remove any additional permissions from that copy, or from any part of
346
+ it. (Additional permissions may be written to require their own
347
+ removal in certain cases when you modify the work.) You may place
348
+ additional permissions on material, added by you to a covered work,
349
+ for which you have or can give appropriate copyright permission.
350
+
351
+ Notwithstanding any other provision of this License, for material you
352
+ add to a covered work, you may (if authorized by the copyright holders of
353
+ that material) supplement the terms of this License with terms:
354
+
355
+ a) Disclaiming warranty or limiting liability differently from the
356
+ terms of sections 15 and 16 of this License; or
357
+
358
+ b) Requiring preservation of specified reasonable legal notices or
359
+ author attributions in that material or in the Appropriate Legal
360
+ Notices displayed by works containing it; or
361
+
362
+ c) Prohibiting misrepresentation of the origin of that material, or
363
+ requiring that modified versions of such material be marked in
364
+ reasonable ways as different from the original version; or
365
+
366
+ d) Limiting the use for publicity purposes of names of licensors or
367
+ authors of the material; or
368
+
369
+ e) Declining to grant rights under trademark law for use of some
370
+ trade names, trademarks, or service marks; or
371
+
372
+ f) Requiring indemnification of licensors and authors of that
373
+ material by anyone who conveys the material (or modified versions of
374
+ it) with contractual assumptions of liability to the recipient, for
375
+ any liability that these contractual assumptions directly impose on
376
+ those licensors and authors.
377
+
378
+ All other non-permissive additional terms are considered "further
379
+ restrictions" within the meaning of section 10. If the Program as you
380
+ received it, or any part of it, contains a notice stating that it is
381
+ governed by this License along with a term that is a further
382
+ restriction, you may remove that term. If a license document contains
383
+ a further restriction but permits relicensing or conveying under this
384
+ License, you may add to a covered work material governed by the terms
385
+ of that license document, provided that the further restriction does
386
+ not survive such relicensing or conveying.
387
+
388
+ If you add terms to a covered work in accord with this section, you
389
+ must place, in the relevant source files, a statement of the
390
+ additional terms that apply to those files, or a notice indicating
391
+ where to find the applicable terms.
392
+
393
+ Additional terms, permissive or non-permissive, may be stated in the
394
+ form of a separately written license, or stated as exceptions;
395
+ the above requirements apply either way.
396
+
397
+ 8. Termination.
398
+
399
+ You may not propagate or modify a covered work except as expressly
400
+ provided under this License. Any attempt otherwise to propagate or
401
+ modify it is void, and will automatically terminate your rights under
402
+ this License (including any patent licenses granted under the third
403
+ paragraph of section 11).
404
+
405
+ However, if you cease all violation of this License, then your
406
+ license from a particular copyright holder is reinstated (a)
407
+ provisionally, unless and until the copyright holder explicitly and
408
+ finally terminates your license, and (b) permanently, if the copyright
409
+ holder fails to notify you of the violation by some reasonable means
410
+ prior to 60 days after the cessation.
411
+
412
+ Moreover, your license from a particular copyright holder is
413
+ reinstated permanently if the copyright holder notifies you of the
414
+ violation by some reasonable means, this is the first time you have
415
+ received notice of violation of this License (for any work) from that
416
+ copyright holder, and you cure the violation prior to 30 days after
417
+ your receipt of the notice.
418
+
419
+ Termination of your rights under this section does not terminate the
420
+ licenses of parties who have received copies or rights from you under
421
+ this License. If your rights have been terminated and not permanently
422
+ reinstated, you do not qualify to receive new licenses for the same
423
+ material under section 10.
424
+
425
+ 9. Acceptance Not Required for Having Copies.
426
+
427
+ You are not required to accept this License in order to receive or
428
+ run a copy of the Program. Ancillary propagation of a covered work
429
+ occurring solely as a consequence of using peer-to-peer transmission
430
+ to receive a copy likewise does not require acceptance. However,
431
+ nothing other than this License grants you permission to propagate or
432
+ modify any covered work. These actions infringe copyright if you do
433
+ not accept this License. Therefore, by modifying or propagating a
434
+ covered work, you indicate your acceptance of this License to do so.
435
+
436
+ 10. Automatic Licensing of Downstream Recipients.
437
+
438
+ Each time you convey a covered work, the recipient automatically
439
+ receives a license from the original licensors, to run, modify and
440
+ propagate that work, subject to this License. You are not responsible
441
+ for enforcing compliance by third parties with this License.
442
+
443
+ An "entity transaction" is a transaction transferring control of an
444
+ organization, or substantially all assets of one, or subdividing an
445
+ organization, or merging organizations. If propagation of a covered
446
+ work results from an entity transaction, each party to that
447
+ transaction who receives a copy of the work also receives whatever
448
+ licenses to the work the party's predecessor in interest had or could
449
+ give under the previous paragraph, plus a right to possession of the
450
+ Corresponding Source of the work from the predecessor in interest, if
451
+ the predecessor has it or can get it with reasonable efforts.
452
+
453
+ You may not impose any further restrictions on the exercise of the
454
+ rights granted or affirmed under this License. For example, you may
455
+ not impose a license fee, royalty, or other charge for exercise of
456
+ rights granted under this License, and you may not initiate litigation
457
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
458
+ any patent claim is infringed by making, using, selling, offering for
459
+ sale, or importing the Program or any portion of it.
460
+
461
+ 11. Patents.
462
+
463
+ A "contributor" is a copyright holder who authorizes use under this
464
+ License of the Program or a work on which the Program is based. The
465
+ work thus licensed is called the contributor's "contributor version".
466
+
467
+ A contributor's "essential patent claims" are all patent claims
468
+ owned or controlled by the contributor, whether already acquired or
469
+ hereafter acquired, that would be infringed by some manner, permitted
470
+ by this License, of making, using, or selling its contributor version,
471
+ but do not include claims that would be infringed only as a
472
+ consequence of further modification of the contributor version. For
473
+ purposes of this definition, "control" includes the right to grant
474
+ patent sublicenses in a manner consistent with the requirements of
475
+ this License.
476
+
477
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
478
+ patent license under the contributor's essential patent claims, to
479
+ make, use, sell, offer for sale, import and otherwise run, modify and
480
+ propagate the contents of its contributor version.
481
+
482
+ In the following three paragraphs, a "patent license" is any express
483
+ agreement or commitment, however denominated, not to enforce a patent
484
+ (such as an express permission to practice a patent or covenant not to
485
+ sue for patent infringement). To "grant" such a patent license to a
486
+ party means to make such an agreement or commitment not to enforce a
487
+ patent against the party.
488
+
489
+ If you convey a covered work, knowingly relying on a patent license,
490
+ and the Corresponding Source of the work is not available for anyone
491
+ to copy, free of charge and under the terms of this License, through a
492
+ publicly available network server or other readily accessible means,
493
+ then you must either (1) cause the Corresponding Source to be so
494
+ available, or (2) arrange to deprive yourself of the benefit of the
495
+ patent license for this particular work, or (3) arrange, in a manner
496
+ consistent with the requirements of this License, to extend the patent
497
+ license to downstream recipients. "Knowingly relying" means you have
498
+ actual knowledge that, but for the patent license, your conveying the
499
+ covered work in a country, or your recipient's use of the covered work
500
+ in a country, would infringe one or more identifiable patents in that
501
+ country that you have reason to believe are valid.
502
+
503
+ If, pursuant to or in connection with a single transaction or
504
+ arrangement, you convey, or propagate by procuring conveyance of, a
505
+ covered work, and grant a patent license to some of the parties
506
+ receiving the covered work authorizing them to use, propagate, modify
507
+ or convey a specific copy of the covered work, then the patent license
508
+ you grant is automatically extended to all recipients of the covered
509
+ work and works based on it.
510
+
511
+ A patent license is "discriminatory" if it does not include within
512
+ the scope of its coverage, prohibits the exercise of, or is
513
+ conditioned on the non-exercise of one or more of the rights that are
514
+ specifically granted under this License. You may not convey a covered
515
+ work if you are a party to an arrangement with a third party that is
516
+ in the business of distributing software, under which you make payment
517
+ to the third party based on the extent of your activity of conveying
518
+ the work, and under which the third party grants, to any of the
519
+ parties who would receive the covered work from you, a discriminatory
520
+ patent license (a) in connection with copies of the covered work
521
+ conveyed by you (or copies made from those copies), or (b) primarily
522
+ for and in connection with specific products or compilations that
523
+ contain the covered work, unless you entered into that arrangement,
524
+ or that patent license was granted, prior to 28 March 2007.
525
+
526
+ Nothing in this License shall be construed as excluding or limiting
527
+ any implied license or other defenses to infringement that may
528
+ otherwise be available to you under applicable patent law.
529
+
530
+ 12. No Surrender of Others' Freedom.
531
+
532
+ If conditions are imposed on you (whether by court order, agreement or
533
+ otherwise) that contradict the conditions of this License, they do not
534
+ excuse you from the conditions of this License. If you cannot convey a
535
+ covered work so as to satisfy simultaneously your obligations under this
536
+ License and any other pertinent obligations, then as a consequence you may
537
+ not convey it at all. For example, if you agree to terms that obligate you
538
+ to collect a royalty for further conveying from those to whom you convey
539
+ the Program, the only way you could satisfy both those terms and this
540
+ License would be to refrain entirely from conveying the Program.
541
+
542
+ 13. Remote Network Interaction; Use with the GNU General Public License.
543
+
544
+ Notwithstanding any other provision of this License, if you modify the
545
+ Program, your modified version must prominently offer all users
546
+ interacting with it remotely through a computer network (if your version
547
+ supports such interaction) an opportunity to receive the Corresponding
548
+ Source of your version by providing access to the Corresponding Source
549
+ from a network server at no charge, through some standard or customary
550
+ means of facilitating copying of software. This Corresponding Source
551
+ shall include the Corresponding Source for any work covered by version 3
552
+ of the GNU General Public License that is incorporated pursuant to the
553
+ following paragraph.
554
+
555
+ Notwithstanding any other provision of this License, you have
556
+ permission to link or combine any covered work with a work licensed
557
+ under version 3 of the GNU General Public License into a single
558
+ combined work, and to convey the resulting work. The terms of this
559
+ License will continue to apply to the part which is the covered work,
560
+ but the work with which it is combined will remain governed by version
561
+ 3 of the GNU General Public License.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU Affero General Public License from time to time. Such new versions
567
+ will be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU Affero General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU Affero General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU Affero General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU Affero General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU Affero General Public License for more details.
646
+
647
+ You should have received a copy of the GNU Affero General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If your software can interact with users remotely through a computer
653
+ network, you should also make sure that it provides a way for users to
654
+ get its source. For example, if your program is a web application, its
655
+ interface could display a "Source" link that leads users to an archive
656
+ of the code. There are many ways you could offer source, and different
657
+ solutions will be better for different programs; see section 13 for the
658
+ specific requirements.
659
+
660
+ You should also get your employer (if you work as a programmer) or school,
661
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
662
+ For more information on this, and how to apply and follow the GNU AGPL, see
663
+ <https://www.gnu.org/licenses/>.
sd/stable-diffusion-webui/README.md ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stable Diffusion web UI
2
+ A browser interface based on Gradio library for Stable Diffusion.
3
+
4
+ ![](screenshot.png)
5
+
6
+ ## Features
7
+ [Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):
8
+ - Original txt2img and img2img modes
9
+ - One click install and run script (but you still must install python and git)
10
+ - Outpainting
11
+ - Inpainting
12
+ - Color Sketch
13
+ - Prompt Matrix
14
+ - Stable Diffusion Upscale
15
+ - Attention, specify parts of text that the model should pay more attention to
16
+ - a man in a ((tuxedo)) - will pay more attention to tuxedo
17
+ - a man in a (tuxedo:1.21) - alternative syntax
18
+ - select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user)
19
+ - Loopback, run img2img processing multiple times
20
+ - X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
21
+ - Textual Inversion
22
+ - have as many embeddings as you want and use any names you like for them
23
+ - use multiple embeddings with different numbers of vectors per token
24
+ - works with half precision floating point numbers
25
+ - train embeddings on 8GB (also reports of 6GB working)
26
+ - Extras tab with:
27
+ - GFPGAN, neural network that fixes faces
28
+ - CodeFormer, face restoration tool as an alternative to GFPGAN
29
+ - RealESRGAN, neural network upscaler
30
+ - ESRGAN, neural network upscaler with a lot of third party models
31
+ - SwinIR and Swin2SR([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers
32
+ - LDSR, Latent diffusion super resolution upscaling
33
+ - Resizing aspect ratio options
34
+ - Sampling method selection
35
+ - Adjust sampler eta values (noise multiplier)
36
+ - More advanced noise setting options
37
+ - Interrupt processing at any time
38
+ - 4GB video card support (also reports of 2GB working)
39
+ - Correct seeds for batches
40
+ - Live prompt token length validation
41
+ - Generation parameters
42
+ - parameters you used to generate images are saved with that image
43
+ - in PNG chunks for PNG, in EXIF for JPEG
44
+ - can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
45
+ - can be disabled in settings
46
+ - drag and drop an image/text-parameters to promptbox
47
+ - Read Generation Parameters Button, loads parameters in promptbox to UI
48
+ - Settings page
49
+ - Running arbitrary python code from UI (must run with --allow-code to enable)
50
+ - Mouseover hints for most UI elements
51
+ - Possible to change defaults/mix/max/step values for UI elements via text config
52
+ - Tiling support, a checkbox to create images that can be tiled like textures
53
+ - Progress bar and live image generation preview
54
+ - Can use a separate neural network to produce previews with almost none VRAM or compute requirement
55
+ - Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
56
+ - Styles, a way to save part of prompt and easily apply them via dropdown later
57
+ - Variations, a way to generate same image but with tiny differences
58
+ - Seed resizing, a way to generate same image but at slightly different resolution
59
+ - CLIP interrogator, a button that tries to guess prompt from an image
60
+ - Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
61
+ - Batch Processing, process a group of files using img2img
62
+ - Img2img Alternative, reverse Euler method of cross attention control
63
+ - Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
64
+ - Reloading checkpoints on the fly
65
+ - Checkpoint Merger, a tab that allows you to merge up to 3 checkpoints into one
66
+ - [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
67
+ - [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once
68
+ - separate prompts using uppercase `AND`
69
+ - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
70
+ - No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
71
+ - DeepDanbooru integration, creates danbooru style tags for anime prompts
72
+ - [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
73
+ - via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI
74
+ - Generate forever option
75
+ - Training tab
76
+ - hypernetworks and embeddings options
77
+ - Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime)
78
+ - Clip skip
79
+ - Hypernetworks
80
+ - Loras (same as Hypernetworks but more pretty)
81
+ - A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt.
82
+ - Can select to load a different VAE from settings screen
83
+ - Estimated completion time in progress bar
84
+ - API
85
+ - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
86
+ - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
87
+ - [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
88
+ - [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
89
+ - Now without any bad letters!
90
+ - Load checkpoints in safetensors format
91
+ - Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
92
+ - Now with a license!
93
+ - Reorder elements in the UI from settings screen
94
+ -
95
+
96
+ ## Installation and Running
97
+ Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
98
+
99
+ Alternatively, use online services (like Google Colab):
100
+
101
+ - [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
102
+
103
+ ### Automatic Installation on Windows
104
+ 1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
105
+ 2. Install [git](https://git-scm.com/download/win).
106
+ 3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
107
+ 4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
108
+
109
+ ### Automatic Installation on Linux
110
+ 1. Install the dependencies:
111
+ ```bash
112
+ # Debian-based:
113
+ sudo apt install wget git python3 python3-venv
114
+ # Red Hat-based:
115
+ sudo dnf install wget git python3
116
+ # Arch-based:
117
+ sudo pacman -S wget git python3
118
+ ```
119
+ 2. To install in `/home/$(whoami)/stable-diffusion-webui/`, run:
120
+ ```bash
121
+ bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
122
+ ```
123
+ 3. Run `webui.sh`.
124
+ ### Installation on Apple Silicon
125
+
126
+ Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
127
+
128
+ ## Contributing
129
+ Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
130
+
131
+ ## Documentation
132
+ The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
133
+
134
+ ## Credits
135
+ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
136
+
137
+ - Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
138
+ - k-diffusion - https://github.com/crowsonkb/k-diffusion.git
139
+ - GFPGAN - https://github.com/TencentARC/GFPGAN.git
140
+ - CodeFormer - https://github.com/sczhou/CodeFormer
141
+ - ESRGAN - https://github.com/xinntao/ESRGAN
142
+ - SwinIR - https://github.com/JingyunLiang/SwinIR
143
+ - Swin2SR - https://github.com/mv-lab/swin2sr
144
+ - LDSR - https://github.com/Hafiidz/latent-diffusion
145
+ - MiDaS - https://github.com/isl-org/MiDaS
146
+ - Ideas for optimizations - https://github.com/basujindal/stable-diffusion
147
+ - Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
148
+ - Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
149
+ - Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san/diffusers/pull/1), Amin Rezaei (https://github.com/AminRezaei0x443/memory-efficient-attention)
150
+ - Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
151
+ - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
152
+ - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
153
+ - CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
154
+ - Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
155
+ - xformers - https://github.com/facebookresearch/xformers
156
+ - DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
157
+ - Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
158
+ - Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
159
+ - Security advice - RyotaK
160
+ - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
161
+ - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
162
+ - (You)
sd/stable-diffusion-webui/extensions-builtin/LDSR/preload.py CHANGED
@@ -1,6 +1,6 @@
1
- import os
2
- from modules import paths
3
-
4
-
5
- def preload(parser):
6
- parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
 
1
+ import os
2
+ from modules import paths
3
+
4
+
5
+ def preload(parser):
6
+ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
sd/stable-diffusion-webui/extensions-builtin/Lora/extra_networks_lora.py CHANGED
@@ -1,26 +1,26 @@
1
- from modules import extra_networks, shared
2
- import lora
3
-
4
- class ExtraNetworkLora(extra_networks.ExtraNetwork):
5
- def __init__(self):
6
- super().__init__('lora')
7
-
8
- def activate(self, p, params_list):
9
- additional = shared.opts.sd_lora
10
-
11
- if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
12
- p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
13
- params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
14
-
15
- names = []
16
- multipliers = []
17
- for params in params_list:
18
- assert len(params.items) > 0
19
-
20
- names.append(params.items[0])
21
- multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
22
-
23
- lora.load_loras(names, multipliers)
24
-
25
- def deactivate(self, p):
26
- pass
 
1
+ from modules import extra_networks, shared
2
+ import lora
3
+
4
+ class ExtraNetworkLora(extra_networks.ExtraNetwork):
5
+ def __init__(self):
6
+ super().__init__('lora')
7
+
8
+ def activate(self, p, params_list):
9
+ additional = shared.opts.sd_lora
10
+
11
+ if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
12
+ p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
13
+ params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
14
+
15
+ names = []
16
+ multipliers = []
17
+ for params in params_list:
18
+ assert len(params.items) > 0
19
+
20
+ names.append(params.items[0])
21
+ multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
22
+
23
+ lora.load_loras(names, multipliers)
24
+
25
+ def deactivate(self, p):
26
+ pass
sd/stable-diffusion-webui/extensions-builtin/Lora/lora.py CHANGED
@@ -1,207 +1,207 @@
1
- import glob
2
- import os
3
- import re
4
- import torch
5
-
6
- from modules import shared, devices, sd_models
7
-
8
- re_digits = re.compile(r"\d+")
9
- re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
10
- re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
11
- re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
12
- re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
13
-
14
-
15
- def convert_diffusers_name_to_compvis(key):
16
- def match(match_list, regex):
17
- r = re.match(regex, key)
18
- if not r:
19
- return False
20
-
21
- match_list.clear()
22
- match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
23
- return True
24
-
25
- m = []
26
-
27
- if match(m, re_unet_down_blocks):
28
- return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
29
-
30
- if match(m, re_unet_mid_blocks):
31
- return f"diffusion_model_middle_block_1_{m[1]}"
32
-
33
- if match(m, re_unet_up_blocks):
34
- return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
35
-
36
- if match(m, re_text_block):
37
- return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
38
-
39
- return key
40
-
41
-
42
- class LoraOnDisk:
43
- def __init__(self, name, filename):
44
- self.name = name
45
- self.filename = filename
46
-
47
-
48
- class LoraModule:
49
- def __init__(self, name):
50
- self.name = name
51
- self.multiplier = 1.0
52
- self.modules = {}
53
- self.mtime = None
54
-
55
-
56
- class LoraUpDownModule:
57
- def __init__(self):
58
- self.up = None
59
- self.down = None
60
- self.alpha = None
61
-
62
-
63
- def assign_lora_names_to_compvis_modules(sd_model):
64
- lora_layer_mapping = {}
65
-
66
- for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
67
- lora_name = name.replace(".", "_")
68
- lora_layer_mapping[lora_name] = module
69
- module.lora_layer_name = lora_name
70
-
71
- for name, module in shared.sd_model.model.named_modules():
72
- lora_name = name.replace(".", "_")
73
- lora_layer_mapping[lora_name] = module
74
- module.lora_layer_name = lora_name
75
-
76
- sd_model.lora_layer_mapping = lora_layer_mapping
77
-
78
-
79
- def load_lora(name, filename):
80
- lora = LoraModule(name)
81
- lora.mtime = os.path.getmtime(filename)
82
-
83
- sd = sd_models.read_state_dict(filename)
84
-
85
- keys_failed_to_match = []
86
-
87
- for key_diffusers, weight in sd.items():
88
- fullkey = convert_diffusers_name_to_compvis(key_diffusers)
89
- key, lora_key = fullkey.split(".", 1)
90
-
91
- sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
92
- if sd_module is None:
93
- keys_failed_to_match.append(key_diffusers)
94
- continue
95
-
96
- lora_module = lora.modules.get(key, None)
97
- if lora_module is None:
98
- lora_module = LoraUpDownModule()
99
- lora.modules[key] = lora_module
100
-
101
- if lora_key == "alpha":
102
- lora_module.alpha = weight.item()
103
- continue
104
-
105
- if type(sd_module) == torch.nn.Linear:
106
- module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
107
- elif type(sd_module) == torch.nn.Conv2d:
108
- module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
109
- else:
110
- assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
111
-
112
- with torch.no_grad():
113
- module.weight.copy_(weight)
114
-
115
- module.to(device=devices.device, dtype=devices.dtype)
116
-
117
- if lora_key == "lora_up.weight":
118
- lora_module.up = module
119
- elif lora_key == "lora_down.weight":
120
- lora_module.down = module
121
- else:
122
- assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
123
-
124
- if len(keys_failed_to_match) > 0:
125
- print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
126
-
127
- return lora
128
-
129
-
130
- def load_loras(names, multipliers=None):
131
- already_loaded = {}
132
-
133
- for lora in loaded_loras:
134
- if lora.name in names:
135
- already_loaded[lora.name] = lora
136
-
137
- loaded_loras.clear()
138
-
139
- loras_on_disk = [available_loras.get(name, None) for name in names]
140
- if any([x is None for x in loras_on_disk]):
141
- list_available_loras()
142
-
143
- loras_on_disk = [available_loras.get(name, None) for name in names]
144
-
145
- for i, name in enumerate(names):
146
- lora = already_loaded.get(name, None)
147
-
148
- lora_on_disk = loras_on_disk[i]
149
- if lora_on_disk is not None:
150
- if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
151
- lora = load_lora(name, lora_on_disk.filename)
152
-
153
- if lora is None:
154
- print(f"Couldn't find Lora with name {name}")
155
- continue
156
-
157
- lora.multiplier = multipliers[i] if multipliers else 1.0
158
- loaded_loras.append(lora)
159
-
160
-
161
- def lora_forward(module, input, res):
162
- if len(loaded_loras) == 0:
163
- return res
164
-
165
- lora_layer_name = getattr(module, 'lora_layer_name', None)
166
- for lora in loaded_loras:
167
- module = lora.modules.get(lora_layer_name, None)
168
- if module is not None:
169
- if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
170
- res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
171
- else:
172
- res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
173
-
174
- return res
175
-
176
-
177
- def lora_Linear_forward(self, input):
178
- return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
179
-
180
-
181
- def lora_Conv2d_forward(self, input):
182
- return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
183
-
184
-
185
- def list_available_loras():
186
- available_loras.clear()
187
-
188
- os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
189
-
190
- candidates = \
191
- glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
192
- glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
193
- glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
194
-
195
- for filename in sorted(candidates):
196
- if os.path.isdir(filename):
197
- continue
198
-
199
- name = os.path.splitext(os.path.basename(filename))[0]
200
-
201
- available_loras[name] = LoraOnDisk(name, filename)
202
-
203
-
204
- available_loras = {}
205
- loaded_loras = []
206
-
207
- list_available_loras()
 
1
+ import glob
2
+ import os
3
+ import re
4
+ import torch
5
+
6
+ from modules import shared, devices, sd_models
7
+
8
+ re_digits = re.compile(r"\d+")
9
+ re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
10
+ re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
11
+ re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
12
+ re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
13
+
14
+
15
+ def convert_diffusers_name_to_compvis(key):
16
+ def match(match_list, regex):
17
+ r = re.match(regex, key)
18
+ if not r:
19
+ return False
20
+
21
+ match_list.clear()
22
+ match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
23
+ return True
24
+
25
+ m = []
26
+
27
+ if match(m, re_unet_down_blocks):
28
+ return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
29
+
30
+ if match(m, re_unet_mid_blocks):
31
+ return f"diffusion_model_middle_block_1_{m[1]}"
32
+
33
+ if match(m, re_unet_up_blocks):
34
+ return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
35
+
36
+ if match(m, re_text_block):
37
+ return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
38
+
39
+ return key
40
+
41
+
42
+ class LoraOnDisk:
43
+ def __init__(self, name, filename):
44
+ self.name = name
45
+ self.filename = filename
46
+
47
+
48
+ class LoraModule:
49
+ def __init__(self, name):
50
+ self.name = name
51
+ self.multiplier = 1.0
52
+ self.modules = {}
53
+ self.mtime = None
54
+
55
+
56
+ class LoraUpDownModule:
57
+ def __init__(self):
58
+ self.up = None
59
+ self.down = None
60
+ self.alpha = None
61
+
62
+
63
+ def assign_lora_names_to_compvis_modules(sd_model):
64
+ lora_layer_mapping = {}
65
+
66
+ for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
67
+ lora_name = name.replace(".", "_")
68
+ lora_layer_mapping[lora_name] = module
69
+ module.lora_layer_name = lora_name
70
+
71
+ for name, module in shared.sd_model.model.named_modules():
72
+ lora_name = name.replace(".", "_")
73
+ lora_layer_mapping[lora_name] = module
74
+ module.lora_layer_name = lora_name
75
+
76
+ sd_model.lora_layer_mapping = lora_layer_mapping
77
+
78
+
79
+ def load_lora(name, filename):
80
+ lora = LoraModule(name)
81
+ lora.mtime = os.path.getmtime(filename)
82
+
83
+ sd = sd_models.read_state_dict(filename)
84
+
85
+ keys_failed_to_match = []
86
+
87
+ for key_diffusers, weight in sd.items():
88
+ fullkey = convert_diffusers_name_to_compvis(key_diffusers)
89
+ key, lora_key = fullkey.split(".", 1)
90
+
91
+ sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
92
+ if sd_module is None:
93
+ keys_failed_to_match.append(key_diffusers)
94
+ continue
95
+
96
+ lora_module = lora.modules.get(key, None)
97
+ if lora_module is None:
98
+ lora_module = LoraUpDownModule()
99
+ lora.modules[key] = lora_module
100
+
101
+ if lora_key == "alpha":
102
+ lora_module.alpha = weight.item()
103
+ continue
104
+
105
+ if type(sd_module) == torch.nn.Linear:
106
+ module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
107
+ elif type(sd_module) == torch.nn.Conv2d:
108
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
109
+ else:
110
+ assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
111
+
112
+ with torch.no_grad():
113
+ module.weight.copy_(weight)
114
+
115
+ module.to(device=devices.device, dtype=devices.dtype)
116
+
117
+ if lora_key == "lora_up.weight":
118
+ lora_module.up = module
119
+ elif lora_key == "lora_down.weight":
120
+ lora_module.down = module
121
+ else:
122
+ assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
123
+
124
+ if len(keys_failed_to_match) > 0:
125
+ print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
126
+
127
+ return lora
128
+
129
+
130
+ def load_loras(names, multipliers=None):
131
+ already_loaded = {}
132
+
133
+ for lora in loaded_loras:
134
+ if lora.name in names:
135
+ already_loaded[lora.name] = lora
136
+
137
+ loaded_loras.clear()
138
+
139
+ loras_on_disk = [available_loras.get(name, None) for name in names]
140
+ if any([x is None for x in loras_on_disk]):
141
+ list_available_loras()
142
+
143
+ loras_on_disk = [available_loras.get(name, None) for name in names]
144
+
145
+ for i, name in enumerate(names):
146
+ lora = already_loaded.get(name, None)
147
+
148
+ lora_on_disk = loras_on_disk[i]
149
+ if lora_on_disk is not None:
150
+ if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
151
+ lora = load_lora(name, lora_on_disk.filename)
152
+
153
+ if lora is None:
154
+ print(f"Couldn't find Lora with name {name}")
155
+ continue
156
+
157
+ lora.multiplier = multipliers[i] if multipliers else 1.0
158
+ loaded_loras.append(lora)
159
+
160
+
161
+ def lora_forward(module, input, res):
162
+ if len(loaded_loras) == 0:
163
+ return res
164
+
165
+ lora_layer_name = getattr(module, 'lora_layer_name', None)
166
+ for lora in loaded_loras:
167
+ module = lora.modules.get(lora_layer_name, None)
168
+ if module is not None:
169
+ if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
170
+ res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
171
+ else:
172
+ res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
173
+
174
+ return res
175
+
176
+
177
+ def lora_Linear_forward(self, input):
178
+ return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
179
+
180
+
181
+ def lora_Conv2d_forward(self, input):
182
+ return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
183
+
184
+
185
+ def list_available_loras():
186
+ available_loras.clear()
187
+
188
+ os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
189
+
190
+ candidates = \
191
+ glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
192
+ glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
193
+ glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
194
+
195
+ for filename in sorted(candidates):
196
+ if os.path.isdir(filename):
197
+ continue
198
+
199
+ name = os.path.splitext(os.path.basename(filename))[0]
200
+
201
+ available_loras[name] = LoraOnDisk(name, filename)
202
+
203
+
204
+ available_loras = {}
205
+ loaded_loras = []
206
+
207
+ list_available_loras()
sd/stable-diffusion-webui/extensions-builtin/Lora/preload.py CHANGED
@@ -1,6 +1,6 @@
1
- import os
2
- from modules import paths
3
-
4
-
5
- def preload(parser):
6
- parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
 
1
+ import os
2
+ from modules import paths
3
+
4
+
5
+ def preload(parser):
6
+ parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
sd/stable-diffusion-webui/extensions-builtin/Lora/scripts/lora_script.py CHANGED
@@ -1,38 +1,38 @@
1
- import torch
2
- import gradio as gr
3
-
4
- import lora
5
- import extra_networks_lora
6
- import ui_extra_networks_lora
7
- from modules import script_callbacks, ui_extra_networks, extra_networks, shared
8
-
9
-
10
- def unload():
11
- torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
12
- torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
13
-
14
-
15
- def before_ui():
16
- ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
17
- extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
18
-
19
-
20
- if not hasattr(torch.nn, 'Linear_forward_before_lora'):
21
- torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
22
-
23
- if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
24
- torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
25
-
26
- torch.nn.Linear.forward = lora.lora_Linear_forward
27
- torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
28
-
29
- script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
30
- script_callbacks.on_script_unloaded(unload)
31
- script_callbacks.on_before_ui(before_ui)
32
-
33
-
34
- shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
35
- "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
36
- "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
37
-
38
- }))
 
1
+ import torch
2
+ import gradio as gr
3
+
4
+ import lora
5
+ import extra_networks_lora
6
+ import ui_extra_networks_lora
7
+ from modules import script_callbacks, ui_extra_networks, extra_networks, shared
8
+
9
+
10
+ def unload():
11
+ torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
12
+ torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
13
+
14
+
15
+ def before_ui():
16
+ ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
17
+ extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
18
+
19
+
20
+ if not hasattr(torch.nn, 'Linear_forward_before_lora'):
21
+ torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
22
+
23
+ if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
24
+ torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
25
+
26
+ torch.nn.Linear.forward = lora.lora_Linear_forward
27
+ torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
28
+
29
+ script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
30
+ script_callbacks.on_script_unloaded(unload)
31
+ script_callbacks.on_before_ui(before_ui)
32
+
33
+
34
+ shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
35
+ "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
36
+ "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
37
+
38
+ }))
sd/stable-diffusion-webui/extensions-builtin/Lora/ui_extra_networks_lora.py CHANGED
@@ -1,37 +1,30 @@
1
- import json
2
- import os
3
- import lora
4
-
5
- from modules import shared, ui_extra_networks
6
-
7
-
8
- class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
9
- def __init__(self):
10
- super().__init__('Lora')
11
-
12
- def refresh(self):
13
- lora.list_available_loras()
14
-
15
- def list_items(self):
16
- for name, lora_on_disk in lora.available_loras.items():
17
- path, ext = os.path.splitext(lora_on_disk.filename)
18
- previews = [path + ".png", path + ".preview.png"]
19
-
20
- preview = None
21
- for file in previews:
22
- if os.path.isfile(file):
23
- preview = self.link_preview(file)
24
- break
25
-
26
- yield {
27
- "name": name,
28
- "filename": path,
29
- "preview": preview,
30
- "search_term": self.search_terms_from_path(lora_on_disk.filename),
31
- "prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
32
- "local_preview": path + ".png",
33
- }
34
-
35
- def allowed_directories_for_previews(self):
36
- return [shared.cmd_opts.lora_dir]
37
-
 
1
+ import json
2
+ import os
3
+ import lora
4
+
5
+ from modules import shared, ui_extra_networks
6
+
7
+
8
+ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
9
+ def __init__(self):
10
+ super().__init__('Lora')
11
+
12
+ def refresh(self):
13
+ lora.list_available_loras()
14
+
15
+ def list_items(self):
16
+ for name, lora_on_disk in lora.available_loras.items():
17
+ path, ext = os.path.splitext(lora_on_disk.filename)
18
+ yield {
19
+ "name": name,
20
+ "filename": path,
21
+ "preview": self._find_preview(path),
22
+ "description": self._find_description(path),
23
+ "search_term": self.search_terms_from_path(lora_on_disk.filename),
24
+ "prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
25
+ "local_preview": path + ".png",
26
+ }
27
+
28
+ def allowed_directories_for_previews(self):
29
+ return [shared.cmd_opts.lora_dir]
30
+
 
 
 
 
 
 
 
sd/stable-diffusion-webui/extensions-builtin/ScuNET/preload.py CHANGED
@@ -1,6 +1,6 @@
1
- import os
2
- from modules import paths
3
-
4
-
5
- def preload(parser):
6
- parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
 
1
+ import os
2
+ from modules import paths
3
+
4
+
5
+ def preload(parser):
6
+ parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
sd/stable-diffusion-webui/extensions-builtin/SwinIR/preload.py CHANGED
@@ -1,6 +1,6 @@
1
- import os
2
- from modules import paths
3
-
4
-
5
- def preload(parser):
6
- parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
 
1
+ import os
2
+ from modules import paths
3
+
4
+
5
+ def preload(parser):
6
+ parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
sd/stable-diffusion-webui/extensions-builtin/SwinIR/swinir_model_arch_v2.py CHANGED
@@ -1,1017 +1,1017 @@
1
- # -----------------------------------------------------------------------------------
2
- # Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
3
- # Written by Conde and Choi et al.
4
- # -----------------------------------------------------------------------------------
5
-
6
- import math
7
- import numpy as np
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- import torch.utils.checkpoint as checkpoint
12
- from timm.models.layers import DropPath, to_2tuple, trunc_normal_
13
-
14
-
15
- class Mlp(nn.Module):
16
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
17
- super().__init__()
18
- out_features = out_features or in_features
19
- hidden_features = hidden_features or in_features
20
- self.fc1 = nn.Linear(in_features, hidden_features)
21
- self.act = act_layer()
22
- self.fc2 = nn.Linear(hidden_features, out_features)
23
- self.drop = nn.Dropout(drop)
24
-
25
- def forward(self, x):
26
- x = self.fc1(x)
27
- x = self.act(x)
28
- x = self.drop(x)
29
- x = self.fc2(x)
30
- x = self.drop(x)
31
- return x
32
-
33
-
34
- def window_partition(x, window_size):
35
- """
36
- Args:
37
- x: (B, H, W, C)
38
- window_size (int): window size
39
- Returns:
40
- windows: (num_windows*B, window_size, window_size, C)
41
- """
42
- B, H, W, C = x.shape
43
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
44
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
45
- return windows
46
-
47
-
48
- def window_reverse(windows, window_size, H, W):
49
- """
50
- Args:
51
- windows: (num_windows*B, window_size, window_size, C)
52
- window_size (int): Window size
53
- H (int): Height of image
54
- W (int): Width of image
55
- Returns:
56
- x: (B, H, W, C)
57
- """
58
- B = int(windows.shape[0] / (H * W / window_size / window_size))
59
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
60
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
61
- return x
62
-
63
- class WindowAttention(nn.Module):
64
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
65
- It supports both of shifted and non-shifted window.
66
- Args:
67
- dim (int): Number of input channels.
68
- window_size (tuple[int]): The height and width of the window.
69
- num_heads (int): Number of attention heads.
70
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
71
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
72
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
73
- pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
74
- """
75
-
76
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
77
- pretrained_window_size=[0, 0]):
78
-
79
- super().__init__()
80
- self.dim = dim
81
- self.window_size = window_size # Wh, Ww
82
- self.pretrained_window_size = pretrained_window_size
83
- self.num_heads = num_heads
84
-
85
- self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
86
-
87
- # mlp to generate continuous relative position bias
88
- self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
89
- nn.ReLU(inplace=True),
90
- nn.Linear(512, num_heads, bias=False))
91
-
92
- # get relative_coords_table
93
- relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
94
- relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
95
- relative_coords_table = torch.stack(
96
- torch.meshgrid([relative_coords_h,
97
- relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
98
- if pretrained_window_size[0] > 0:
99
- relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
100
- relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
101
- else:
102
- relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
103
- relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
104
- relative_coords_table *= 8 # normalize to -8, 8
105
- relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
106
- torch.abs(relative_coords_table) + 1.0) / np.log2(8)
107
-
108
- self.register_buffer("relative_coords_table", relative_coords_table)
109
-
110
- # get pair-wise relative position index for each token inside the window
111
- coords_h = torch.arange(self.window_size[0])
112
- coords_w = torch.arange(self.window_size[1])
113
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
114
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
115
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
116
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
117
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
118
- relative_coords[:, :, 1] += self.window_size[1] - 1
119
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
120
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
121
- self.register_buffer("relative_position_index", relative_position_index)
122
-
123
- self.qkv = nn.Linear(dim, dim * 3, bias=False)
124
- if qkv_bias:
125
- self.q_bias = nn.Parameter(torch.zeros(dim))
126
- self.v_bias = nn.Parameter(torch.zeros(dim))
127
- else:
128
- self.q_bias = None
129
- self.v_bias = None
130
- self.attn_drop = nn.Dropout(attn_drop)
131
- self.proj = nn.Linear(dim, dim)
132
- self.proj_drop = nn.Dropout(proj_drop)
133
- self.softmax = nn.Softmax(dim=-1)
134
-
135
- def forward(self, x, mask=None):
136
- """
137
- Args:
138
- x: input features with shape of (num_windows*B, N, C)
139
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
140
- """
141
- B_, N, C = x.shape
142
- qkv_bias = None
143
- if self.q_bias is not None:
144
- qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
145
- qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
146
- qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
147
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
148
-
149
- # cosine attention
150
- attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
151
- logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
152
- attn = attn * logit_scale
153
-
154
- relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
155
- relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
156
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
157
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
158
- relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
159
- attn = attn + relative_position_bias.unsqueeze(0)
160
-
161
- if mask is not None:
162
- nW = mask.shape[0]
163
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
164
- attn = attn.view(-1, self.num_heads, N, N)
165
- attn = self.softmax(attn)
166
- else:
167
- attn = self.softmax(attn)
168
-
169
- attn = self.attn_drop(attn)
170
-
171
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
172
- x = self.proj(x)
173
- x = self.proj_drop(x)
174
- return x
175
-
176
- def extra_repr(self) -> str:
177
- return f'dim={self.dim}, window_size={self.window_size}, ' \
178
- f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
179
-
180
- def flops(self, N):
181
- # calculate flops for 1 window with token length of N
182
- flops = 0
183
- # qkv = self.qkv(x)
184
- flops += N * self.dim * 3 * self.dim
185
- # attn = (q @ k.transpose(-2, -1))
186
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
187
- # x = (attn @ v)
188
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
189
- # x = self.proj(x)
190
- flops += N * self.dim * self.dim
191
- return flops
192
-
193
- class SwinTransformerBlock(nn.Module):
194
- r""" Swin Transformer Block.
195
- Args:
196
- dim (int): Number of input channels.
197
- input_resolution (tuple[int]): Input resulotion.
198
- num_heads (int): Number of attention heads.
199
- window_size (int): Window size.
200
- shift_size (int): Shift size for SW-MSA.
201
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
202
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
203
- drop (float, optional): Dropout rate. Default: 0.0
204
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
205
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
206
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
207
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
208
- pretrained_window_size (int): Window size in pre-training.
209
- """
210
-
211
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
212
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
213
- act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
214
- super().__init__()
215
- self.dim = dim
216
- self.input_resolution = input_resolution
217
- self.num_heads = num_heads
218
- self.window_size = window_size
219
- self.shift_size = shift_size
220
- self.mlp_ratio = mlp_ratio
221
- if min(self.input_resolution) <= self.window_size:
222
- # if window size is larger than input resolution, we don't partition windows
223
- self.shift_size = 0
224
- self.window_size = min(self.input_resolution)
225
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
226
-
227
- self.norm1 = norm_layer(dim)
228
- self.attn = WindowAttention(
229
- dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
230
- qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
231
- pretrained_window_size=to_2tuple(pretrained_window_size))
232
-
233
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
234
- self.norm2 = norm_layer(dim)
235
- mlp_hidden_dim = int(dim * mlp_ratio)
236
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
237
-
238
- if self.shift_size > 0:
239
- attn_mask = self.calculate_mask(self.input_resolution)
240
- else:
241
- attn_mask = None
242
-
243
- self.register_buffer("attn_mask", attn_mask)
244
-
245
- def calculate_mask(self, x_size):
246
- # calculate attention mask for SW-MSA
247
- H, W = x_size
248
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
249
- h_slices = (slice(0, -self.window_size),
250
- slice(-self.window_size, -self.shift_size),
251
- slice(-self.shift_size, None))
252
- w_slices = (slice(0, -self.window_size),
253
- slice(-self.window_size, -self.shift_size),
254
- slice(-self.shift_size, None))
255
- cnt = 0
256
- for h in h_slices:
257
- for w in w_slices:
258
- img_mask[:, h, w, :] = cnt
259
- cnt += 1
260
-
261
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
262
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
263
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
264
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
265
-
266
- return attn_mask
267
-
268
- def forward(self, x, x_size):
269
- H, W = x_size
270
- B, L, C = x.shape
271
- #assert L == H * W, "input feature has wrong size"
272
-
273
- shortcut = x
274
- x = x.view(B, H, W, C)
275
-
276
- # cyclic shift
277
- if self.shift_size > 0:
278
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
279
- else:
280
- shifted_x = x
281
-
282
- # partition windows
283
- x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
284
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
285
-
286
- # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
287
- if self.input_resolution == x_size:
288
- attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
289
- else:
290
- attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
291
-
292
- # merge windows
293
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
294
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
295
-
296
- # reverse cyclic shift
297
- if self.shift_size > 0:
298
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
299
- else:
300
- x = shifted_x
301
- x = x.view(B, H * W, C)
302
- x = shortcut + self.drop_path(self.norm1(x))
303
-
304
- # FFN
305
- x = x + self.drop_path(self.norm2(self.mlp(x)))
306
-
307
- return x
308
-
309
- def extra_repr(self) -> str:
310
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
311
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
312
-
313
- def flops(self):
314
- flops = 0
315
- H, W = self.input_resolution
316
- # norm1
317
- flops += self.dim * H * W
318
- # W-MSA/SW-MSA
319
- nW = H * W / self.window_size / self.window_size
320
- flops += nW * self.attn.flops(self.window_size * self.window_size)
321
- # mlp
322
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
323
- # norm2
324
- flops += self.dim * H * W
325
- return flops
326
-
327
- class PatchMerging(nn.Module):
328
- r""" Patch Merging Layer.
329
- Args:
330
- input_resolution (tuple[int]): Resolution of input feature.
331
- dim (int): Number of input channels.
332
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
333
- """
334
-
335
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
336
- super().__init__()
337
- self.input_resolution = input_resolution
338
- self.dim = dim
339
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
340
- self.norm = norm_layer(2 * dim)
341
-
342
- def forward(self, x):
343
- """
344
- x: B, H*W, C
345
- """
346
- H, W = self.input_resolution
347
- B, L, C = x.shape
348
- assert L == H * W, "input feature has wrong size"
349
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
350
-
351
- x = x.view(B, H, W, C)
352
-
353
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
354
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
355
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
356
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
357
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
358
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
359
-
360
- x = self.reduction(x)
361
- x = self.norm(x)
362
-
363
- return x
364
-
365
- def extra_repr(self) -> str:
366
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
367
-
368
- def flops(self):
369
- H, W = self.input_resolution
370
- flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
371
- flops += H * W * self.dim // 2
372
- return flops
373
-
374
- class BasicLayer(nn.Module):
375
- """ A basic Swin Transformer layer for one stage.
376
- Args:
377
- dim (int): Number of input channels.
378
- input_resolution (tuple[int]): Input resolution.
379
- depth (int): Number of blocks.
380
- num_heads (int): Number of attention heads.
381
- window_size (int): Local window size.
382
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
383
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
384
- drop (float, optional): Dropout rate. Default: 0.0
385
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
386
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
387
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
388
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
389
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
390
- pretrained_window_size (int): Local window size in pre-training.
391
- """
392
-
393
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
394
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
395
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
396
- pretrained_window_size=0):
397
-
398
- super().__init__()
399
- self.dim = dim
400
- self.input_resolution = input_resolution
401
- self.depth = depth
402
- self.use_checkpoint = use_checkpoint
403
-
404
- # build blocks
405
- self.blocks = nn.ModuleList([
406
- SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
407
- num_heads=num_heads, window_size=window_size,
408
- shift_size=0 if (i % 2 == 0) else window_size // 2,
409
- mlp_ratio=mlp_ratio,
410
- qkv_bias=qkv_bias,
411
- drop=drop, attn_drop=attn_drop,
412
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
413
- norm_layer=norm_layer,
414
- pretrained_window_size=pretrained_window_size)
415
- for i in range(depth)])
416
-
417
- # patch merging layer
418
- if downsample is not None:
419
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
420
- else:
421
- self.downsample = None
422
-
423
- def forward(self, x, x_size):
424
- for blk in self.blocks:
425
- if self.use_checkpoint:
426
- x = checkpoint.checkpoint(blk, x, x_size)
427
- else:
428
- x = blk(x, x_size)
429
- if self.downsample is not None:
430
- x = self.downsample(x)
431
- return x
432
-
433
- def extra_repr(self) -> str:
434
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
435
-
436
- def flops(self):
437
- flops = 0
438
- for blk in self.blocks:
439
- flops += blk.flops()
440
- if self.downsample is not None:
441
- flops += self.downsample.flops()
442
- return flops
443
-
444
- def _init_respostnorm(self):
445
- for blk in self.blocks:
446
- nn.init.constant_(blk.norm1.bias, 0)
447
- nn.init.constant_(blk.norm1.weight, 0)
448
- nn.init.constant_(blk.norm2.bias, 0)
449
- nn.init.constant_(blk.norm2.weight, 0)
450
-
451
- class PatchEmbed(nn.Module):
452
- r""" Image to Patch Embedding
453
- Args:
454
- img_size (int): Image size. Default: 224.
455
- patch_size (int): Patch token size. Default: 4.
456
- in_chans (int): Number of input image channels. Default: 3.
457
- embed_dim (int): Number of linear projection output channels. Default: 96.
458
- norm_layer (nn.Module, optional): Normalization layer. Default: None
459
- """
460
-
461
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
462
- super().__init__()
463
- img_size = to_2tuple(img_size)
464
- patch_size = to_2tuple(patch_size)
465
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
466
- self.img_size = img_size
467
- self.patch_size = patch_size
468
- self.patches_resolution = patches_resolution
469
- self.num_patches = patches_resolution[0] * patches_resolution[1]
470
-
471
- self.in_chans = in_chans
472
- self.embed_dim = embed_dim
473
-
474
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
475
- if norm_layer is not None:
476
- self.norm = norm_layer(embed_dim)
477
- else:
478
- self.norm = None
479
-
480
- def forward(self, x):
481
- B, C, H, W = x.shape
482
- # FIXME look at relaxing size constraints
483
- # assert H == self.img_size[0] and W == self.img_size[1],
484
- # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
485
- x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
486
- if self.norm is not None:
487
- x = self.norm(x)
488
- return x
489
-
490
- def flops(self):
491
- Ho, Wo = self.patches_resolution
492
- flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
493
- if self.norm is not None:
494
- flops += Ho * Wo * self.embed_dim
495
- return flops
496
-
497
- class RSTB(nn.Module):
498
- """Residual Swin Transformer Block (RSTB).
499
-
500
- Args:
501
- dim (int): Number of input channels.
502
- input_resolution (tuple[int]): Input resolution.
503
- depth (int): Number of blocks.
504
- num_heads (int): Number of attention heads.
505
- window_size (int): Local window size.
506
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
507
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
508
- drop (float, optional): Dropout rate. Default: 0.0
509
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
510
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
511
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
512
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
513
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
514
- img_size: Input image size.
515
- patch_size: Patch size.
516
- resi_connection: The convolutional block before residual connection.
517
- """
518
-
519
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
520
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
521
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
522
- img_size=224, patch_size=4, resi_connection='1conv'):
523
- super(RSTB, self).__init__()
524
-
525
- self.dim = dim
526
- self.input_resolution = input_resolution
527
-
528
- self.residual_group = BasicLayer(dim=dim,
529
- input_resolution=input_resolution,
530
- depth=depth,
531
- num_heads=num_heads,
532
- window_size=window_size,
533
- mlp_ratio=mlp_ratio,
534
- qkv_bias=qkv_bias,
535
- drop=drop, attn_drop=attn_drop,
536
- drop_path=drop_path,
537
- norm_layer=norm_layer,
538
- downsample=downsample,
539
- use_checkpoint=use_checkpoint)
540
-
541
- if resi_connection == '1conv':
542
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
543
- elif resi_connection == '3conv':
544
- # to save parameters and memory
545
- self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
546
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
547
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
548
- nn.Conv2d(dim // 4, dim, 3, 1, 1))
549
-
550
- self.patch_embed = PatchEmbed(
551
- img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
552
- norm_layer=None)
553
-
554
- self.patch_unembed = PatchUnEmbed(
555
- img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
556
- norm_layer=None)
557
-
558
- def forward(self, x, x_size):
559
- return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
560
-
561
- def flops(self):
562
- flops = 0
563
- flops += self.residual_group.flops()
564
- H, W = self.input_resolution
565
- flops += H * W * self.dim * self.dim * 9
566
- flops += self.patch_embed.flops()
567
- flops += self.patch_unembed.flops()
568
-
569
- return flops
570
-
571
- class PatchUnEmbed(nn.Module):
572
- r""" Image to Patch Unembedding
573
-
574
- Args:
575
- img_size (int): Image size. Default: 224.
576
- patch_size (int): Patch token size. Default: 4.
577
- in_chans (int): Number of input image channels. Default: 3.
578
- embed_dim (int): Number of linear projection output channels. Default: 96.
579
- norm_layer (nn.Module, optional): Normalization layer. Default: None
580
- """
581
-
582
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
583
- super().__init__()
584
- img_size = to_2tuple(img_size)
585
- patch_size = to_2tuple(patch_size)
586
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
587
- self.img_size = img_size
588
- self.patch_size = patch_size
589
- self.patches_resolution = patches_resolution
590
- self.num_patches = patches_resolution[0] * patches_resolution[1]
591
-
592
- self.in_chans = in_chans
593
- self.embed_dim = embed_dim
594
-
595
- def forward(self, x, x_size):
596
- B, HW, C = x.shape
597
- x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
598
- return x
599
-
600
- def flops(self):
601
- flops = 0
602
- return flops
603
-
604
-
605
- class Upsample(nn.Sequential):
606
- """Upsample module.
607
-
608
- Args:
609
- scale (int): Scale factor. Supported scales: 2^n and 3.
610
- num_feat (int): Channel number of intermediate features.
611
- """
612
-
613
- def __init__(self, scale, num_feat):
614
- m = []
615
- if (scale & (scale - 1)) == 0: # scale = 2^n
616
- for _ in range(int(math.log(scale, 2))):
617
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
618
- m.append(nn.PixelShuffle(2))
619
- elif scale == 3:
620
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
621
- m.append(nn.PixelShuffle(3))
622
- else:
623
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
624
- super(Upsample, self).__init__(*m)
625
-
626
- class Upsample_hf(nn.Sequential):
627
- """Upsample module.
628
-
629
- Args:
630
- scale (int): Scale factor. Supported scales: 2^n and 3.
631
- num_feat (int): Channel number of intermediate features.
632
- """
633
-
634
- def __init__(self, scale, num_feat):
635
- m = []
636
- if (scale & (scale - 1)) == 0: # scale = 2^n
637
- for _ in range(int(math.log(scale, 2))):
638
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
639
- m.append(nn.PixelShuffle(2))
640
- elif scale == 3:
641
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
642
- m.append(nn.PixelShuffle(3))
643
- else:
644
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
645
- super(Upsample_hf, self).__init__(*m)
646
-
647
-
648
- class UpsampleOneStep(nn.Sequential):
649
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
650
- Used in lightweight SR to save parameters.
651
-
652
- Args:
653
- scale (int): Scale factor. Supported scales: 2^n and 3.
654
- num_feat (int): Channel number of intermediate features.
655
-
656
- """
657
-
658
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
659
- self.num_feat = num_feat
660
- self.input_resolution = input_resolution
661
- m = []
662
- m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
663
- m.append(nn.PixelShuffle(scale))
664
- super(UpsampleOneStep, self).__init__(*m)
665
-
666
- def flops(self):
667
- H, W = self.input_resolution
668
- flops = H * W * self.num_feat * 3 * 9
669
- return flops
670
-
671
-
672
-
673
- class Swin2SR(nn.Module):
674
- r""" Swin2SR
675
- A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
676
-
677
- Args:
678
- img_size (int | tuple(int)): Input image size. Default 64
679
- patch_size (int | tuple(int)): Patch size. Default: 1
680
- in_chans (int): Number of input image channels. Default: 3
681
- embed_dim (int): Patch embedding dimension. Default: 96
682
- depths (tuple(int)): Depth of each Swin Transformer layer.
683
- num_heads (tuple(int)): Number of attention heads in different layers.
684
- window_size (int): Window size. Default: 7
685
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
686
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
687
- drop_rate (float): Dropout rate. Default: 0
688
- attn_drop_rate (float): Attention dropout rate. Default: 0
689
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
690
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
691
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
692
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
693
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
694
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
695
- img_range: Image range. 1. or 255.
696
- upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
697
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
698
- """
699
-
700
- def __init__(self, img_size=64, patch_size=1, in_chans=3,
701
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
702
- window_size=7, mlp_ratio=4., qkv_bias=True,
703
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
704
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
705
- use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
706
- **kwargs):
707
- super(Swin2SR, self).__init__()
708
- num_in_ch = in_chans
709
- num_out_ch = in_chans
710
- num_feat = 64
711
- self.img_range = img_range
712
- if in_chans == 3:
713
- rgb_mean = (0.4488, 0.4371, 0.4040)
714
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
715
- else:
716
- self.mean = torch.zeros(1, 1, 1, 1)
717
- self.upscale = upscale
718
- self.upsampler = upsampler
719
- self.window_size = window_size
720
-
721
- #####################################################################################################
722
- ################################### 1, shallow feature extraction ###################################
723
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
724
-
725
- #####################################################################################################
726
- ################################### 2, deep feature extraction ######################################
727
- self.num_layers = len(depths)
728
- self.embed_dim = embed_dim
729
- self.ape = ape
730
- self.patch_norm = patch_norm
731
- self.num_features = embed_dim
732
- self.mlp_ratio = mlp_ratio
733
-
734
- # split image into non-overlapping patches
735
- self.patch_embed = PatchEmbed(
736
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
737
- norm_layer=norm_layer if self.patch_norm else None)
738
- num_patches = self.patch_embed.num_patches
739
- patches_resolution = self.patch_embed.patches_resolution
740
- self.patches_resolution = patches_resolution
741
-
742
- # merge non-overlapping patches into image
743
- self.patch_unembed = PatchUnEmbed(
744
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
745
- norm_layer=norm_layer if self.patch_norm else None)
746
-
747
- # absolute position embedding
748
- if self.ape:
749
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
750
- trunc_normal_(self.absolute_pos_embed, std=.02)
751
-
752
- self.pos_drop = nn.Dropout(p=drop_rate)
753
-
754
- # stochastic depth
755
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
756
-
757
- # build Residual Swin Transformer blocks (RSTB)
758
- self.layers = nn.ModuleList()
759
- for i_layer in range(self.num_layers):
760
- layer = RSTB(dim=embed_dim,
761
- input_resolution=(patches_resolution[0],
762
- patches_resolution[1]),
763
- depth=depths[i_layer],
764
- num_heads=num_heads[i_layer],
765
- window_size=window_size,
766
- mlp_ratio=self.mlp_ratio,
767
- qkv_bias=qkv_bias,
768
- drop=drop_rate, attn_drop=attn_drop_rate,
769
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
770
- norm_layer=norm_layer,
771
- downsample=None,
772
- use_checkpoint=use_checkpoint,
773
- img_size=img_size,
774
- patch_size=patch_size,
775
- resi_connection=resi_connection
776
-
777
- )
778
- self.layers.append(layer)
779
-
780
- if self.upsampler == 'pixelshuffle_hf':
781
- self.layers_hf = nn.ModuleList()
782
- for i_layer in range(self.num_layers):
783
- layer = RSTB(dim=embed_dim,
784
- input_resolution=(patches_resolution[0],
785
- patches_resolution[1]),
786
- depth=depths[i_layer],
787
- num_heads=num_heads[i_layer],
788
- window_size=window_size,
789
- mlp_ratio=self.mlp_ratio,
790
- qkv_bias=qkv_bias,
791
- drop=drop_rate, attn_drop=attn_drop_rate,
792
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
793
- norm_layer=norm_layer,
794
- downsample=None,
795
- use_checkpoint=use_checkpoint,
796
- img_size=img_size,
797
- patch_size=patch_size,
798
- resi_connection=resi_connection
799
-
800
- )
801
- self.layers_hf.append(layer)
802
-
803
- self.norm = norm_layer(self.num_features)
804
-
805
- # build the last conv layer in deep feature extraction
806
- if resi_connection == '1conv':
807
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
808
- elif resi_connection == '3conv':
809
- # to save parameters and memory
810
- self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
811
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
812
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
813
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
814
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
815
-
816
- #####################################################################################################
817
- ################################ 3, high quality image reconstruction ################################
818
- if self.upsampler == 'pixelshuffle':
819
- # for classical SR
820
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
821
- nn.LeakyReLU(inplace=True))
822
- self.upsample = Upsample(upscale, num_feat)
823
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
824
- elif self.upsampler == 'pixelshuffle_aux':
825
- self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
826
- self.conv_before_upsample = nn.Sequential(
827
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
828
- nn.LeakyReLU(inplace=True))
829
- self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
830
- self.conv_after_aux = nn.Sequential(
831
- nn.Conv2d(3, num_feat, 3, 1, 1),
832
- nn.LeakyReLU(inplace=True))
833
- self.upsample = Upsample(upscale, num_feat)
834
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
835
-
836
- elif self.upsampler == 'pixelshuffle_hf':
837
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
838
- nn.LeakyReLU(inplace=True))
839
- self.upsample = Upsample(upscale, num_feat)
840
- self.upsample_hf = Upsample_hf(upscale, num_feat)
841
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
842
- self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
843
- nn.LeakyReLU(inplace=True))
844
- self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
845
- self.conv_before_upsample_hf = nn.Sequential(
846
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
847
- nn.LeakyReLU(inplace=True))
848
- self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
849
-
850
- elif self.upsampler == 'pixelshuffledirect':
851
- # for lightweight SR (to save parameters)
852
- self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
853
- (patches_resolution[0], patches_resolution[1]))
854
- elif self.upsampler == 'nearest+conv':
855
- # for real-world SR (less artifacts)
856
- assert self.upscale == 4, 'only support x4 now.'
857
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
858
- nn.LeakyReLU(inplace=True))
859
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
860
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
861
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
862
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
863
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
864
- else:
865
- # for image denoising and JPEG compression artifact reduction
866
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
867
-
868
- self.apply(self._init_weights)
869
-
870
- def _init_weights(self, m):
871
- if isinstance(m, nn.Linear):
872
- trunc_normal_(m.weight, std=.02)
873
- if isinstance(m, nn.Linear) and m.bias is not None:
874
- nn.init.constant_(m.bias, 0)
875
- elif isinstance(m, nn.LayerNorm):
876
- nn.init.constant_(m.bias, 0)
877
- nn.init.constant_(m.weight, 1.0)
878
-
879
- @torch.jit.ignore
880
- def no_weight_decay(self):
881
- return {'absolute_pos_embed'}
882
-
883
- @torch.jit.ignore
884
- def no_weight_decay_keywords(self):
885
- return {'relative_position_bias_table'}
886
-
887
- def check_image_size(self, x):
888
- _, _, h, w = x.size()
889
- mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
890
- mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
891
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
892
- return x
893
-
894
- def forward_features(self, x):
895
- x_size = (x.shape[2], x.shape[3])
896
- x = self.patch_embed(x)
897
- if self.ape:
898
- x = x + self.absolute_pos_embed
899
- x = self.pos_drop(x)
900
-
901
- for layer in self.layers:
902
- x = layer(x, x_size)
903
-
904
- x = self.norm(x) # B L C
905
- x = self.patch_unembed(x, x_size)
906
-
907
- return x
908
-
909
- def forward_features_hf(self, x):
910
- x_size = (x.shape[2], x.shape[3])
911
- x = self.patch_embed(x)
912
- if self.ape:
913
- x = x + self.absolute_pos_embed
914
- x = self.pos_drop(x)
915
-
916
- for layer in self.layers_hf:
917
- x = layer(x, x_size)
918
-
919
- x = self.norm(x) # B L C
920
- x = self.patch_unembed(x, x_size)
921
-
922
- return x
923
-
924
- def forward(self, x):
925
- H, W = x.shape[2:]
926
- x = self.check_image_size(x)
927
-
928
- self.mean = self.mean.type_as(x)
929
- x = (x - self.mean) * self.img_range
930
-
931
- if self.upsampler == 'pixelshuffle':
932
- # for classical SR
933
- x = self.conv_first(x)
934
- x = self.conv_after_body(self.forward_features(x)) + x
935
- x = self.conv_before_upsample(x)
936
- x = self.conv_last(self.upsample(x))
937
- elif self.upsampler == 'pixelshuffle_aux':
938
- bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
939
- bicubic = self.conv_bicubic(bicubic)
940
- x = self.conv_first(x)
941
- x = self.conv_after_body(self.forward_features(x)) + x
942
- x = self.conv_before_upsample(x)
943
- aux = self.conv_aux(x) # b, 3, LR_H, LR_W
944
- x = self.conv_after_aux(aux)
945
- x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
946
- x = self.conv_last(x)
947
- aux = aux / self.img_range + self.mean
948
- elif self.upsampler == 'pixelshuffle_hf':
949
- # for classical SR with HF
950
- x = self.conv_first(x)
951
- x = self.conv_after_body(self.forward_features(x)) + x
952
- x_before = self.conv_before_upsample(x)
953
- x_out = self.conv_last(self.upsample(x_before))
954
-
955
- x_hf = self.conv_first_hf(x_before)
956
- x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
957
- x_hf = self.conv_before_upsample_hf(x_hf)
958
- x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
959
- x = x_out + x_hf
960
- x_hf = x_hf / self.img_range + self.mean
961
-
962
- elif self.upsampler == 'pixelshuffledirect':
963
- # for lightweight SR
964
- x = self.conv_first(x)
965
- x = self.conv_after_body(self.forward_features(x)) + x
966
- x = self.upsample(x)
967
- elif self.upsampler == 'nearest+conv':
968
- # for real-world SR
969
- x = self.conv_first(x)
970
- x = self.conv_after_body(self.forward_features(x)) + x
971
- x = self.conv_before_upsample(x)
972
- x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
973
- x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
974
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
975
- else:
976
- # for image denoising and JPEG compression artifact reduction
977
- x_first = self.conv_first(x)
978
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
979
- x = x + self.conv_last(res)
980
-
981
- x = x / self.img_range + self.mean
982
- if self.upsampler == "pixelshuffle_aux":
983
- return x[:, :, :H*self.upscale, :W*self.upscale], aux
984
-
985
- elif self.upsampler == "pixelshuffle_hf":
986
- x_out = x_out / self.img_range + self.mean
987
- return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
988
-
989
- else:
990
- return x[:, :, :H*self.upscale, :W*self.upscale]
991
-
992
- def flops(self):
993
- flops = 0
994
- H, W = self.patches_resolution
995
- flops += H * W * 3 * self.embed_dim * 9
996
- flops += self.patch_embed.flops()
997
- for i, layer in enumerate(self.layers):
998
- flops += layer.flops()
999
- flops += H * W * 3 * self.embed_dim * self.embed_dim
1000
- flops += self.upsample.flops()
1001
- return flops
1002
-
1003
-
1004
- if __name__ == '__main__':
1005
- upscale = 4
1006
- window_size = 8
1007
- height = (1024 // upscale // window_size + 1) * window_size
1008
- width = (720 // upscale // window_size + 1) * window_size
1009
- model = Swin2SR(upscale=2, img_size=(height, width),
1010
- window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
1011
- embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
1012
- print(model)
1013
- print(height, width, model.flops() / 1e9)
1014
-
1015
- x = torch.randn((1, 3, height, width))
1016
- x = model(x)
1017
  print(x.shape)
 
1
+ # -----------------------------------------------------------------------------------
2
+ # Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
3
+ # Written by Conde and Choi et al.
4
+ # -----------------------------------------------------------------------------------
5
+
6
+ import math
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint as checkpoint
12
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
13
+
14
+
15
+ class Mlp(nn.Module):
16
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
17
+ super().__init__()
18
+ out_features = out_features or in_features
19
+ hidden_features = hidden_features or in_features
20
+ self.fc1 = nn.Linear(in_features, hidden_features)
21
+ self.act = act_layer()
22
+ self.fc2 = nn.Linear(hidden_features, out_features)
23
+ self.drop = nn.Dropout(drop)
24
+
25
+ def forward(self, x):
26
+ x = self.fc1(x)
27
+ x = self.act(x)
28
+ x = self.drop(x)
29
+ x = self.fc2(x)
30
+ x = self.drop(x)
31
+ return x
32
+
33
+
34
+ def window_partition(x, window_size):
35
+ """
36
+ Args:
37
+ x: (B, H, W, C)
38
+ window_size (int): window size
39
+ Returns:
40
+ windows: (num_windows*B, window_size, window_size, C)
41
+ """
42
+ B, H, W, C = x.shape
43
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
44
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
45
+ return windows
46
+
47
+
48
+ def window_reverse(windows, window_size, H, W):
49
+ """
50
+ Args:
51
+ windows: (num_windows*B, window_size, window_size, C)
52
+ window_size (int): Window size
53
+ H (int): Height of image
54
+ W (int): Width of image
55
+ Returns:
56
+ x: (B, H, W, C)
57
+ """
58
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
59
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
60
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
61
+ return x
62
+
63
+ class WindowAttention(nn.Module):
64
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
65
+ It supports both of shifted and non-shifted window.
66
+ Args:
67
+ dim (int): Number of input channels.
68
+ window_size (tuple[int]): The height and width of the window.
69
+ num_heads (int): Number of attention heads.
70
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
71
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
72
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
73
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
74
+ """
75
+
76
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
77
+ pretrained_window_size=[0, 0]):
78
+
79
+ super().__init__()
80
+ self.dim = dim
81
+ self.window_size = window_size # Wh, Ww
82
+ self.pretrained_window_size = pretrained_window_size
83
+ self.num_heads = num_heads
84
+
85
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
86
+
87
+ # mlp to generate continuous relative position bias
88
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
89
+ nn.ReLU(inplace=True),
90
+ nn.Linear(512, num_heads, bias=False))
91
+
92
+ # get relative_coords_table
93
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
94
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
95
+ relative_coords_table = torch.stack(
96
+ torch.meshgrid([relative_coords_h,
97
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
98
+ if pretrained_window_size[0] > 0:
99
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
100
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
101
+ else:
102
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
103
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
104
+ relative_coords_table *= 8 # normalize to -8, 8
105
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
106
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
107
+
108
+ self.register_buffer("relative_coords_table", relative_coords_table)
109
+
110
+ # get pair-wise relative position index for each token inside the window
111
+ coords_h = torch.arange(self.window_size[0])
112
+ coords_w = torch.arange(self.window_size[1])
113
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
114
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
115
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
116
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
117
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
118
+ relative_coords[:, :, 1] += self.window_size[1] - 1
119
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
120
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
121
+ self.register_buffer("relative_position_index", relative_position_index)
122
+
123
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
124
+ if qkv_bias:
125
+ self.q_bias = nn.Parameter(torch.zeros(dim))
126
+ self.v_bias = nn.Parameter(torch.zeros(dim))
127
+ else:
128
+ self.q_bias = None
129
+ self.v_bias = None
130
+ self.attn_drop = nn.Dropout(attn_drop)
131
+ self.proj = nn.Linear(dim, dim)
132
+ self.proj_drop = nn.Dropout(proj_drop)
133
+ self.softmax = nn.Softmax(dim=-1)
134
+
135
+ def forward(self, x, mask=None):
136
+ """
137
+ Args:
138
+ x: input features with shape of (num_windows*B, N, C)
139
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
140
+ """
141
+ B_, N, C = x.shape
142
+ qkv_bias = None
143
+ if self.q_bias is not None:
144
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
145
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
146
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
147
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
148
+
149
+ # cosine attention
150
+ attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
151
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
152
+ attn = attn * logit_scale
153
+
154
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
155
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
156
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
157
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
158
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
159
+ attn = attn + relative_position_bias.unsqueeze(0)
160
+
161
+ if mask is not None:
162
+ nW = mask.shape[0]
163
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
164
+ attn = attn.view(-1, self.num_heads, N, N)
165
+ attn = self.softmax(attn)
166
+ else:
167
+ attn = self.softmax(attn)
168
+
169
+ attn = self.attn_drop(attn)
170
+
171
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
172
+ x = self.proj(x)
173
+ x = self.proj_drop(x)
174
+ return x
175
+
176
+ def extra_repr(self) -> str:
177
+ return f'dim={self.dim}, window_size={self.window_size}, ' \
178
+ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
179
+
180
+ def flops(self, N):
181
+ # calculate flops for 1 window with token length of N
182
+ flops = 0
183
+ # qkv = self.qkv(x)
184
+ flops += N * self.dim * 3 * self.dim
185
+ # attn = (q @ k.transpose(-2, -1))
186
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
187
+ # x = (attn @ v)
188
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
189
+ # x = self.proj(x)
190
+ flops += N * self.dim * self.dim
191
+ return flops
192
+
193
+ class SwinTransformerBlock(nn.Module):
194
+ r""" Swin Transformer Block.
195
+ Args:
196
+ dim (int): Number of input channels.
197
+ input_resolution (tuple[int]): Input resulotion.
198
+ num_heads (int): Number of attention heads.
199
+ window_size (int): Window size.
200
+ shift_size (int): Shift size for SW-MSA.
201
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
202
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
203
+ drop (float, optional): Dropout rate. Default: 0.0
204
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
205
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
206
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
207
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
208
+ pretrained_window_size (int): Window size in pre-training.
209
+ """
210
+
211
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
212
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
213
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
214
+ super().__init__()
215
+ self.dim = dim
216
+ self.input_resolution = input_resolution
217
+ self.num_heads = num_heads
218
+ self.window_size = window_size
219
+ self.shift_size = shift_size
220
+ self.mlp_ratio = mlp_ratio
221
+ if min(self.input_resolution) <= self.window_size:
222
+ # if window size is larger than input resolution, we don't partition windows
223
+ self.shift_size = 0
224
+ self.window_size = min(self.input_resolution)
225
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
226
+
227
+ self.norm1 = norm_layer(dim)
228
+ self.attn = WindowAttention(
229
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
230
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
231
+ pretrained_window_size=to_2tuple(pretrained_window_size))
232
+
233
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
234
+ self.norm2 = norm_layer(dim)
235
+ mlp_hidden_dim = int(dim * mlp_ratio)
236
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
237
+
238
+ if self.shift_size > 0:
239
+ attn_mask = self.calculate_mask(self.input_resolution)
240
+ else:
241
+ attn_mask = None
242
+
243
+ self.register_buffer("attn_mask", attn_mask)
244
+
245
+ def calculate_mask(self, x_size):
246
+ # calculate attention mask for SW-MSA
247
+ H, W = x_size
248
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
249
+ h_slices = (slice(0, -self.window_size),
250
+ slice(-self.window_size, -self.shift_size),
251
+ slice(-self.shift_size, None))
252
+ w_slices = (slice(0, -self.window_size),
253
+ slice(-self.window_size, -self.shift_size),
254
+ slice(-self.shift_size, None))
255
+ cnt = 0
256
+ for h in h_slices:
257
+ for w in w_slices:
258
+ img_mask[:, h, w, :] = cnt
259
+ cnt += 1
260
+
261
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
262
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
263
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
264
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
265
+
266
+ return attn_mask
267
+
268
+ def forward(self, x, x_size):
269
+ H, W = x_size
270
+ B, L, C = x.shape
271
+ #assert L == H * W, "input feature has wrong size"
272
+
273
+ shortcut = x
274
+ x = x.view(B, H, W, C)
275
+
276
+ # cyclic shift
277
+ if self.shift_size > 0:
278
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
279
+ else:
280
+ shifted_x = x
281
+
282
+ # partition windows
283
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
284
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
285
+
286
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
287
+ if self.input_resolution == x_size:
288
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
289
+ else:
290
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
291
+
292
+ # merge windows
293
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
294
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
295
+
296
+ # reverse cyclic shift
297
+ if self.shift_size > 0:
298
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
299
+ else:
300
+ x = shifted_x
301
+ x = x.view(B, H * W, C)
302
+ x = shortcut + self.drop_path(self.norm1(x))
303
+
304
+ # FFN
305
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
306
+
307
+ return x
308
+
309
+ def extra_repr(self) -> str:
310
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
311
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
312
+
313
+ def flops(self):
314
+ flops = 0
315
+ H, W = self.input_resolution
316
+ # norm1
317
+ flops += self.dim * H * W
318
+ # W-MSA/SW-MSA
319
+ nW = H * W / self.window_size / self.window_size
320
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
321
+ # mlp
322
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
323
+ # norm2
324
+ flops += self.dim * H * W
325
+ return flops
326
+
327
+ class PatchMerging(nn.Module):
328
+ r""" Patch Merging Layer.
329
+ Args:
330
+ input_resolution (tuple[int]): Resolution of input feature.
331
+ dim (int): Number of input channels.
332
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
333
+ """
334
+
335
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
336
+ super().__init__()
337
+ self.input_resolution = input_resolution
338
+ self.dim = dim
339
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
340
+ self.norm = norm_layer(2 * dim)
341
+
342
+ def forward(self, x):
343
+ """
344
+ x: B, H*W, C
345
+ """
346
+ H, W = self.input_resolution
347
+ B, L, C = x.shape
348
+ assert L == H * W, "input feature has wrong size"
349
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
350
+
351
+ x = x.view(B, H, W, C)
352
+
353
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
354
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
355
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
356
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
357
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
358
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
359
+
360
+ x = self.reduction(x)
361
+ x = self.norm(x)
362
+
363
+ return x
364
+
365
+ def extra_repr(self) -> str:
366
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
367
+
368
+ def flops(self):
369
+ H, W = self.input_resolution
370
+ flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
371
+ flops += H * W * self.dim // 2
372
+ return flops
373
+
374
+ class BasicLayer(nn.Module):
375
+ """ A basic Swin Transformer layer for one stage.
376
+ Args:
377
+ dim (int): Number of input channels.
378
+ input_resolution (tuple[int]): Input resolution.
379
+ depth (int): Number of blocks.
380
+ num_heads (int): Number of attention heads.
381
+ window_size (int): Local window size.
382
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
383
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
384
+ drop (float, optional): Dropout rate. Default: 0.0
385
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
386
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
387
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
388
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
389
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
390
+ pretrained_window_size (int): Local window size in pre-training.
391
+ """
392
+
393
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
394
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
395
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
396
+ pretrained_window_size=0):
397
+
398
+ super().__init__()
399
+ self.dim = dim
400
+ self.input_resolution = input_resolution
401
+ self.depth = depth
402
+ self.use_checkpoint = use_checkpoint
403
+
404
+ # build blocks
405
+ self.blocks = nn.ModuleList([
406
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
407
+ num_heads=num_heads, window_size=window_size,
408
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
409
+ mlp_ratio=mlp_ratio,
410
+ qkv_bias=qkv_bias,
411
+ drop=drop, attn_drop=attn_drop,
412
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
413
+ norm_layer=norm_layer,
414
+ pretrained_window_size=pretrained_window_size)
415
+ for i in range(depth)])
416
+
417
+ # patch merging layer
418
+ if downsample is not None:
419
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
420
+ else:
421
+ self.downsample = None
422
+
423
+ def forward(self, x, x_size):
424
+ for blk in self.blocks:
425
+ if self.use_checkpoint:
426
+ x = checkpoint.checkpoint(blk, x, x_size)
427
+ else:
428
+ x = blk(x, x_size)
429
+ if self.downsample is not None:
430
+ x = self.downsample(x)
431
+ return x
432
+
433
+ def extra_repr(self) -> str:
434
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
435
+
436
+ def flops(self):
437
+ flops = 0
438
+ for blk in self.blocks:
439
+ flops += blk.flops()
440
+ if self.downsample is not None:
441
+ flops += self.downsample.flops()
442
+ return flops
443
+
444
+ def _init_respostnorm(self):
445
+ for blk in self.blocks:
446
+ nn.init.constant_(blk.norm1.bias, 0)
447
+ nn.init.constant_(blk.norm1.weight, 0)
448
+ nn.init.constant_(blk.norm2.bias, 0)
449
+ nn.init.constant_(blk.norm2.weight, 0)
450
+
451
+ class PatchEmbed(nn.Module):
452
+ r""" Image to Patch Embedding
453
+ Args:
454
+ img_size (int): Image size. Default: 224.
455
+ patch_size (int): Patch token size. Default: 4.
456
+ in_chans (int): Number of input image channels. Default: 3.
457
+ embed_dim (int): Number of linear projection output channels. Default: 96.
458
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
459
+ """
460
+
461
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
462
+ super().__init__()
463
+ img_size = to_2tuple(img_size)
464
+ patch_size = to_2tuple(patch_size)
465
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
466
+ self.img_size = img_size
467
+ self.patch_size = patch_size
468
+ self.patches_resolution = patches_resolution
469
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
470
+
471
+ self.in_chans = in_chans
472
+ self.embed_dim = embed_dim
473
+
474
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
475
+ if norm_layer is not None:
476
+ self.norm = norm_layer(embed_dim)
477
+ else:
478
+ self.norm = None
479
+
480
+ def forward(self, x):
481
+ B, C, H, W = x.shape
482
+ # FIXME look at relaxing size constraints
483
+ # assert H == self.img_size[0] and W == self.img_size[1],
484
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
485
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
486
+ if self.norm is not None:
487
+ x = self.norm(x)
488
+ return x
489
+
490
+ def flops(self):
491
+ Ho, Wo = self.patches_resolution
492
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
493
+ if self.norm is not None:
494
+ flops += Ho * Wo * self.embed_dim
495
+ return flops
496
+
497
+ class RSTB(nn.Module):
498
+ """Residual Swin Transformer Block (RSTB).
499
+
500
+ Args:
501
+ dim (int): Number of input channels.
502
+ input_resolution (tuple[int]): Input resolution.
503
+ depth (int): Number of blocks.
504
+ num_heads (int): Number of attention heads.
505
+ window_size (int): Local window size.
506
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
507
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
508
+ drop (float, optional): Dropout rate. Default: 0.0
509
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
510
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
511
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
512
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
513
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
514
+ img_size: Input image size.
515
+ patch_size: Patch size.
516
+ resi_connection: The convolutional block before residual connection.
517
+ """
518
+
519
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
520
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
521
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
522
+ img_size=224, patch_size=4, resi_connection='1conv'):
523
+ super(RSTB, self).__init__()
524
+
525
+ self.dim = dim
526
+ self.input_resolution = input_resolution
527
+
528
+ self.residual_group = BasicLayer(dim=dim,
529
+ input_resolution=input_resolution,
530
+ depth=depth,
531
+ num_heads=num_heads,
532
+ window_size=window_size,
533
+ mlp_ratio=mlp_ratio,
534
+ qkv_bias=qkv_bias,
535
+ drop=drop, attn_drop=attn_drop,
536
+ drop_path=drop_path,
537
+ norm_layer=norm_layer,
538
+ downsample=downsample,
539
+ use_checkpoint=use_checkpoint)
540
+
541
+ if resi_connection == '1conv':
542
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
543
+ elif resi_connection == '3conv':
544
+ # to save parameters and memory
545
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
546
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
547
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
548
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
549
+
550
+ self.patch_embed = PatchEmbed(
551
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
552
+ norm_layer=None)
553
+
554
+ self.patch_unembed = PatchUnEmbed(
555
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
556
+ norm_layer=None)
557
+
558
+ def forward(self, x, x_size):
559
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
560
+
561
+ def flops(self):
562
+ flops = 0
563
+ flops += self.residual_group.flops()
564
+ H, W = self.input_resolution
565
+ flops += H * W * self.dim * self.dim * 9
566
+ flops += self.patch_embed.flops()
567
+ flops += self.patch_unembed.flops()
568
+
569
+ return flops
570
+
571
+ class PatchUnEmbed(nn.Module):
572
+ r""" Image to Patch Unembedding
573
+
574
+ Args:
575
+ img_size (int): Image size. Default: 224.
576
+ patch_size (int): Patch token size. Default: 4.
577
+ in_chans (int): Number of input image channels. Default: 3.
578
+ embed_dim (int): Number of linear projection output channels. Default: 96.
579
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
580
+ """
581
+
582
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
583
+ super().__init__()
584
+ img_size = to_2tuple(img_size)
585
+ patch_size = to_2tuple(patch_size)
586
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
587
+ self.img_size = img_size
588
+ self.patch_size = patch_size
589
+ self.patches_resolution = patches_resolution
590
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
591
+
592
+ self.in_chans = in_chans
593
+ self.embed_dim = embed_dim
594
+
595
+ def forward(self, x, x_size):
596
+ B, HW, C = x.shape
597
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
598
+ return x
599
+
600
+ def flops(self):
601
+ flops = 0
602
+ return flops
603
+
604
+
605
+ class Upsample(nn.Sequential):
606
+ """Upsample module.
607
+
608
+ Args:
609
+ scale (int): Scale factor. Supported scales: 2^n and 3.
610
+ num_feat (int): Channel number of intermediate features.
611
+ """
612
+
613
+ def __init__(self, scale, num_feat):
614
+ m = []
615
+ if (scale & (scale - 1)) == 0: # scale = 2^n
616
+ for _ in range(int(math.log(scale, 2))):
617
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
618
+ m.append(nn.PixelShuffle(2))
619
+ elif scale == 3:
620
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
621
+ m.append(nn.PixelShuffle(3))
622
+ else:
623
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
624
+ super(Upsample, self).__init__(*m)
625
+
626
+ class Upsample_hf(nn.Sequential):
627
+ """Upsample module.
628
+
629
+ Args:
630
+ scale (int): Scale factor. Supported scales: 2^n and 3.
631
+ num_feat (int): Channel number of intermediate features.
632
+ """
633
+
634
+ def __init__(self, scale, num_feat):
635
+ m = []
636
+ if (scale & (scale - 1)) == 0: # scale = 2^n
637
+ for _ in range(int(math.log(scale, 2))):
638
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
639
+ m.append(nn.PixelShuffle(2))
640
+ elif scale == 3:
641
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
642
+ m.append(nn.PixelShuffle(3))
643
+ else:
644
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
645
+ super(Upsample_hf, self).__init__(*m)
646
+
647
+
648
+ class UpsampleOneStep(nn.Sequential):
649
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
650
+ Used in lightweight SR to save parameters.
651
+
652
+ Args:
653
+ scale (int): Scale factor. Supported scales: 2^n and 3.
654
+ num_feat (int): Channel number of intermediate features.
655
+
656
+ """
657
+
658
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
659
+ self.num_feat = num_feat
660
+ self.input_resolution = input_resolution
661
+ m = []
662
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
663
+ m.append(nn.PixelShuffle(scale))
664
+ super(UpsampleOneStep, self).__init__(*m)
665
+
666
+ def flops(self):
667
+ H, W = self.input_resolution
668
+ flops = H * W * self.num_feat * 3 * 9
669
+ return flops
670
+
671
+
672
+
673
+ class Swin2SR(nn.Module):
674
+ r""" Swin2SR
675
+ A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
676
+
677
+ Args:
678
+ img_size (int | tuple(int)): Input image size. Default 64
679
+ patch_size (int | tuple(int)): Patch size. Default: 1
680
+ in_chans (int): Number of input image channels. Default: 3
681
+ embed_dim (int): Patch embedding dimension. Default: 96
682
+ depths (tuple(int)): Depth of each Swin Transformer layer.
683
+ num_heads (tuple(int)): Number of attention heads in different layers.
684
+ window_size (int): Window size. Default: 7
685
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
686
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
687
+ drop_rate (float): Dropout rate. Default: 0
688
+ attn_drop_rate (float): Attention dropout rate. Default: 0
689
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
690
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
691
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
692
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
693
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
694
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
695
+ img_range: Image range. 1. or 255.
696
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
697
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
698
+ """
699
+
700
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
701
+ embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
702
+ window_size=7, mlp_ratio=4., qkv_bias=True,
703
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
704
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
705
+ use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
706
+ **kwargs):
707
+ super(Swin2SR, self).__init__()
708
+ num_in_ch = in_chans
709
+ num_out_ch = in_chans
710
+ num_feat = 64
711
+ self.img_range = img_range
712
+ if in_chans == 3:
713
+ rgb_mean = (0.4488, 0.4371, 0.4040)
714
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
715
+ else:
716
+ self.mean = torch.zeros(1, 1, 1, 1)
717
+ self.upscale = upscale
718
+ self.upsampler = upsampler
719
+ self.window_size = window_size
720
+
721
+ #####################################################################################################
722
+ ################################### 1, shallow feature extraction ###################################
723
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
724
+
725
+ #####################################################################################################
726
+ ################################### 2, deep feature extraction ######################################
727
+ self.num_layers = len(depths)
728
+ self.embed_dim = embed_dim
729
+ self.ape = ape
730
+ self.patch_norm = patch_norm
731
+ self.num_features = embed_dim
732
+ self.mlp_ratio = mlp_ratio
733
+
734
+ # split image into non-overlapping patches
735
+ self.patch_embed = PatchEmbed(
736
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
737
+ norm_layer=norm_layer if self.patch_norm else None)
738
+ num_patches = self.patch_embed.num_patches
739
+ patches_resolution = self.patch_embed.patches_resolution
740
+ self.patches_resolution = patches_resolution
741
+
742
+ # merge non-overlapping patches into image
743
+ self.patch_unembed = PatchUnEmbed(
744
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
745
+ norm_layer=norm_layer if self.patch_norm else None)
746
+
747
+ # absolute position embedding
748
+ if self.ape:
749
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
750
+ trunc_normal_(self.absolute_pos_embed, std=.02)
751
+
752
+ self.pos_drop = nn.Dropout(p=drop_rate)
753
+
754
+ # stochastic depth
755
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
756
+
757
+ # build Residual Swin Transformer blocks (RSTB)
758
+ self.layers = nn.ModuleList()
759
+ for i_layer in range(self.num_layers):
760
+ layer = RSTB(dim=embed_dim,
761
+ input_resolution=(patches_resolution[0],
762
+ patches_resolution[1]),
763
+ depth=depths[i_layer],
764
+ num_heads=num_heads[i_layer],
765
+ window_size=window_size,
766
+ mlp_ratio=self.mlp_ratio,
767
+ qkv_bias=qkv_bias,
768
+ drop=drop_rate, attn_drop=attn_drop_rate,
769
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
770
+ norm_layer=norm_layer,
771
+ downsample=None,
772
+ use_checkpoint=use_checkpoint,
773
+ img_size=img_size,
774
+ patch_size=patch_size,
775
+ resi_connection=resi_connection
776
+
777
+ )
778
+ self.layers.append(layer)
779
+
780
+ if self.upsampler == 'pixelshuffle_hf':
781
+ self.layers_hf = nn.ModuleList()
782
+ for i_layer in range(self.num_layers):
783
+ layer = RSTB(dim=embed_dim,
784
+ input_resolution=(patches_resolution[0],
785
+ patches_resolution[1]),
786
+ depth=depths[i_layer],
787
+ num_heads=num_heads[i_layer],
788
+ window_size=window_size,
789
+ mlp_ratio=self.mlp_ratio,
790
+ qkv_bias=qkv_bias,
791
+ drop=drop_rate, attn_drop=attn_drop_rate,
792
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
793
+ norm_layer=norm_layer,
794
+ downsample=None,
795
+ use_checkpoint=use_checkpoint,
796
+ img_size=img_size,
797
+ patch_size=patch_size,
798
+ resi_connection=resi_connection
799
+
800
+ )
801
+ self.layers_hf.append(layer)
802
+
803
+ self.norm = norm_layer(self.num_features)
804
+
805
+ # build the last conv layer in deep feature extraction
806
+ if resi_connection == '1conv':
807
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
808
+ elif resi_connection == '3conv':
809
+ # to save parameters and memory
810
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
811
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
812
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
813
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
814
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
815
+
816
+ #####################################################################################################
817
+ ################################ 3, high quality image reconstruction ################################
818
+ if self.upsampler == 'pixelshuffle':
819
+ # for classical SR
820
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
821
+ nn.LeakyReLU(inplace=True))
822
+ self.upsample = Upsample(upscale, num_feat)
823
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
824
+ elif self.upsampler == 'pixelshuffle_aux':
825
+ self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
826
+ self.conv_before_upsample = nn.Sequential(
827
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
828
+ nn.LeakyReLU(inplace=True))
829
+ self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
830
+ self.conv_after_aux = nn.Sequential(
831
+ nn.Conv2d(3, num_feat, 3, 1, 1),
832
+ nn.LeakyReLU(inplace=True))
833
+ self.upsample = Upsample(upscale, num_feat)
834
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
835
+
836
+ elif self.upsampler == 'pixelshuffle_hf':
837
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
838
+ nn.LeakyReLU(inplace=True))
839
+ self.upsample = Upsample(upscale, num_feat)
840
+ self.upsample_hf = Upsample_hf(upscale, num_feat)
841
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
842
+ self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
843
+ nn.LeakyReLU(inplace=True))
844
+ self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
845
+ self.conv_before_upsample_hf = nn.Sequential(
846
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
847
+ nn.LeakyReLU(inplace=True))
848
+ self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
849
+
850
+ elif self.upsampler == 'pixelshuffledirect':
851
+ # for lightweight SR (to save parameters)
852
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
853
+ (patches_resolution[0], patches_resolution[1]))
854
+ elif self.upsampler == 'nearest+conv':
855
+ # for real-world SR (less artifacts)
856
+ assert self.upscale == 4, 'only support x4 now.'
857
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
858
+ nn.LeakyReLU(inplace=True))
859
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
860
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
861
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
862
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
863
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
864
+ else:
865
+ # for image denoising and JPEG compression artifact reduction
866
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
867
+
868
+ self.apply(self._init_weights)
869
+
870
+ def _init_weights(self, m):
871
+ if isinstance(m, nn.Linear):
872
+ trunc_normal_(m.weight, std=.02)
873
+ if isinstance(m, nn.Linear) and m.bias is not None:
874
+ nn.init.constant_(m.bias, 0)
875
+ elif isinstance(m, nn.LayerNorm):
876
+ nn.init.constant_(m.bias, 0)
877
+ nn.init.constant_(m.weight, 1.0)
878
+
879
+ @torch.jit.ignore
880
+ def no_weight_decay(self):
881
+ return {'absolute_pos_embed'}
882
+
883
+ @torch.jit.ignore
884
+ def no_weight_decay_keywords(self):
885
+ return {'relative_position_bias_table'}
886
+
887
+ def check_image_size(self, x):
888
+ _, _, h, w = x.size()
889
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
890
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
891
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
892
+ return x
893
+
894
+ def forward_features(self, x):
895
+ x_size = (x.shape[2], x.shape[3])
896
+ x = self.patch_embed(x)
897
+ if self.ape:
898
+ x = x + self.absolute_pos_embed
899
+ x = self.pos_drop(x)
900
+
901
+ for layer in self.layers:
902
+ x = layer(x, x_size)
903
+
904
+ x = self.norm(x) # B L C
905
+ x = self.patch_unembed(x, x_size)
906
+
907
+ return x
908
+
909
+ def forward_features_hf(self, x):
910
+ x_size = (x.shape[2], x.shape[3])
911
+ x = self.patch_embed(x)
912
+ if self.ape:
913
+ x = x + self.absolute_pos_embed
914
+ x = self.pos_drop(x)
915
+
916
+ for layer in self.layers_hf:
917
+ x = layer(x, x_size)
918
+
919
+ x = self.norm(x) # B L C
920
+ x = self.patch_unembed(x, x_size)
921
+
922
+ return x
923
+
924
+ def forward(self, x):
925
+ H, W = x.shape[2:]
926
+ x = self.check_image_size(x)
927
+
928
+ self.mean = self.mean.type_as(x)
929
+ x = (x - self.mean) * self.img_range
930
+
931
+ if self.upsampler == 'pixelshuffle':
932
+ # for classical SR
933
+ x = self.conv_first(x)
934
+ x = self.conv_after_body(self.forward_features(x)) + x
935
+ x = self.conv_before_upsample(x)
936
+ x = self.conv_last(self.upsample(x))
937
+ elif self.upsampler == 'pixelshuffle_aux':
938
+ bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
939
+ bicubic = self.conv_bicubic(bicubic)
940
+ x = self.conv_first(x)
941
+ x = self.conv_after_body(self.forward_features(x)) + x
942
+ x = self.conv_before_upsample(x)
943
+ aux = self.conv_aux(x) # b, 3, LR_H, LR_W
944
+ x = self.conv_after_aux(aux)
945
+ x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
946
+ x = self.conv_last(x)
947
+ aux = aux / self.img_range + self.mean
948
+ elif self.upsampler == 'pixelshuffle_hf':
949
+ # for classical SR with HF
950
+ x = self.conv_first(x)
951
+ x = self.conv_after_body(self.forward_features(x)) + x
952
+ x_before = self.conv_before_upsample(x)
953
+ x_out = self.conv_last(self.upsample(x_before))
954
+
955
+ x_hf = self.conv_first_hf(x_before)
956
+ x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
957
+ x_hf = self.conv_before_upsample_hf(x_hf)
958
+ x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
959
+ x = x_out + x_hf
960
+ x_hf = x_hf / self.img_range + self.mean
961
+
962
+ elif self.upsampler == 'pixelshuffledirect':
963
+ # for lightweight SR
964
+ x = self.conv_first(x)
965
+ x = self.conv_after_body(self.forward_features(x)) + x
966
+ x = self.upsample(x)
967
+ elif self.upsampler == 'nearest+conv':
968
+ # for real-world SR
969
+ x = self.conv_first(x)
970
+ x = self.conv_after_body(self.forward_features(x)) + x
971
+ x = self.conv_before_upsample(x)
972
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
973
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
974
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
975
+ else:
976
+ # for image denoising and JPEG compression artifact reduction
977
+ x_first = self.conv_first(x)
978
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
979
+ x = x + self.conv_last(res)
980
+
981
+ x = x / self.img_range + self.mean
982
+ if self.upsampler == "pixelshuffle_aux":
983
+ return x[:, :, :H*self.upscale, :W*self.upscale], aux
984
+
985
+ elif self.upsampler == "pixelshuffle_hf":
986
+ x_out = x_out / self.img_range + self.mean
987
+ return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
988
+
989
+ else:
990
+ return x[:, :, :H*self.upscale, :W*self.upscale]
991
+
992
+ def flops(self):
993
+ flops = 0
994
+ H, W = self.patches_resolution
995
+ flops += H * W * 3 * self.embed_dim * 9
996
+ flops += self.patch_embed.flops()
997
+ for i, layer in enumerate(self.layers):
998
+ flops += layer.flops()
999
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
1000
+ flops += self.upsample.flops()
1001
+ return flops
1002
+
1003
+
1004
+ if __name__ == '__main__':
1005
+ upscale = 4
1006
+ window_size = 8
1007
+ height = (1024 // upscale // window_size + 1) * window_size
1008
+ width = (720 // upscale // window_size + 1) * window_size
1009
+ model = Swin2SR(upscale=2, img_size=(height, width),
1010
+ window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
1011
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
1012
+ print(model)
1013
+ print(height, width, model.flops() / 1e9)
1014
+
1015
+ x = torch.randn((1, 3, height, width))
1016
+ x = model(x)
1017
  print(x.shape)
sd/stable-diffusion-webui/html/extra-networks-card.html CHANGED
@@ -7,6 +7,7 @@
7
  <span style="display:none" class='search_term'>{search_term}</span>
8
  </div>
9
  <span class='name'>{name}</span>
 
10
  </div>
11
  </div>
12
 
 
7
  <span style="display:none" class='search_term'>{search_term}</span>
8
  </div>
9
  <span class='name'>{name}</span>
10
+ <span class='description'>{description}</span>
11
  </div>
12
  </div>
13
 
sd/stable-diffusion-webui/html/footer.html CHANGED
@@ -1,13 +1,13 @@
1
- <div>
2
- <a href="/docs">API</a>
3
-  • 
4
- <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
5
-  • 
6
- <a href="https://gradio.app">Gradio</a>
7
-  • 
8
- <a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
9
- </div>
10
- <br />
11
- <div class="versions">
12
- {versions}
13
- </div>
 
1
+ <div>
2
+ <a href="/docs">API</a>
3
+  • 
4
+ <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
5
+  • 
6
+ <a href="https://gradio.app">Gradio</a>
7
+  • 
8
+ <a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
9
+ </div>
10
+ <br />
11
+ <div class="versions">
12
+ {versions}
13
+ </div>
sd/stable-diffusion-webui/html/licenses.html CHANGED
@@ -1,419 +1,638 @@
1
- <style>
2
- #licenses h2 {font-size: 1.2em; font-weight: bold; margin-bottom: 0.2em;}
3
- #licenses small {font-size: 0.95em; opacity: 0.85;}
4
- #licenses pre { margin: 1em 0 2em 0;}
5
- </style>
6
-
7
- <h2><a href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">CodeFormer</a></h2>
8
- <small>Parts of CodeFormer code had to be copied to be compatible with GFPGAN.</small>
9
- <pre>
10
- S-Lab License 1.0
11
-
12
- Copyright 2022 S-Lab
13
-
14
- Redistribution and use for non-commercial purpose in source and
15
- binary forms, with or without modification, are permitted provided
16
- that the following conditions are met:
17
-
18
- 1. Redistributions of source code must retain the above copyright
19
- notice, this list of conditions and the following disclaimer.
20
-
21
- 2. Redistributions in binary form must reproduce the above copyright
22
- notice, this list of conditions and the following disclaimer in
23
- the documentation and/or other materials provided with the
24
- distribution.
25
-
26
- 3. Neither the name of the copyright holder nor the names of its
27
- contributors may be used to endorse or promote products derived
28
- from this software without specific prior written permission.
29
-
30
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
31
- "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
32
- LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
33
- A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
34
- HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
35
- SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
36
- LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
37
- DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
38
- THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
39
- (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
40
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
41
-
42
- In the event that redistribution and/or use for commercial purpose in
43
- source or binary forms, with or without modification is required,
44
- please contact the contributor(s) of the work.
45
- </pre>
46
-
47
-
48
- <h2><a href="https://github.com/victorca25/iNNfer/blob/main/LICENSE">ESRGAN</a></h2>
49
- <small>Code for architecture and reading models copied.</small>
50
- <pre>
51
- MIT License
52
-
53
- Copyright (c) 2021 victorca25
54
-
55
- Permission is hereby granted, free of charge, to any person obtaining a copy
56
- of this software and associated documentation files (the "Software"), to deal
57
- in the Software without restriction, including without limitation the rights
58
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
59
- copies of the Software, and to permit persons to whom the Software is
60
- furnished to do so, subject to the following conditions:
61
-
62
- The above copyright notice and this permission notice shall be included in all
63
- copies or substantial portions of the Software.
64
-
65
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
66
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
67
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
68
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
69
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
70
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
71
- SOFTWARE.
72
- </pre>
73
-
74
- <h2><a href="https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE">Real-ESRGAN</a></h2>
75
- <small>Some code is copied to support ESRGAN models.</small>
76
- <pre>
77
- BSD 3-Clause License
78
-
79
- Copyright (c) 2021, Xintao Wang
80
- All rights reserved.
81
-
82
- Redistribution and use in source and binary forms, with or without
83
- modification, are permitted provided that the following conditions are met:
84
-
85
- 1. Redistributions of source code must retain the above copyright notice, this
86
- list of conditions and the following disclaimer.
87
-
88
- 2. Redistributions in binary form must reproduce the above copyright notice,
89
- this list of conditions and the following disclaimer in the documentation
90
- and/or other materials provided with the distribution.
91
-
92
- 3. Neither the name of the copyright holder nor the names of its
93
- contributors may be used to endorse or promote products derived from
94
- this software without specific prior written permission.
95
-
96
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
97
- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
98
- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
99
- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
100
- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
101
- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
102
- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
103
- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
104
- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
105
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
106
- </pre>
107
-
108
- <h2><a href="https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE">InvokeAI</a></h2>
109
- <small>Some code for compatibility with OSX is taken from lstein's repository.</small>
110
- <pre>
111
- MIT License
112
-
113
- Copyright (c) 2022 InvokeAI Team
114
-
115
- Permission is hereby granted, free of charge, to any person obtaining a copy
116
- of this software and associated documentation files (the "Software"), to deal
117
- in the Software without restriction, including without limitation the rights
118
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
119
- copies of the Software, and to permit persons to whom the Software is
120
- furnished to do so, subject to the following conditions:
121
-
122
- The above copyright notice and this permission notice shall be included in all
123
- copies or substantial portions of the Software.
124
-
125
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
126
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
127
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
128
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
129
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
130
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
131
- SOFTWARE.
132
- </pre>
133
-
134
- <h2><a href="https://github.com/Hafiidz/latent-diffusion/blob/main/LICENSE">LDSR</a></h2>
135
- <small>Code added by contirubtors, most likely copied from this repository.</small>
136
- <pre>
137
- MIT License
138
-
139
- Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
140
-
141
- Permission is hereby granted, free of charge, to any person obtaining a copy
142
- of this software and associated documentation files (the "Software"), to deal
143
- in the Software without restriction, including without limitation the rights
144
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
145
- copies of the Software, and to permit persons to whom the Software is
146
- furnished to do so, subject to the following conditions:
147
-
148
- The above copyright notice and this permission notice shall be included in all
149
- copies or substantial portions of the Software.
150
-
151
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
152
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
153
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
154
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
155
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
156
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
157
- SOFTWARE.
158
- </pre>
159
-
160
- <h2><a href="https://github.com/pharmapsychotic/clip-interrogator/blob/main/LICENSE">CLIP Interrogator</a></h2>
161
- <small>Some small amounts of code borrowed and reworked.</small>
162
- <pre>
163
- MIT License
164
-
165
- Copyright (c) 2022 pharmapsychotic
166
-
167
- Permission is hereby granted, free of charge, to any person obtaining a copy
168
- of this software and associated documentation files (the "Software"), to deal
169
- in the Software without restriction, including without limitation the rights
170
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
171
- copies of the Software, and to permit persons to whom the Software is
172
- furnished to do so, subject to the following conditions:
173
-
174
- The above copyright notice and this permission notice shall be included in all
175
- copies or substantial portions of the Software.
176
-
177
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
178
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
179
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
180
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
181
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
182
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
183
- SOFTWARE.
184
- </pre>
185
-
186
- <h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
187
- <small>Code added by contributors, most likely copied from this repository.</small>
188
-
189
- <pre>
190
- Apache License
191
- Version 2.0, January 2004
192
- http://www.apache.org/licenses/
193
-
194
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
195
-
196
- 1. Definitions.
197
-
198
- "License" shall mean the terms and conditions for use, reproduction,
199
- and distribution as defined by Sections 1 through 9 of this document.
200
-
201
- "Licensor" shall mean the copyright owner or entity authorized by
202
- the copyright owner that is granting the License.
203
-
204
- "Legal Entity" shall mean the union of the acting entity and all
205
- other entities that control, are controlled by, or are under common
206
- control with that entity. For the purposes of this definition,
207
- "control" means (i) the power, direct or indirect, to cause the
208
- direction or management of such entity, whether by contract or
209
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
210
- outstanding shares, or (iii) beneficial ownership of such entity.
211
-
212
- "You" (or "Your") shall mean an individual or Legal Entity
213
- exercising permissions granted by this License.
214
-
215
- "Source" form shall mean the preferred form for making modifications,
216
- including but not limited to software source code, documentation
217
- source, and configuration files.
218
-
219
- "Object" form shall mean any form resulting from mechanical
220
- transformation or translation of a Source form, including but
221
- not limited to compiled object code, generated documentation,
222
- and conversions to other media types.
223
-
224
- "Work" shall mean the work of authorship, whether in Source or
225
- Object form, made available under the License, as indicated by a
226
- copyright notice that is included in or attached to the work
227
- (an example is provided in the Appendix below).
228
-
229
- "Derivative Works" shall mean any work, whether in Source or Object
230
- form, that is based on (or derived from) the Work and for which the
231
- editorial revisions, annotations, elaborations, or other modifications
232
- represent, as a whole, an original work of authorship. For the purposes
233
- of this License, Derivative Works shall not include works that remain
234
- separable from, or merely link (or bind by name) to the interfaces of,
235
- the Work and Derivative Works thereof.
236
-
237
- "Contribution" shall mean any work of authorship, including
238
- the original version of the Work and any modifications or additions
239
- to that Work or Derivative Works thereof, that is intentionally
240
- submitted to Licensor for inclusion in the Work by the copyright owner
241
- or by an individual or Legal Entity authorized to submit on behalf of
242
- the copyright owner. For the purposes of this definition, "submitted"
243
- means any form of electronic, verbal, or written communication sent
244
- to the Licensor or its representatives, including but not limited to
245
- communication on electronic mailing lists, source code control systems,
246
- and issue tracking systems that are managed by, or on behalf of, the
247
- Licensor for the purpose of discussing and improving the Work, but
248
- excluding communication that is conspicuously marked or otherwise
249
- designated in writing by the copyright owner as "Not a Contribution."
250
-
251
- "Contributor" shall mean Licensor and any individual or Legal Entity
252
- on behalf of whom a Contribution has been received by Licensor and
253
- subsequently incorporated within the Work.
254
-
255
- 2. Grant of Copyright License. Subject to the terms and conditions of
256
- this License, each Contributor hereby grants to You a perpetual,
257
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
258
- copyright license to reproduce, prepare Derivative Works of,
259
- publicly display, publicly perform, sublicense, and distribute the
260
- Work and such Derivative Works in Source or Object form.
261
-
262
- 3. Grant of Patent License. Subject to the terms and conditions of
263
- this License, each Contributor hereby grants to You a perpetual,
264
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
265
- (except as stated in this section) patent license to make, have made,
266
- use, offer to sell, sell, import, and otherwise transfer the Work,
267
- where such license applies only to those patent claims licensable
268
- by such Contributor that are necessarily infringed by their
269
- Contribution(s) alone or by combination of their Contribution(s)
270
- with the Work to which such Contribution(s) was submitted. If You
271
- institute patent litigation against any entity (including a
272
- cross-claim or counterclaim in a lawsuit) alleging that the Work
273
- or a Contribution incorporated within the Work constitutes direct
274
- or contributory patent infringement, then any patent licenses
275
- granted to You under this License for that Work shall terminate
276
- as of the date such litigation is filed.
277
-
278
- 4. Redistribution. You may reproduce and distribute copies of the
279
- Work or Derivative Works thereof in any medium, with or without
280
- modifications, and in Source or Object form, provided that You
281
- meet the following conditions:
282
-
283
- (a) You must give any other recipients of the Work or
284
- Derivative Works a copy of this License; and
285
-
286
- (b) You must cause any modified files to carry prominent notices
287
- stating that You changed the files; and
288
-
289
- (c) You must retain, in the Source form of any Derivative Works
290
- that You distribute, all copyright, patent, trademark, and
291
- attribution notices from the Source form of the Work,
292
- excluding those notices that do not pertain to any part of
293
- the Derivative Works; and
294
-
295
- (d) If the Work includes a "NOTICE" text file as part of its
296
- distribution, then any Derivative Works that You distribute must
297
- include a readable copy of the attribution notices contained
298
- within such NOTICE file, excluding those notices that do not
299
- pertain to any part of the Derivative Works, in at least one
300
- of the following places: within a NOTICE text file distributed
301
- as part of the Derivative Works; within the Source form or
302
- documentation, if provided along with the Derivative Works; or,
303
- within a display generated by the Derivative Works, if and
304
- wherever such third-party notices normally appear. The contents
305
- of the NOTICE file are for informational purposes only and
306
- do not modify the License. You may add Your own attribution
307
- notices within Derivative Works that You distribute, alongside
308
- or as an addendum to the NOTICE text from the Work, provided
309
- that such additional attribution notices cannot be construed
310
- as modifying the License.
311
-
312
- You may add Your own copyright statement to Your modifications and
313
- may provide additional or different license terms and conditions
314
- for use, reproduction, or distribution of Your modifications, or
315
- for any such Derivative Works as a whole, provided Your use,
316
- reproduction, and distribution of the Work otherwise complies with
317
- the conditions stated in this License.
318
-
319
- 5. Submission of Contributions. Unless You explicitly state otherwise,
320
- any Contribution intentionally submitted for inclusion in the Work
321
- by You to the Licensor shall be under the terms and conditions of
322
- this License, without any additional terms or conditions.
323
- Notwithstanding the above, nothing herein shall supersede or modify
324
- the terms of any separate license agreement you may have executed
325
- with Licensor regarding such Contributions.
326
-
327
- 6. Trademarks. This License does not grant permission to use the trade
328
- names, trademarks, service marks, or product names of the Licensor,
329
- except as required for reasonable and customary use in describing the
330
- origin of the Work and reproducing the content of the NOTICE file.
331
-
332
- 7. Disclaimer of Warranty. Unless required by applicable law or
333
- agreed to in writing, Licensor provides the Work (and each
334
- Contributor provides its Contributions) on an "AS IS" BASIS,
335
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
336
- implied, including, without limitation, any warranties or conditions
337
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
338
- PARTICULAR PURPOSE. You are solely responsible for determining the
339
- appropriateness of using or redistributing the Work and assume any
340
- risks associated with Your exercise of permissions under this License.
341
-
342
- 8. Limitation of Liability. In no event and under no legal theory,
343
- whether in tort (including negligence), contract, or otherwise,
344
- unless required by applicable law (such as deliberate and grossly
345
- negligent acts) or agreed to in writing, shall any Contributor be
346
- liable to You for damages, including any direct, indirect, special,
347
- incidental, or consequential damages of any character arising as a
348
- result of this License or out of the use or inability to use the
349
- Work (including but not limited to damages for loss of goodwill,
350
- work stoppage, computer failure or malfunction, or any and all
351
- other commercial damages or losses), even if such Contributor
352
- has been advised of the possibility of such damages.
353
-
354
- 9. Accepting Warranty or Additional Liability. While redistributing
355
- the Work or Derivative Works thereof, You may choose to offer,
356
- and charge a fee for, acceptance of support, warranty, indemnity,
357
- or other liability obligations and/or rights consistent with this
358
- License. However, in accepting such obligations, You may act only
359
- on Your own behalf and on Your sole responsibility, not on behalf
360
- of any other Contributor, and only if You agree to indemnify,
361
- defend, and hold each Contributor harmless for any liability
362
- incurred by, or claims asserted against, such Contributor by reason
363
- of your accepting any such warranty or additional liability.
364
-
365
- END OF TERMS AND CONDITIONS
366
-
367
- APPENDIX: How to apply the Apache License to your work.
368
-
369
- To apply the Apache License to your work, attach the following
370
- boilerplate notice, with the fields enclosed by brackets "[]"
371
- replaced with your own identifying information. (Don't include
372
- the brackets!) The text should be enclosed in the appropriate
373
- comment syntax for the file format. We also recommend that a
374
- file or class name and description of purpose be included on the
375
- same "printed page" as the copyright notice for easier
376
- identification within third-party archives.
377
-
378
- Copyright [2021] [SwinIR Authors]
379
-
380
- Licensed under the Apache License, Version 2.0 (the "License");
381
- you may not use this file except in compliance with the License.
382
- You may obtain a copy of the License at
383
-
384
- http://www.apache.org/licenses/LICENSE-2.0
385
-
386
- Unless required by applicable law or agreed to in writing, software
387
- distributed under the License is distributed on an "AS IS" BASIS,
388
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
389
- See the License for the specific language governing permissions and
390
- limitations under the License.
391
- </pre>
392
-
393
- <h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2>
394
- <small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>
395
- <pre>
396
- MIT License
397
-
398
- Copyright (c) 2023 Alex Birch
399
- Copyright (c) 2023 Amin Rezaei
400
-
401
- Permission is hereby granted, free of charge, to any person obtaining a copy
402
- of this software and associated documentation files (the "Software"), to deal
403
- in the Software without restriction, including without limitation the rights
404
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
405
- copies of the Software, and to permit persons to whom the Software is
406
- furnished to do so, subject to the following conditions:
407
-
408
- The above copyright notice and this permission notice shall be included in all
409
- copies or substantial portions of the Software.
410
-
411
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
412
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
413
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
414
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
415
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
416
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
417
- SOFTWARE.
418
- </pre>
419
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <style>
2
+ #licenses h2 {font-size: 1.2em; font-weight: bold; margin-bottom: 0.2em;}
3
+ #licenses small {font-size: 0.95em; opacity: 0.85;}
4
+ #licenses pre { margin: 1em 0 2em 0;}
5
+ </style>
6
+
7
+ <h2><a href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">CodeFormer</a></h2>
8
+ <small>Parts of CodeFormer code had to be copied to be compatible with GFPGAN.</small>
9
+ <pre>
10
+ S-Lab License 1.0
11
+
12
+ Copyright 2022 S-Lab
13
+
14
+ Redistribution and use for non-commercial purpose in source and
15
+ binary forms, with or without modification, are permitted provided
16
+ that the following conditions are met:
17
+
18
+ 1. Redistributions of source code must retain the above copyright
19
+ notice, this list of conditions and the following disclaimer.
20
+
21
+ 2. Redistributions in binary form must reproduce the above copyright
22
+ notice, this list of conditions and the following disclaimer in
23
+ the documentation and/or other materials provided with the
24
+ distribution.
25
+
26
+ 3. Neither the name of the copyright holder nor the names of its
27
+ contributors may be used to endorse or promote products derived
28
+ from this software without specific prior written permission.
29
+
30
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
31
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
32
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
33
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
34
+ HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
35
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
36
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
37
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
38
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
39
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
40
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
41
+
42
+ In the event that redistribution and/or use for commercial purpose in
43
+ source or binary forms, with or without modification is required,
44
+ please contact the contributor(s) of the work.
45
+ </pre>
46
+
47
+
48
+ <h2><a href="https://github.com/victorca25/iNNfer/blob/main/LICENSE">ESRGAN</a></h2>
49
+ <small>Code for architecture and reading models copied.</small>
50
+ <pre>
51
+ MIT License
52
+
53
+ Copyright (c) 2021 victorca25
54
+
55
+ Permission is hereby granted, free of charge, to any person obtaining a copy
56
+ of this software and associated documentation files (the "Software"), to deal
57
+ in the Software without restriction, including without limitation the rights
58
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
59
+ copies of the Software, and to permit persons to whom the Software is
60
+ furnished to do so, subject to the following conditions:
61
+
62
+ The above copyright notice and this permission notice shall be included in all
63
+ copies or substantial portions of the Software.
64
+
65
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
66
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
67
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
68
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
69
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
70
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
71
+ SOFTWARE.
72
+ </pre>
73
+
74
+ <h2><a href="https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE">Real-ESRGAN</a></h2>
75
+ <small>Some code is copied to support ESRGAN models.</small>
76
+ <pre>
77
+ BSD 3-Clause License
78
+
79
+ Copyright (c) 2021, Xintao Wang
80
+ All rights reserved.
81
+
82
+ Redistribution and use in source and binary forms, with or without
83
+ modification, are permitted provided that the following conditions are met:
84
+
85
+ 1. Redistributions of source code must retain the above copyright notice, this
86
+ list of conditions and the following disclaimer.
87
+
88
+ 2. Redistributions in binary form must reproduce the above copyright notice,
89
+ this list of conditions and the following disclaimer in the documentation
90
+ and/or other materials provided with the distribution.
91
+
92
+ 3. Neither the name of the copyright holder nor the names of its
93
+ contributors may be used to endorse or promote products derived from
94
+ this software without specific prior written permission.
95
+
96
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
97
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
98
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
99
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
100
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
101
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
102
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
103
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
104
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
105
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
106
+ </pre>
107
+
108
+ <h2><a href="https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE">InvokeAI</a></h2>
109
+ <small>Some code for compatibility with OSX is taken from lstein's repository.</small>
110
+ <pre>
111
+ MIT License
112
+
113
+ Copyright (c) 2022 InvokeAI Team
114
+
115
+ Permission is hereby granted, free of charge, to any person obtaining a copy
116
+ of this software and associated documentation files (the "Software"), to deal
117
+ in the Software without restriction, including without limitation the rights
118
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
119
+ copies of the Software, and to permit persons to whom the Software is
120
+ furnished to do so, subject to the following conditions:
121
+
122
+ The above copyright notice and this permission notice shall be included in all
123
+ copies or substantial portions of the Software.
124
+
125
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
126
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
127
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
128
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
129
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
130
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
131
+ SOFTWARE.
132
+ </pre>
133
+
134
+ <h2><a href="https://github.com/Hafiidz/latent-diffusion/blob/main/LICENSE">LDSR</a></h2>
135
+ <small>Code added by contirubtors, most likely copied from this repository.</small>
136
+ <pre>
137
+ MIT License
138
+
139
+ Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
140
+
141
+ Permission is hereby granted, free of charge, to any person obtaining a copy
142
+ of this software and associated documentation files (the "Software"), to deal
143
+ in the Software without restriction, including without limitation the rights
144
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
145
+ copies of the Software, and to permit persons to whom the Software is
146
+ furnished to do so, subject to the following conditions:
147
+
148
+ The above copyright notice and this permission notice shall be included in all
149
+ copies or substantial portions of the Software.
150
+
151
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
152
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
153
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
154
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
155
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
156
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
157
+ SOFTWARE.
158
+ </pre>
159
+
160
+ <h2><a href="https://github.com/pharmapsychotic/clip-interrogator/blob/main/LICENSE">CLIP Interrogator</a></h2>
161
+ <small>Some small amounts of code borrowed and reworked.</small>
162
+ <pre>
163
+ MIT License
164
+
165
+ Copyright (c) 2022 pharmapsychotic
166
+
167
+ Permission is hereby granted, free of charge, to any person obtaining a copy
168
+ of this software and associated documentation files (the "Software"), to deal
169
+ in the Software without restriction, including without limitation the rights
170
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
171
+ copies of the Software, and to permit persons to whom the Software is
172
+ furnished to do so, subject to the following conditions:
173
+
174
+ The above copyright notice and this permission notice shall be included in all
175
+ copies or substantial portions of the Software.
176
+
177
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
178
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
179
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
180
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
181
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
182
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
183
+ SOFTWARE.
184
+ </pre>
185
+
186
+ <h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
187
+ <small>Code added by contributors, most likely copied from this repository.</small>
188
+
189
+ <pre>
190
+ Apache License
191
+ Version 2.0, January 2004
192
+ http://www.apache.org/licenses/
193
+
194
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
195
+
196
+ 1. Definitions.
197
+
198
+ "License" shall mean the terms and conditions for use, reproduction,
199
+ and distribution as defined by Sections 1 through 9 of this document.
200
+
201
+ "Licensor" shall mean the copyright owner or entity authorized by
202
+ the copyright owner that is granting the License.
203
+
204
+ "Legal Entity" shall mean the union of the acting entity and all
205
+ other entities that control, are controlled by, or are under common
206
+ control with that entity. For the purposes of this definition,
207
+ "control" means (i) the power, direct or indirect, to cause the
208
+ direction or management of such entity, whether by contract or
209
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
210
+ outstanding shares, or (iii) beneficial ownership of such entity.
211
+
212
+ "You" (or "Your") shall mean an individual or Legal Entity
213
+ exercising permissions granted by this License.
214
+
215
+ "Source" form shall mean the preferred form for making modifications,
216
+ including but not limited to software source code, documentation
217
+ source, and configuration files.
218
+
219
+ "Object" form shall mean any form resulting from mechanical
220
+ transformation or translation of a Source form, including but
221
+ not limited to compiled object code, generated documentation,
222
+ and conversions to other media types.
223
+
224
+ "Work" shall mean the work of authorship, whether in Source or
225
+ Object form, made available under the License, as indicated by a
226
+ copyright notice that is included in or attached to the work
227
+ (an example is provided in the Appendix below).
228
+
229
+ "Derivative Works" shall mean any work, whether in Source or Object
230
+ form, that is based on (or derived from) the Work and for which the
231
+ editorial revisions, annotations, elaborations, or other modifications
232
+ represent, as a whole, an original work of authorship. For the purposes
233
+ of this License, Derivative Works shall not include works that remain
234
+ separable from, or merely link (or bind by name) to the interfaces of,
235
+ the Work and Derivative Works thereof.
236
+
237
+ "Contribution" shall mean any work of authorship, including
238
+ the original version of the Work and any modifications or additions
239
+ to that Work or Derivative Works thereof, that is intentionally
240
+ submitted to Licensor for inclusion in the Work by the copyright owner
241
+ or by an individual or Legal Entity authorized to submit on behalf of
242
+ the copyright owner. For the purposes of this definition, "submitted"
243
+ means any form of electronic, verbal, or written communication sent
244
+ to the Licensor or its representatives, including but not limited to
245
+ communication on electronic mailing lists, source code control systems,
246
+ and issue tracking systems that are managed by, or on behalf of, the
247
+ Licensor for the purpose of discussing and improving the Work, but
248
+ excluding communication that is conspicuously marked or otherwise
249
+ designated in writing by the copyright owner as "Not a Contribution."
250
+
251
+ "Contributor" shall mean Licensor and any individual or Legal Entity
252
+ on behalf of whom a Contribution has been received by Licensor and
253
+ subsequently incorporated within the Work.
254
+
255
+ 2. Grant of Copyright License. Subject to the terms and conditions of
256
+ this License, each Contributor hereby grants to You a perpetual,
257
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
258
+ copyright license to reproduce, prepare Derivative Works of,
259
+ publicly display, publicly perform, sublicense, and distribute the
260
+ Work and such Derivative Works in Source or Object form.
261
+
262
+ 3. Grant of Patent License. Subject to the terms and conditions of
263
+ this License, each Contributor hereby grants to You a perpetual,
264
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
265
+ (except as stated in this section) patent license to make, have made,
266
+ use, offer to sell, sell, import, and otherwise transfer the Work,
267
+ where such license applies only to those patent claims licensable
268
+ by such Contributor that are necessarily infringed by their
269
+ Contribution(s) alone or by combination of their Contribution(s)
270
+ with the Work to which such Contribution(s) was submitted. If You
271
+ institute patent litigation against any entity (including a
272
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
273
+ or a Contribution incorporated within the Work constitutes direct
274
+ or contributory patent infringement, then any patent licenses
275
+ granted to You under this License for that Work shall terminate
276
+ as of the date such litigation is filed.
277
+
278
+ 4. Redistribution. You may reproduce and distribute copies of the
279
+ Work or Derivative Works thereof in any medium, with or without
280
+ modifications, and in Source or Object form, provided that You
281
+ meet the following conditions:
282
+
283
+ (a) You must give any other recipients of the Work or
284
+ Derivative Works a copy of this License; and
285
+
286
+ (b) You must cause any modified files to carry prominent notices
287
+ stating that You changed the files; and
288
+
289
+ (c) You must retain, in the Source form of any Derivative Works
290
+ that You distribute, all copyright, patent, trademark, and
291
+ attribution notices from the Source form of the Work,
292
+ excluding those notices that do not pertain to any part of
293
+ the Derivative Works; and
294
+
295
+ (d) If the Work includes a "NOTICE" text file as part of its
296
+ distribution, then any Derivative Works that You distribute must
297
+ include a readable copy of the attribution notices contained
298
+ within such NOTICE file, excluding those notices that do not
299
+ pertain to any part of the Derivative Works, in at least one
300
+ of the following places: within a NOTICE text file distributed
301
+ as part of the Derivative Works; within the Source form or
302
+ documentation, if provided along with the Derivative Works; or,
303
+ within a display generated by the Derivative Works, if and
304
+ wherever such third-party notices normally appear. The contents
305
+ of the NOTICE file are for informational purposes only and
306
+ do not modify the License. You may add Your own attribution
307
+ notices within Derivative Works that You distribute, alongside
308
+ or as an addendum to the NOTICE text from the Work, provided
309
+ that such additional attribution notices cannot be construed
310
+ as modifying the License.
311
+
312
+ You may add Your own copyright statement to Your modifications and
313
+ may provide additional or different license terms and conditions
314
+ for use, reproduction, or distribution of Your modifications, or
315
+ for any such Derivative Works as a whole, provided Your use,
316
+ reproduction, and distribution of the Work otherwise complies with
317
+ the conditions stated in this License.
318
+
319
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
320
+ any Contribution intentionally submitted for inclusion in the Work
321
+ by You to the Licensor shall be under the terms and conditions of
322
+ this License, without any additional terms or conditions.
323
+ Notwithstanding the above, nothing herein shall supersede or modify
324
+ the terms of any separate license agreement you may have executed
325
+ with Licensor regarding such Contributions.
326
+
327
+ 6. Trademarks. This License does not grant permission to use the trade
328
+ names, trademarks, service marks, or product names of the Licensor,
329
+ except as required for reasonable and customary use in describing the
330
+ origin of the Work and reproducing the content of the NOTICE file.
331
+
332
+ 7. Disclaimer of Warranty. Unless required by applicable law or
333
+ agreed to in writing, Licensor provides the Work (and each
334
+ Contributor provides its Contributions) on an "AS IS" BASIS,
335
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
336
+ implied, including, without limitation, any warranties or conditions
337
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
338
+ PARTICULAR PURPOSE. You are solely responsible for determining the
339
+ appropriateness of using or redistributing the Work and assume any
340
+ risks associated with Your exercise of permissions under this License.
341
+
342
+ 8. Limitation of Liability. In no event and under no legal theory,
343
+ whether in tort (including negligence), contract, or otherwise,
344
+ unless required by applicable law (such as deliberate and grossly
345
+ negligent acts) or agreed to in writing, shall any Contributor be
346
+ liable to You for damages, including any direct, indirect, special,
347
+ incidental, or consequential damages of any character arising as a
348
+ result of this License or out of the use or inability to use the
349
+ Work (including but not limited to damages for loss of goodwill,
350
+ work stoppage, computer failure or malfunction, or any and all
351
+ other commercial damages or losses), even if such Contributor
352
+ has been advised of the possibility of such damages.
353
+
354
+ 9. Accepting Warranty or Additional Liability. While redistributing
355
+ the Work or Derivative Works thereof, You may choose to offer,
356
+ and charge a fee for, acceptance of support, warranty, indemnity,
357
+ or other liability obligations and/or rights consistent with this
358
+ License. However, in accepting such obligations, You may act only
359
+ on Your own behalf and on Your sole responsibility, not on behalf
360
+ of any other Contributor, and only if You agree to indemnify,
361
+ defend, and hold each Contributor harmless for any liability
362
+ incurred by, or claims asserted against, such Contributor by reason
363
+ of your accepting any such warranty or additional liability.
364
+
365
+ END OF TERMS AND CONDITIONS
366
+
367
+ APPENDIX: How to apply the Apache License to your work.
368
+
369
+ To apply the Apache License to your work, attach the following
370
+ boilerplate notice, with the fields enclosed by brackets "[]"
371
+ replaced with your own identifying information. (Don't include
372
+ the brackets!) The text should be enclosed in the appropriate
373
+ comment syntax for the file format. We also recommend that a
374
+ file or class name and description of purpose be included on the
375
+ same "printed page" as the copyright notice for easier
376
+ identification within third-party archives.
377
+
378
+ Copyright [2021] [SwinIR Authors]
379
+
380
+ Licensed under the Apache License, Version 2.0 (the "License");
381
+ you may not use this file except in compliance with the License.
382
+ You may obtain a copy of the License at
383
+
384
+ http://www.apache.org/licenses/LICENSE-2.0
385
+
386
+ Unless required by applicable law or agreed to in writing, software
387
+ distributed under the License is distributed on an "AS IS" BASIS,
388
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
389
+ See the License for the specific language governing permissions and
390
+ limitations under the License.
391
+ </pre>
392
+
393
+ <h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2>
394
+ <small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>
395
+ <pre>
396
+ MIT License
397
+
398
+ Copyright (c) 2023 Alex Birch
399
+ Copyright (c) 2023 Amin Rezaei
400
+
401
+ Permission is hereby granted, free of charge, to any person obtaining a copy
402
+ of this software and associated documentation files (the "Software"), to deal
403
+ in the Software without restriction, including without limitation the rights
404
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
405
+ copies of the Software, and to permit persons to whom the Software is
406
+ furnished to do so, subject to the following conditions:
407
+
408
+ The above copyright notice and this permission notice shall be included in all
409
+ copies or substantial portions of the Software.
410
+
411
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
412
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
413
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
414
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
415
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
416
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
417
+ SOFTWARE.
418
+ </pre>
419
+
420
+ <h2><a href="https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/LICENSE">Scaled Dot Product Attention</a></h2>
421
+ <small>Some small amounts of code borrowed and reworked.</small>
422
+ <pre>
423
+ Copyright 2023 The HuggingFace Team. All rights reserved.
424
+
425
+ Licensed under the Apache License, Version 2.0 (the "License");
426
+ you may not use this file except in compliance with the License.
427
+ You may obtain a copy of the License at
428
+
429
+ http://www.apache.org/licenses/LICENSE-2.0
430
+
431
+ Unless required by applicable law or agreed to in writing, software
432
+ distributed under the License is distributed on an "AS IS" BASIS,
433
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
434
+ See the License for the specific language governing permissions and
435
+ limitations under the License.
436
+
437
+ Apache License
438
+ Version 2.0, January 2004
439
+ http://www.apache.org/licenses/
440
+
441
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
442
+
443
+ 1. Definitions.
444
+
445
+ "License" shall mean the terms and conditions for use, reproduction,
446
+ and distribution as defined by Sections 1 through 9 of this document.
447
+
448
+ "Licensor" shall mean the copyright owner or entity authorized by
449
+ the copyright owner that is granting the License.
450
+
451
+ "Legal Entity" shall mean the union of the acting entity and all
452
+ other entities that control, are controlled by, or are under common
453
+ control with that entity. For the purposes of this definition,
454
+ "control" means (i) the power, direct or indirect, to cause the
455
+ direction or management of such entity, whether by contract or
456
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
457
+ outstanding shares, or (iii) beneficial ownership of such entity.
458
+
459
+ "You" (or "Your") shall mean an individual or Legal Entity
460
+ exercising permissions granted by this License.
461
+
462
+ "Source" form shall mean the preferred form for making modifications,
463
+ including but not limited to software source code, documentation
464
+ source, and configuration files.
465
+
466
+ "Object" form shall mean any form resulting from mechanical
467
+ transformation or translation of a Source form, including but
468
+ not limited to compiled object code, generated documentation,
469
+ and conversions to other media types.
470
+
471
+ "Work" shall mean the work of authorship, whether in Source or
472
+ Object form, made available under the License, as indicated by a
473
+ copyright notice that is included in or attached to the work
474
+ (an example is provided in the Appendix below).
475
+
476
+ "Derivative Works" shall mean any work, whether in Source or Object
477
+ form, that is based on (or derived from) the Work and for which the
478
+ editorial revisions, annotations, elaborations, or other modifications
479
+ represent, as a whole, an original work of authorship. For the purposes
480
+ of this License, Derivative Works shall not include works that remain
481
+ separable from, or merely link (or bind by name) to the interfaces of,
482
+ the Work and Derivative Works thereof.
483
+
484
+ "Contribution" shall mean any work of authorship, including
485
+ the original version of the Work and any modifications or additions
486
+ to that Work or Derivative Works thereof, that is intentionally
487
+ submitted to Licensor for inclusion in the Work by the copyright owner
488
+ or by an individual or Legal Entity authorized to submit on behalf of
489
+ the copyright owner. For the purposes of this definition, "submitted"
490
+ means any form of electronic, verbal, or written communication sent
491
+ to the Licensor or its representatives, including but not limited to
492
+ communication on electronic mailing lists, source code control systems,
493
+ and issue tracking systems that are managed by, or on behalf of, the
494
+ Licensor for the purpose of discussing and improving the Work, but
495
+ excluding communication that is conspicuously marked or otherwise
496
+ designated in writing by the copyright owner as "Not a Contribution."
497
+
498
+ "Contributor" shall mean Licensor and any individual or Legal Entity
499
+ on behalf of whom a Contribution has been received by Licensor and
500
+ subsequently incorporated within the Work.
501
+
502
+ 2. Grant of Copyright License. Subject to the terms and conditions of
503
+ this License, each Contributor hereby grants to You a perpetual,
504
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
505
+ copyright license to reproduce, prepare Derivative Works of,
506
+ publicly display, publicly perform, sublicense, and distribute the
507
+ Work and such Derivative Works in Source or Object form.
508
+
509
+ 3. Grant of Patent License. Subject to the terms and conditions of
510
+ this License, each Contributor hereby grants to You a perpetual,
511
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
512
+ (except as stated in this section) patent license to make, have made,
513
+ use, offer to sell, sell, import, and otherwise transfer the Work,
514
+ where such license applies only to those patent claims licensable
515
+ by such Contributor that are necessarily infringed by their
516
+ Contribution(s) alone or by combination of their Contribution(s)
517
+ with the Work to which such Contribution(s) was submitted. If You
518
+ institute patent litigation against any entity (including a
519
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
520
+ or a Contribution incorporated within the Work constitutes direct
521
+ or contributory patent infringement, then any patent licenses
522
+ granted to You under this License for that Work shall terminate
523
+ as of the date such litigation is filed.
524
+
525
+ 4. Redistribution. You may reproduce and distribute copies of the
526
+ Work or Derivative Works thereof in any medium, with or without
527
+ modifications, and in Source or Object form, provided that You
528
+ meet the following conditions:
529
+
530
+ (a) You must give any other recipients of the Work or
531
+ Derivative Works a copy of this License; and
532
+
533
+ (b) You must cause any modified files to carry prominent notices
534
+ stating that You changed the files; and
535
+
536
+ (c) You must retain, in the Source form of any Derivative Works
537
+ that You distribute, all copyright, patent, trademark, and
538
+ attribution notices from the Source form of the Work,
539
+ excluding those notices that do not pertain to any part of
540
+ the Derivative Works; and
541
+
542
+ (d) If the Work includes a "NOTICE" text file as part of its
543
+ distribution, then any Derivative Works that You distribute must
544
+ include a readable copy of the attribution notices contained
545
+ within such NOTICE file, excluding those notices that do not
546
+ pertain to any part of the Derivative Works, in at least one
547
+ of the following places: within a NOTICE text file distributed
548
+ as part of the Derivative Works; within the Source form or
549
+ documentation, if provided along with the Derivative Works; or,
550
+ within a display generated by the Derivative Works, if and
551
+ wherever such third-party notices normally appear. The contents
552
+ of the NOTICE file are for informational purposes only and
553
+ do not modify the License. You may add Your own attribution
554
+ notices within Derivative Works that You distribute, alongside
555
+ or as an addendum to the NOTICE text from the Work, provided
556
+ that such additional attribution notices cannot be construed
557
+ as modifying the License.
558
+
559
+ You may add Your own copyright statement to Your modifications and
560
+ may provide additional or different license terms and conditions
561
+ for use, reproduction, or distribution of Your modifications, or
562
+ for any such Derivative Works as a whole, provided Your use,
563
+ reproduction, and distribution of the Work otherwise complies with
564
+ the conditions stated in this License.
565
+
566
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
567
+ any Contribution intentionally submitted for inclusion in the Work
568
+ by You to the Licensor shall be under the terms and conditions of
569
+ this License, without any additional terms or conditions.
570
+ Notwithstanding the above, nothing herein shall supersede or modify
571
+ the terms of any separate license agreement you may have executed
572
+ with Licensor regarding such Contributions.
573
+
574
+ 6. Trademarks. This License does not grant permission to use the trade
575
+ names, trademarks, service marks, or product names of the Licensor,
576
+ except as required for reasonable and customary use in describing the
577
+ origin of the Work and reproducing the content of the NOTICE file.
578
+
579
+ 7. Disclaimer of Warranty. Unless required by applicable law or
580
+ agreed to in writing, Licensor provides the Work (and each
581
+ Contributor provides its Contributions) on an "AS IS" BASIS,
582
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
583
+ implied, including, without limitation, any warranties or conditions
584
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
585
+ PARTICULAR PURPOSE. You are solely responsible for determining the
586
+ appropriateness of using or redistributing the Work and assume any
587
+ risks associated with Your exercise of permissions under this License.
588
+
589
+ 8. Limitation of Liability. In no event and under no legal theory,
590
+ whether in tort (including negligence), contract, or otherwise,
591
+ unless required by applicable law (such as deliberate and grossly
592
+ negligent acts) or agreed to in writing, shall any Contributor be
593
+ liable to You for damages, including any direct, indirect, special,
594
+ incidental, or consequential damages of any character arising as a
595
+ result of this License or out of the use or inability to use the
596
+ Work (including but not limited to damages for loss of goodwill,
597
+ work stoppage, computer failure or malfunction, or any and all
598
+ other commercial damages or losses), even if such Contributor
599
+ has been advised of the possibility of such damages.
600
+
601
+ 9. Accepting Warranty or Additional Liability. While redistributing
602
+ the Work or Derivative Works thereof, You may choose to offer,
603
+ and charge a fee for, acceptance of support, warranty, indemnity,
604
+ or other liability obligations and/or rights consistent with this
605
+ License. However, in accepting such obligations, You may act only
606
+ on Your own behalf and on Your sole responsibility, not on behalf
607
+ of any other Contributor, and only if You agree to indemnify,
608
+ defend, and hold each Contributor harmless for any liability
609
+ incurred by, or claims asserted against, such Contributor by reason
610
+ of your accepting any such warranty or additional liability.
611
+
612
+ END OF TERMS AND CONDITIONS
613
+
614
+ APPENDIX: How to apply the Apache License to your work.
615
+
616
+ To apply the Apache License to your work, attach the following
617
+ boilerplate notice, with the fields enclosed by brackets "[]"
618
+ replaced with your own identifying information. (Don't include
619
+ the brackets!) The text should be enclosed in the appropriate
620
+ comment syntax for the file format. We also recommend that a
621
+ file or class name and description of purpose be included on the
622
+ same "printed page" as the copyright notice for easier
623
+ identification within third-party archives.
624
+
625
+ Copyright [yyyy] [name of copyright owner]
626
+
627
+ Licensed under the Apache License, Version 2.0 (the "License");
628
+ you may not use this file except in compliance with the License.
629
+ You may obtain a copy of the License at
630
+
631
+ http://www.apache.org/licenses/LICENSE-2.0
632
+
633
+ Unless required by applicable law or agreed to in writing, software
634
+ distributed under the License is distributed on an "AS IS" BASIS,
635
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
636
+ See the License for the specific language governing permissions and
637
+ limitations under the License.
638
+ </pre>
sd/stable-diffusion-webui/javascript/aspectRatioOverlay.js CHANGED
@@ -1,113 +1,113 @@
1
-
2
- let currentWidth = null;
3
- let currentHeight = null;
4
- let arFrameTimeout = setTimeout(function(){},0);
5
-
6
- function dimensionChange(e, is_width, is_height){
7
-
8
- if(is_width){
9
- currentWidth = e.target.value*1.0
10
- }
11
- if(is_height){
12
- currentHeight = e.target.value*1.0
13
- }
14
-
15
- var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
16
-
17
- if(!inImg2img){
18
- return;
19
- }
20
-
21
- var targetElement = null;
22
-
23
- var tabIndex = get_tab_index('mode_img2img')
24
- if(tabIndex == 0){ // img2img
25
- targetElement = gradioApp().querySelector('div[data-testid=image] img');
26
- } else if(tabIndex == 1){ //Sketch
27
- targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
28
- } else if(tabIndex == 2){ // Inpaint
29
- targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
30
- } else if(tabIndex == 3){ // Inpaint sketch
31
- targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
32
- }
33
-
34
-
35
- if(targetElement){
36
-
37
- var arPreviewRect = gradioApp().querySelector('#imageARPreview');
38
- if(!arPreviewRect){
39
- arPreviewRect = document.createElement('div')
40
- arPreviewRect.id = "imageARPreview";
41
- gradioApp().getRootNode().appendChild(arPreviewRect)
42
- }
43
-
44
-
45
-
46
- var viewportOffset = targetElement.getBoundingClientRect();
47
-
48
- viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
49
-
50
- scaledx = targetElement.naturalWidth*viewportscale
51
- scaledy = targetElement.naturalHeight*viewportscale
52
-
53
- cleintRectTop = (viewportOffset.top+window.scrollY)
54
- cleintRectLeft = (viewportOffset.left+window.scrollX)
55
- cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
56
- cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
57
-
58
- viewRectTop = cleintRectCentreY-(scaledy/2)
59
- viewRectLeft = cleintRectCentreX-(scaledx/2)
60
- arRectWidth = scaledx
61
- arRectHeight = scaledy
62
-
63
- arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
64
- arscaledx = currentWidth*arscale
65
- arscaledy = currentHeight*arscale
66
-
67
- arRectTop = cleintRectCentreY-(arscaledy/2)
68
- arRectLeft = cleintRectCentreX-(arscaledx/2)
69
- arRectWidth = arscaledx
70
- arRectHeight = arscaledy
71
-
72
- arPreviewRect.style.top = arRectTop+'px';
73
- arPreviewRect.style.left = arRectLeft+'px';
74
- arPreviewRect.style.width = arRectWidth+'px';
75
- arPreviewRect.style.height = arRectHeight+'px';
76
-
77
- clearTimeout(arFrameTimeout);
78
- arFrameTimeout = setTimeout(function(){
79
- arPreviewRect.style.display = 'none';
80
- },2000);
81
-
82
- arPreviewRect.style.display = 'block';
83
-
84
- }
85
-
86
- }
87
-
88
-
89
- onUiUpdate(function(){
90
- var arPreviewRect = gradioApp().querySelector('#imageARPreview');
91
- if(arPreviewRect){
92
- arPreviewRect.style.display = 'none';
93
- }
94
- var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
95
- if(inImg2img){
96
- let inputs = gradioApp().querySelectorAll('input');
97
- inputs.forEach(function(e){
98
- var is_width = e.parentElement.id == "img2img_width"
99
- var is_height = e.parentElement.id == "img2img_height"
100
-
101
- if((is_width || is_height) && !e.classList.contains('scrollwatch')){
102
- e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
103
- e.classList.add('scrollwatch')
104
- }
105
- if(is_width){
106
- currentWidth = e.value*1.0
107
- }
108
- if(is_height){
109
- currentHeight = e.value*1.0
110
- }
111
- })
112
- }
113
- });
 
1
+
2
+ let currentWidth = null;
3
+ let currentHeight = null;
4
+ let arFrameTimeout = setTimeout(function(){},0);
5
+
6
+ function dimensionChange(e, is_width, is_height){
7
+
8
+ if(is_width){
9
+ currentWidth = e.target.value*1.0
10
+ }
11
+ if(is_height){
12
+ currentHeight = e.target.value*1.0
13
+ }
14
+
15
+ var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
16
+
17
+ if(!inImg2img){
18
+ return;
19
+ }
20
+
21
+ var targetElement = null;
22
+
23
+ var tabIndex = get_tab_index('mode_img2img')
24
+ if(tabIndex == 0){ // img2img
25
+ targetElement = gradioApp().querySelector('div[data-testid=image] img');
26
+ } else if(tabIndex == 1){ //Sketch
27
+ targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
28
+ } else if(tabIndex == 2){ // Inpaint
29
+ targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
30
+ } else if(tabIndex == 3){ // Inpaint sketch
31
+ targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
32
+ }
33
+
34
+
35
+ if(targetElement){
36
+
37
+ var arPreviewRect = gradioApp().querySelector('#imageARPreview');
38
+ if(!arPreviewRect){
39
+ arPreviewRect = document.createElement('div')
40
+ arPreviewRect.id = "imageARPreview";
41
+ gradioApp().getRootNode().appendChild(arPreviewRect)
42
+ }
43
+
44
+
45
+
46
+ var viewportOffset = targetElement.getBoundingClientRect();
47
+
48
+ viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
49
+
50
+ scaledx = targetElement.naturalWidth*viewportscale
51
+ scaledy = targetElement.naturalHeight*viewportscale
52
+
53
+ cleintRectTop = (viewportOffset.top+window.scrollY)
54
+ cleintRectLeft = (viewportOffset.left+window.scrollX)
55
+ cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
56
+ cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
57
+
58
+ viewRectTop = cleintRectCentreY-(scaledy/2)
59
+ viewRectLeft = cleintRectCentreX-(scaledx/2)
60
+ arRectWidth = scaledx
61
+ arRectHeight = scaledy
62
+
63
+ arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
64
+ arscaledx = currentWidth*arscale
65
+ arscaledy = currentHeight*arscale
66
+
67
+ arRectTop = cleintRectCentreY-(arscaledy/2)
68
+ arRectLeft = cleintRectCentreX-(arscaledx/2)
69
+ arRectWidth = arscaledx
70
+ arRectHeight = arscaledy
71
+
72
+ arPreviewRect.style.top = arRectTop+'px';
73
+ arPreviewRect.style.left = arRectLeft+'px';
74
+ arPreviewRect.style.width = arRectWidth+'px';
75
+ arPreviewRect.style.height = arRectHeight+'px';
76
+
77
+ clearTimeout(arFrameTimeout);
78
+ arFrameTimeout = setTimeout(function(){
79
+ arPreviewRect.style.display = 'none';
80
+ },2000);
81
+
82
+ arPreviewRect.style.display = 'block';
83
+
84
+ }
85
+
86
+ }
87
+
88
+
89
+ onUiUpdate(function(){
90
+ var arPreviewRect = gradioApp().querySelector('#imageARPreview');
91
+ if(arPreviewRect){
92
+ arPreviewRect.style.display = 'none';
93
+ }
94
+ var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
95
+ if(inImg2img){
96
+ let inputs = gradioApp().querySelectorAll('input');
97
+ inputs.forEach(function(e){
98
+ var is_width = e.parentElement.id == "img2img_width"
99
+ var is_height = e.parentElement.id == "img2img_height"
100
+
101
+ if((is_width || is_height) && !e.classList.contains('scrollwatch')){
102
+ e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
103
+ e.classList.add('scrollwatch')
104
+ }
105
+ if(is_width){
106
+ currentWidth = e.value*1.0
107
+ }
108
+ if(is_height){
109
+ currentHeight = e.value*1.0
110
+ }
111
+ })
112
+ }
113
+ });
sd/stable-diffusion-webui/javascript/contextMenus.js CHANGED
@@ -1,177 +1,177 @@
1
-
2
- contextMenuInit = function(){
3
- let eventListenerApplied=false;
4
- let menuSpecs = new Map();
5
-
6
- const uid = function(){
7
- return Date.now().toString(36) + Math.random().toString(36).substr(2);
8
- }
9
-
10
- function showContextMenu(event,element,menuEntries){
11
- let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
12
- let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
13
-
14
- let oldMenu = gradioApp().querySelector('#context-menu')
15
- if(oldMenu){
16
- oldMenu.remove()
17
- }
18
-
19
- let tabButton = uiCurrentTab
20
- let baseStyle = window.getComputedStyle(tabButton)
21
-
22
- const contextMenu = document.createElement('nav')
23
- contextMenu.id = "context-menu"
24
- contextMenu.style.background = baseStyle.background
25
- contextMenu.style.color = baseStyle.color
26
- contextMenu.style.fontFamily = baseStyle.fontFamily
27
- contextMenu.style.top = posy+'px'
28
- contextMenu.style.left = posx+'px'
29
-
30
-
31
-
32
- const contextMenuList = document.createElement('ul')
33
- contextMenuList.className = 'context-menu-items';
34
- contextMenu.append(contextMenuList);
35
-
36
- menuEntries.forEach(function(entry){
37
- let contextMenuEntry = document.createElement('a')
38
- contextMenuEntry.innerHTML = entry['name']
39
- contextMenuEntry.addEventListener("click", function(e) {
40
- entry['func']();
41
- })
42
- contextMenuList.append(contextMenuEntry);
43
-
44
- })
45
-
46
- gradioApp().getRootNode().appendChild(contextMenu)
47
-
48
- let menuWidth = contextMenu.offsetWidth + 4;
49
- let menuHeight = contextMenu.offsetHeight + 4;
50
-
51
- let windowWidth = window.innerWidth;
52
- let windowHeight = window.innerHeight;
53
-
54
- if ( (windowWidth - posx) < menuWidth ) {
55
- contextMenu.style.left = windowWidth - menuWidth + "px";
56
- }
57
-
58
- if ( (windowHeight - posy) < menuHeight ) {
59
- contextMenu.style.top = windowHeight - menuHeight + "px";
60
- }
61
-
62
- }
63
-
64
- function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
65
-
66
- currentItems = menuSpecs.get(targetElementSelector)
67
-
68
- if(!currentItems){
69
- currentItems = []
70
- menuSpecs.set(targetElementSelector,currentItems);
71
- }
72
- let newItem = {'id':targetElementSelector+'_'+uid(),
73
- 'name':entryName,
74
- 'func':entryFunction,
75
- 'isNew':true}
76
-
77
- currentItems.push(newItem)
78
- return newItem['id']
79
- }
80
-
81
- function removeContextMenuOption(uid){
82
- menuSpecs.forEach(function(v,k) {
83
- let index = -1
84
- v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
85
- if(index>=0){
86
- v.splice(index, 1);
87
- }
88
- })
89
- }
90
-
91
- function addContextMenuEventListener(){
92
- if(eventListenerApplied){
93
- return;
94
- }
95
- gradioApp().addEventListener("click", function(e) {
96
- let source = e.composedPath()[0]
97
- if(source.id && source.id.indexOf('check_progress')>-1){
98
- return
99
- }
100
-
101
- let oldMenu = gradioApp().querySelector('#context-menu')
102
- if(oldMenu){
103
- oldMenu.remove()
104
- }
105
- });
106
- gradioApp().addEventListener("contextmenu", function(e) {
107
- let oldMenu = gradioApp().querySelector('#context-menu')
108
- if(oldMenu){
109
- oldMenu.remove()
110
- }
111
- menuSpecs.forEach(function(v,k) {
112
- if(e.composedPath()[0].matches(k)){
113
- showContextMenu(e,e.composedPath()[0],v)
114
- e.preventDefault()
115
- return
116
- }
117
- })
118
- });
119
- eventListenerApplied=true
120
-
121
- }
122
-
123
- return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
124
- }
125
-
126
- initResponse = contextMenuInit();
127
- appendContextMenuOption = initResponse[0];
128
- removeContextMenuOption = initResponse[1];
129
- addContextMenuEventListener = initResponse[2];
130
-
131
- (function(){
132
- //Start example Context Menu Items
133
- let generateOnRepeat = function(genbuttonid,interruptbuttonid){
134
- let genbutton = gradioApp().querySelector(genbuttonid);
135
- let interruptbutton = gradioApp().querySelector(interruptbuttonid);
136
- if(!interruptbutton.offsetParent){
137
- genbutton.click();
138
- }
139
- clearInterval(window.generateOnRepeatInterval)
140
- window.generateOnRepeatInterval = setInterval(function(){
141
- if(!interruptbutton.offsetParent){
142
- genbutton.click();
143
- }
144
- },
145
- 500)
146
- }
147
-
148
- appendContextMenuOption('#txt2img_generate','Generate forever',function(){
149
- generateOnRepeat('#txt2img_generate','#txt2img_interrupt');
150
- })
151
- appendContextMenuOption('#img2img_generate','Generate forever',function(){
152
- generateOnRepeat('#img2img_generate','#img2img_interrupt');
153
- })
154
-
155
- let cancelGenerateForever = function(){
156
- clearInterval(window.generateOnRepeatInterval)
157
- }
158
-
159
- appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
160
- appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever)
161
- appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
162
- appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
163
-
164
- appendContextMenuOption('#roll','Roll three',
165
- function(){
166
- let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
167
- setTimeout(function(){rollbutton.click()},100)
168
- setTimeout(function(){rollbutton.click()},200)
169
- setTimeout(function(){rollbutton.click()},300)
170
- }
171
- )
172
- })();
173
- //End example Context Menu Items
174
-
175
- onUiUpdate(function(){
176
- addContextMenuEventListener()
177
- });
 
1
+
2
+ contextMenuInit = function(){
3
+ let eventListenerApplied=false;
4
+ let menuSpecs = new Map();
5
+
6
+ const uid = function(){
7
+ return Date.now().toString(36) + Math.random().toString(36).substr(2);
8
+ }
9
+
10
+ function showContextMenu(event,element,menuEntries){
11
+ let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
12
+ let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
13
+
14
+ let oldMenu = gradioApp().querySelector('#context-menu')
15
+ if(oldMenu){
16
+ oldMenu.remove()
17
+ }
18
+
19
+ let tabButton = uiCurrentTab
20
+ let baseStyle = window.getComputedStyle(tabButton)
21
+
22
+ const contextMenu = document.createElement('nav')
23
+ contextMenu.id = "context-menu"
24
+ contextMenu.style.background = baseStyle.background
25
+ contextMenu.style.color = baseStyle.color
26
+ contextMenu.style.fontFamily = baseStyle.fontFamily
27
+ contextMenu.style.top = posy+'px'
28
+ contextMenu.style.left = posx+'px'
29
+
30
+
31
+
32
+ const contextMenuList = document.createElement('ul')
33
+ contextMenuList.className = 'context-menu-items';
34
+ contextMenu.append(contextMenuList);
35
+
36
+ menuEntries.forEach(function(entry){
37
+ let contextMenuEntry = document.createElement('a')
38
+ contextMenuEntry.innerHTML = entry['name']
39
+ contextMenuEntry.addEventListener("click", function(e) {
40
+ entry['func']();
41
+ })
42
+ contextMenuList.append(contextMenuEntry);
43
+
44
+ })
45
+
46
+ gradioApp().getRootNode().appendChild(contextMenu)
47
+
48
+ let menuWidth = contextMenu.offsetWidth + 4;
49
+ let menuHeight = contextMenu.offsetHeight + 4;
50
+
51
+ let windowWidth = window.innerWidth;
52
+ let windowHeight = window.innerHeight;
53
+
54
+ if ( (windowWidth - posx) < menuWidth ) {
55
+ contextMenu.style.left = windowWidth - menuWidth + "px";
56
+ }
57
+
58
+ if ( (windowHeight - posy) < menuHeight ) {
59
+ contextMenu.style.top = windowHeight - menuHeight + "px";
60
+ }
61
+
62
+ }
63
+
64
+ function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
65
+
66
+ currentItems = menuSpecs.get(targetElementSelector)
67
+
68
+ if(!currentItems){
69
+ currentItems = []
70
+ menuSpecs.set(targetElementSelector,currentItems);
71
+ }
72
+ let newItem = {'id':targetElementSelector+'_'+uid(),
73
+ 'name':entryName,
74
+ 'func':entryFunction,
75
+ 'isNew':true}
76
+
77
+ currentItems.push(newItem)
78
+ return newItem['id']
79
+ }
80
+
81
+ function removeContextMenuOption(uid){
82
+ menuSpecs.forEach(function(v,k) {
83
+ let index = -1
84
+ v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
85
+ if(index>=0){
86
+ v.splice(index, 1);
87
+ }
88
+ })
89
+ }
90
+
91
+ function addContextMenuEventListener(){
92
+ if(eventListenerApplied){
93
+ return;
94
+ }
95
+ gradioApp().addEventListener("click", function(e) {
96
+ let source = e.composedPath()[0]
97
+ if(source.id && source.id.indexOf('check_progress')>-1){
98
+ return
99
+ }
100
+
101
+ let oldMenu = gradioApp().querySelector('#context-menu')
102
+ if(oldMenu){
103
+ oldMenu.remove()
104
+ }
105
+ });
106
+ gradioApp().addEventListener("contextmenu", function(e) {
107
+ let oldMenu = gradioApp().querySelector('#context-menu')
108
+ if(oldMenu){
109
+ oldMenu.remove()
110
+ }
111
+ menuSpecs.forEach(function(v,k) {
112
+ if(e.composedPath()[0].matches(k)){
113
+ showContextMenu(e,e.composedPath()[0],v)
114
+ e.preventDefault()
115
+ return
116
+ }
117
+ })
118
+ });
119
+ eventListenerApplied=true
120
+
121
+ }
122
+
123
+ return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
124
+ }
125
+
126
+ initResponse = contextMenuInit();
127
+ appendContextMenuOption = initResponse[0];
128
+ removeContextMenuOption = initResponse[1];
129
+ addContextMenuEventListener = initResponse[2];
130
+
131
+ (function(){
132
+ //Start example Context Menu Items
133
+ let generateOnRepeat = function(genbuttonid,interruptbuttonid){
134
+ let genbutton = gradioApp().querySelector(genbuttonid);
135
+ let interruptbutton = gradioApp().querySelector(interruptbuttonid);
136
+ if(!interruptbutton.offsetParent){
137
+ genbutton.click();
138
+ }
139
+ clearInterval(window.generateOnRepeatInterval)
140
+ window.generateOnRepeatInterval = setInterval(function(){
141
+ if(!interruptbutton.offsetParent){
142
+ genbutton.click();
143
+ }
144
+ },
145
+ 500)
146
+ }
147
+
148
+ appendContextMenuOption('#txt2img_generate','Generate forever',function(){
149
+ generateOnRepeat('#txt2img_generate','#txt2img_interrupt');
150
+ })
151
+ appendContextMenuOption('#img2img_generate','Generate forever',function(){
152
+ generateOnRepeat('#img2img_generate','#img2img_interrupt');
153
+ })
154
+
155
+ let cancelGenerateForever = function(){
156
+ clearInterval(window.generateOnRepeatInterval)
157
+ }
158
+
159
+ appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
160
+ appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever)
161
+ appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
162
+ appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
163
+
164
+ appendContextMenuOption('#roll','Roll three',
165
+ function(){
166
+ let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
167
+ setTimeout(function(){rollbutton.click()},100)
168
+ setTimeout(function(){rollbutton.click()},200)
169
+ setTimeout(function(){rollbutton.click()},300)
170
+ }
171
+ )
172
+ })();
173
+ //End example Context Menu Items
174
+
175
+ onUiUpdate(function(){
176
+ addContextMenuEventListener()
177
+ });
sd/stable-diffusion-webui/javascript/edit-attention.js CHANGED
@@ -1,96 +1,96 @@
1
- function keyupEditAttention(event){
2
- let target = event.originalTarget || event.composedPath()[0];
3
- if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return;
4
- if (! (event.metaKey || event.ctrlKey)) return;
5
-
6
- let isPlus = event.key == "ArrowUp"
7
- let isMinus = event.key == "ArrowDown"
8
- if (!isPlus && !isMinus) return;
9
-
10
- let selectionStart = target.selectionStart;
11
- let selectionEnd = target.selectionEnd;
12
- let text = target.value;
13
-
14
- function selectCurrentParenthesisBlock(OPEN, CLOSE){
15
- if (selectionStart !== selectionEnd) return false;
16
-
17
- // Find opening parenthesis around current cursor
18
- const before = text.substring(0, selectionStart);
19
- let beforeParen = before.lastIndexOf(OPEN);
20
- if (beforeParen == -1) return false;
21
- let beforeParenClose = before.lastIndexOf(CLOSE);
22
- while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
23
- beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
24
- beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
25
- }
26
-
27
- // Find closing parenthesis around current cursor
28
- const after = text.substring(selectionStart);
29
- let afterParen = after.indexOf(CLOSE);
30
- if (afterParen == -1) return false;
31
- let afterParenOpen = after.indexOf(OPEN);
32
- while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
33
- afterParen = after.indexOf(CLOSE, afterParen + 1);
34
- afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
35
- }
36
- if (beforeParen === -1 || afterParen === -1) return false;
37
-
38
- // Set the selection to the text between the parenthesis
39
- const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
40
- const lastColon = parenContent.lastIndexOf(":");
41
- selectionStart = beforeParen + 1;
42
- selectionEnd = selectionStart + lastColon;
43
- target.setSelectionRange(selectionStart, selectionEnd);
44
- return true;
45
- }
46
-
47
- // If the user hasn't selected anything, let's select their current parenthesis block
48
- if(! selectCurrentParenthesisBlock('<', '>')){
49
- selectCurrentParenthesisBlock('(', ')')
50
- }
51
-
52
- event.preventDefault();
53
-
54
- closeCharacter = ')'
55
- delta = opts.keyedit_precision_attention
56
-
57
- if (selectionStart > 0 && text[selectionStart - 1] == '<'){
58
- closeCharacter = '>'
59
- delta = opts.keyedit_precision_extra
60
- } else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
61
-
62
- // do not include spaces at the end
63
- while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){
64
- selectionEnd -= 1;
65
- }
66
- if(selectionStart == selectionEnd){
67
- return
68
- }
69
-
70
- text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
71
-
72
- selectionStart += 1;
73
- selectionEnd += 1;
74
- }
75
-
76
- end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
77
- weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
78
- if (isNaN(weight)) return;
79
-
80
- weight += isPlus ? delta : -delta;
81
- weight = parseFloat(weight.toPrecision(12));
82
- if(String(weight).length == 1) weight += ".0"
83
-
84
- text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
85
-
86
- target.focus();
87
- target.value = text;
88
- target.selectionStart = selectionStart;
89
- target.selectionEnd = selectionEnd;
90
-
91
- updateInput(target)
92
- }
93
-
94
- addEventListener('keydown', (event) => {
95
- keyupEditAttention(event);
96
  });
 
1
+ function keyupEditAttention(event){
2
+ let target = event.originalTarget || event.composedPath()[0];
3
+ if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return;
4
+ if (! (event.metaKey || event.ctrlKey)) return;
5
+
6
+ let isPlus = event.key == "ArrowUp"
7
+ let isMinus = event.key == "ArrowDown"
8
+ if (!isPlus && !isMinus) return;
9
+
10
+ let selectionStart = target.selectionStart;
11
+ let selectionEnd = target.selectionEnd;
12
+ let text = target.value;
13
+
14
+ function selectCurrentParenthesisBlock(OPEN, CLOSE){
15
+ if (selectionStart !== selectionEnd) return false;
16
+
17
+ // Find opening parenthesis around current cursor
18
+ const before = text.substring(0, selectionStart);
19
+ let beforeParen = before.lastIndexOf(OPEN);
20
+ if (beforeParen == -1) return false;
21
+ let beforeParenClose = before.lastIndexOf(CLOSE);
22
+ while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
23
+ beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
24
+ beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
25
+ }
26
+
27
+ // Find closing parenthesis around current cursor
28
+ const after = text.substring(selectionStart);
29
+ let afterParen = after.indexOf(CLOSE);
30
+ if (afterParen == -1) return false;
31
+ let afterParenOpen = after.indexOf(OPEN);
32
+ while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
33
+ afterParen = after.indexOf(CLOSE, afterParen + 1);
34
+ afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
35
+ }
36
+ if (beforeParen === -1 || afterParen === -1) return false;
37
+
38
+ // Set the selection to the text between the parenthesis
39
+ const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
40
+ const lastColon = parenContent.lastIndexOf(":");
41
+ selectionStart = beforeParen + 1;
42
+ selectionEnd = selectionStart + lastColon;
43
+ target.setSelectionRange(selectionStart, selectionEnd);
44
+ return true;
45
+ }
46
+
47
+ // If the user hasn't selected anything, let's select their current parenthesis block
48
+ if(! selectCurrentParenthesisBlock('<', '>')){
49
+ selectCurrentParenthesisBlock('(', ')')
50
+ }
51
+
52
+ event.preventDefault();
53
+
54
+ closeCharacter = ')'
55
+ delta = opts.keyedit_precision_attention
56
+
57
+ if (selectionStart > 0 && text[selectionStart - 1] == '<'){
58
+ closeCharacter = '>'
59
+ delta = opts.keyedit_precision_extra
60
+ } else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
61
+
62
+ // do not include spaces at the end
63
+ while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){
64
+ selectionEnd -= 1;
65
+ }
66
+ if(selectionStart == selectionEnd){
67
+ return
68
+ }
69
+
70
+ text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
71
+
72
+ selectionStart += 1;
73
+ selectionEnd += 1;
74
+ }
75
+
76
+ end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
77
+ weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
78
+ if (isNaN(weight)) return;
79
+
80
+ weight += isPlus ? delta : -delta;
81
+ weight = parseFloat(weight.toPrecision(12));
82
+ if(String(weight).length == 1) weight += ".0"
83
+
84
+ text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
85
+
86
+ target.focus();
87
+ target.value = text;
88
+ target.selectionStart = selectionStart;
89
+ target.selectionEnd = selectionEnd;
90
+
91
+ updateInput(target)
92
+ }
93
+
94
+ addEventListener('keydown', (event) => {
95
+ keyupEditAttention(event);
96
  });
sd/stable-diffusion-webui/javascript/extensions.js CHANGED
@@ -1,49 +1,49 @@
1
-
2
- function extensions_apply(_, _){
3
- var disable = []
4
- var update = []
5
-
6
- gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
7
- if(x.name.startsWith("enable_") && ! x.checked)
8
- disable.push(x.name.substr(7))
9
-
10
- if(x.name.startsWith("update_") && x.checked)
11
- update.push(x.name.substr(7))
12
- })
13
-
14
- restart_reload()
15
-
16
- return [JSON.stringify(disable), JSON.stringify(update)]
17
- }
18
-
19
- function extensions_check(){
20
- var disable = []
21
-
22
- gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
23
- if(x.name.startsWith("enable_") && ! x.checked)
24
- disable.push(x.name.substr(7))
25
- })
26
-
27
- gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
28
- x.innerHTML = "Loading..."
29
- })
30
-
31
-
32
- var id = randomId()
33
- requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
34
-
35
- })
36
-
37
- return [id, JSON.stringify(disable)]
38
- }
39
-
40
- function install_extension_from_index(button, url){
41
- button.disabled = "disabled"
42
- button.value = "Installing..."
43
-
44
- textarea = gradioApp().querySelector('#extension_to_install textarea')
45
- textarea.value = url
46
- updateInput(textarea)
47
-
48
- gradioApp().querySelector('#install_extension_button').click()
49
- }
 
1
+
2
+ function extensions_apply(_, _){
3
+ var disable = []
4
+ var update = []
5
+
6
+ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
7
+ if(x.name.startsWith("enable_") && ! x.checked)
8
+ disable.push(x.name.substr(7))
9
+
10
+ if(x.name.startsWith("update_") && x.checked)
11
+ update.push(x.name.substr(7))
12
+ })
13
+
14
+ restart_reload()
15
+
16
+ return [JSON.stringify(disable), JSON.stringify(update)]
17
+ }
18
+
19
+ function extensions_check(){
20
+ var disable = []
21
+
22
+ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
23
+ if(x.name.startsWith("enable_") && ! x.checked)
24
+ disable.push(x.name.substr(7))
25
+ })
26
+
27
+ gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
28
+ x.innerHTML = "Loading..."
29
+ })
30
+
31
+
32
+ var id = randomId()
33
+ requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
34
+
35
+ })
36
+
37
+ return [id, JSON.stringify(disable)]
38
+ }
39
+
40
+ function install_extension_from_index(button, url){
41
+ button.disabled = "disabled"
42
+ button.value = "Installing..."
43
+
44
+ textarea = gradioApp().querySelector('#extension_to_install textarea')
45
+ textarea.value = url
46
+ updateInput(textarea)
47
+
48
+ gradioApp().querySelector('#install_extension_button').click()
49
+ }
sd/stable-diffusion-webui/javascript/extraNetworks.js CHANGED
@@ -1,107 +1,107 @@
1
-
2
- function setupExtraNetworksForTab(tabname){
3
- gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
4
-
5
- var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
6
- var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
7
- var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
8
- var close = gradioApp().getElementById(tabname+'_extra_close')
9
-
10
- search.classList.add('search')
11
- tabs.appendChild(search)
12
- tabs.appendChild(refresh)
13
- tabs.appendChild(close)
14
-
15
- search.addEventListener("input", function(evt){
16
- searchTerm = search.value.toLowerCase()
17
-
18
- gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
19
- text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
20
- elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
21
- })
22
- });
23
- }
24
-
25
- var activePromptTextarea = {};
26
-
27
- function setupExtraNetworks(){
28
- setupExtraNetworksForTab('txt2img')
29
- setupExtraNetworksForTab('img2img')
30
-
31
- function registerPrompt(tabname, id){
32
- var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
33
-
34
- if (! activePromptTextarea[tabname]){
35
- activePromptTextarea[tabname] = textarea
36
- }
37
-
38
- textarea.addEventListener("focus", function(){
39
- activePromptTextarea[tabname] = textarea;
40
- });
41
- }
42
-
43
- registerPrompt('txt2img', 'txt2img_prompt')
44
- registerPrompt('txt2img', 'txt2img_neg_prompt')
45
- registerPrompt('img2img', 'img2img_prompt')
46
- registerPrompt('img2img', 'img2img_neg_prompt')
47
- }
48
-
49
- onUiLoaded(setupExtraNetworks)
50
-
51
- var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/;
52
- var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
53
-
54
- function tryToRemoveExtraNetworkFromPrompt(textarea, text){
55
- var m = text.match(re_extranet)
56
- if(! m) return false
57
-
58
- var partToSearch = m[1]
59
- var replaced = false
60
- var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
61
- m = found.match(re_extranet);
62
- if(m[1] == partToSearch){
63
- replaced = true;
64
- return ""
65
- }
66
- return found;
67
- })
68
-
69
- if(replaced){
70
- textarea.value = newTextareaText
71
- return true;
72
- }
73
-
74
- return false
75
- }
76
-
77
- function cardClicked(tabname, textToAdd, allowNegativePrompt){
78
- var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
79
-
80
- if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
81
- textarea.value = textarea.value + " " + textToAdd
82
- }
83
-
84
- updateInput(textarea)
85
- }
86
-
87
- function saveCardPreview(event, tabname, filename){
88
- var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
89
- var button = gradioApp().getElementById(tabname + '_save_preview')
90
-
91
- textarea.value = filename
92
- updateInput(textarea)
93
-
94
- button.click()
95
-
96
- event.stopPropagation()
97
- event.preventDefault()
98
- }
99
-
100
- function extraNetworksSearchButton(tabs_id, event){
101
- searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
102
- button = event.target
103
- text = button.classList.contains("search-all") ? "" : button.textContent.trim()
104
-
105
- searchTextarea.value = text
106
- updateInput(searchTextarea)
107
  }
 
1
+
2
+ function setupExtraNetworksForTab(tabname){
3
+ gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
4
+
5
+ var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
6
+ var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
7
+ var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
8
+ var close = gradioApp().getElementById(tabname+'_extra_close')
9
+
10
+ search.classList.add('search')
11
+ tabs.appendChild(search)
12
+ tabs.appendChild(refresh)
13
+ tabs.appendChild(close)
14
+
15
+ search.addEventListener("input", function(evt){
16
+ searchTerm = search.value.toLowerCase()
17
+
18
+ gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
19
+ text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
20
+ elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
21
+ })
22
+ });
23
+ }
24
+
25
+ var activePromptTextarea = {};
26
+
27
+ function setupExtraNetworks(){
28
+ setupExtraNetworksForTab('txt2img')
29
+ setupExtraNetworksForTab('img2img')
30
+
31
+ function registerPrompt(tabname, id){
32
+ var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
33
+
34
+ if (! activePromptTextarea[tabname]){
35
+ activePromptTextarea[tabname] = textarea
36
+ }
37
+
38
+ textarea.addEventListener("focus", function(){
39
+ activePromptTextarea[tabname] = textarea;
40
+ });
41
+ }
42
+
43
+ registerPrompt('txt2img', 'txt2img_prompt')
44
+ registerPrompt('txt2img', 'txt2img_neg_prompt')
45
+ registerPrompt('img2img', 'img2img_prompt')
46
+ registerPrompt('img2img', 'img2img_neg_prompt')
47
+ }
48
+
49
+ onUiLoaded(setupExtraNetworks)
50
+
51
+ var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/;
52
+ var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
53
+
54
+ function tryToRemoveExtraNetworkFromPrompt(textarea, text){
55
+ var m = text.match(re_extranet)
56
+ if(! m) return false
57
+
58
+ var partToSearch = m[1]
59
+ var replaced = false
60
+ var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
61
+ m = found.match(re_extranet);
62
+ if(m[1] == partToSearch){
63
+ replaced = true;
64
+ return ""
65
+ }
66
+ return found;
67
+ })
68
+
69
+ if(replaced){
70
+ textarea.value = newTextareaText
71
+ return true;
72
+ }
73
+
74
+ return false
75
+ }
76
+
77
+ function cardClicked(tabname, textToAdd, allowNegativePrompt){
78
+ var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
79
+
80
+ if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
81
+ textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd
82
+ }
83
+
84
+ updateInput(textarea)
85
+ }
86
+
87
+ function saveCardPreview(event, tabname, filename){
88
+ var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
89
+ var button = gradioApp().getElementById(tabname + '_save_preview')
90
+
91
+ textarea.value = filename
92
+ updateInput(textarea)
93
+
94
+ button.click()
95
+
96
+ event.stopPropagation()
97
+ event.preventDefault()
98
+ }
99
+
100
+ function extraNetworksSearchButton(tabs_id, event){
101
+ searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
102
+ button = event.target
103
+ text = button.classList.contains("search-all") ? "" : button.textContent.trim()
104
+
105
+ searchTextarea.value = text
106
+ updateInput(searchTextarea)
107
  }
sd/stable-diffusion-webui/javascript/hints.js CHANGED
@@ -6,6 +6,7 @@ titles = {
6
  "GFPGAN": "Restore low quality faces using GFPGAN neural network",
7
  "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
8
  "DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
 
9
  "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
10
 
11
  "Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
 
6
  "GFPGAN": "Restore low quality faces using GFPGAN neural network",
7
  "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
8
  "DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
9
+ "UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models",
10
  "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
11
 
12
  "Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
sd/stable-diffusion-webui/javascript/hires_fix.js CHANGED
@@ -1,22 +1,22 @@
1
-
2
- function setInactive(elem, inactive){
3
- if(inactive){
4
- elem.classList.add('inactive')
5
- } else{
6
- elem.classList.remove('inactive')
7
- }
8
- }
9
-
10
- function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
11
- hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
12
- hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
13
- hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
14
-
15
- gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
16
-
17
- setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
18
- setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
19
- setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
20
-
21
- return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
22
- }
 
1
+
2
+ function setInactive(elem, inactive){
3
+ if(inactive){
4
+ elem.classList.add('inactive')
5
+ } else{
6
+ elem.classList.remove('inactive')
7
+ }
8
+ }
9
+
10
+ function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
11
+ hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
12
+ hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
13
+ hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
14
+
15
+ gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
16
+
17
+ setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
18
+ setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
19
+ setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
20
+
21
+ return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
22
+ }
sd/stable-diffusion-webui/javascript/localization.js CHANGED
@@ -1,165 +1,165 @@
1
-
2
- // localization = {} -- the dict with translations is created by the backend
3
-
4
- ignore_ids_for_localization={
5
- setting_sd_hypernetwork: 'OPTION',
6
- setting_sd_model_checkpoint: 'OPTION',
7
- setting_realesrgan_enabled_models: 'OPTION',
8
- modelmerger_primary_model_name: 'OPTION',
9
- modelmerger_secondary_model_name: 'OPTION',
10
- modelmerger_tertiary_model_name: 'OPTION',
11
- train_embedding: 'OPTION',
12
- train_hypernetwork: 'OPTION',
13
- txt2img_styles: 'OPTION',
14
- img2img_styles: 'OPTION',
15
- setting_random_artist_categories: 'SPAN',
16
- setting_face_restoration_model: 'SPAN',
17
- setting_realesrgan_enabled_models: 'SPAN',
18
- extras_upscaler_1: 'SPAN',
19
- extras_upscaler_2: 'SPAN',
20
- }
21
-
22
- re_num = /^[\.\d]+$/
23
- re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
24
-
25
- original_lines = {}
26
- translated_lines = {}
27
-
28
- function textNodesUnder(el){
29
- var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
30
- while(n=walk.nextNode()) a.push(n);
31
- return a;
32
- }
33
-
34
- function canBeTranslated(node, text){
35
- if(! text) return false;
36
- if(! node.parentElement) return false;
37
-
38
- parentType = node.parentElement.nodeName
39
- if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
40
-
41
- if (parentType=='OPTION' || parentType=='SPAN'){
42
- pnode = node
43
- for(var level=0; level<4; level++){
44
- pnode = pnode.parentElement
45
- if(! pnode) break;
46
-
47
- if(ignore_ids_for_localization[pnode.id] == parentType) return false;
48
- }
49
- }
50
-
51
- if(re_num.test(text)) return false;
52
- if(re_emoji.test(text)) return false;
53
- return true
54
- }
55
-
56
- function getTranslation(text){
57
- if(! text) return undefined
58
-
59
- if(translated_lines[text] === undefined){
60
- original_lines[text] = 1
61
- }
62
-
63
- tl = localization[text]
64
- if(tl !== undefined){
65
- translated_lines[tl] = 1
66
- }
67
-
68
- return tl
69
- }
70
-
71
- function processTextNode(node){
72
- text = node.textContent.trim()
73
-
74
- if(! canBeTranslated(node, text)) return
75
-
76
- tl = getTranslation(text)
77
- if(tl !== undefined){
78
- node.textContent = tl
79
- }
80
- }
81
-
82
- function processNode(node){
83
- if(node.nodeType == 3){
84
- processTextNode(node)
85
- return
86
- }
87
-
88
- if(node.title){
89
- tl = getTranslation(node.title)
90
- if(tl !== undefined){
91
- node.title = tl
92
- }
93
- }
94
-
95
- if(node.placeholder){
96
- tl = getTranslation(node.placeholder)
97
- if(tl !== undefined){
98
- node.placeholder = tl
99
- }
100
- }
101
-
102
- textNodesUnder(node).forEach(function(node){
103
- processTextNode(node)
104
- })
105
- }
106
-
107
- function dumpTranslations(){
108
- dumped = {}
109
- if (localization.rtl) {
110
- dumped.rtl = true
111
- }
112
-
113
- Object.keys(original_lines).forEach(function(text){
114
- if(dumped[text] !== undefined) return
115
-
116
- dumped[text] = localization[text] || text
117
- })
118
-
119
- return dumped
120
- }
121
-
122
- onUiUpdate(function(m){
123
- m.forEach(function(mutation){
124
- mutation.addedNodes.forEach(function(node){
125
- processNode(node)
126
- })
127
- });
128
- })
129
-
130
-
131
- document.addEventListener("DOMContentLoaded", function() {
132
- processNode(gradioApp())
133
-
134
- if (localization.rtl) { // if the language is from right to left,
135
- (new MutationObserver((mutations, observer) => { // wait for the style to load
136
- mutations.forEach(mutation => {
137
- mutation.addedNodes.forEach(node => {
138
- if (node.tagName === 'STYLE') {
139
- observer.disconnect();
140
-
141
- for (const x of node.sheet.rules) { // find all rtl media rules
142
- if (Array.from(x.media || []).includes('rtl')) {
143
- x.media.appendMedium('all'); // enable them
144
- }
145
- }
146
- }
147
- })
148
- });
149
- })).observe(gradioApp(), { childList: true });
150
- }
151
- })
152
-
153
- function download_localization() {
154
- text = JSON.stringify(dumpTranslations(), null, 4)
155
-
156
- var element = document.createElement('a');
157
- element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
158
- element.setAttribute('download', "localization.json");
159
- element.style.display = 'none';
160
- document.body.appendChild(element);
161
-
162
- element.click();
163
-
164
- document.body.removeChild(element);
165
- }
 
1
+
2
+ // localization = {} -- the dict with translations is created by the backend
3
+
4
+ ignore_ids_for_localization={
5
+ setting_sd_hypernetwork: 'OPTION',
6
+ setting_sd_model_checkpoint: 'OPTION',
7
+ setting_realesrgan_enabled_models: 'OPTION',
8
+ modelmerger_primary_model_name: 'OPTION',
9
+ modelmerger_secondary_model_name: 'OPTION',
10
+ modelmerger_tertiary_model_name: 'OPTION',
11
+ train_embedding: 'OPTION',
12
+ train_hypernetwork: 'OPTION',
13
+ txt2img_styles: 'OPTION',
14
+ img2img_styles: 'OPTION',
15
+ setting_random_artist_categories: 'SPAN',
16
+ setting_face_restoration_model: 'SPAN',
17
+ setting_realesrgan_enabled_models: 'SPAN',
18
+ extras_upscaler_1: 'SPAN',
19
+ extras_upscaler_2: 'SPAN',
20
+ }
21
+
22
+ re_num = /^[\.\d]+$/
23
+ re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
24
+
25
+ original_lines = {}
26
+ translated_lines = {}
27
+
28
+ function textNodesUnder(el){
29
+ var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
30
+ while(n=walk.nextNode()) a.push(n);
31
+ return a;
32
+ }
33
+
34
+ function canBeTranslated(node, text){
35
+ if(! text) return false;
36
+ if(! node.parentElement) return false;
37
+
38
+ parentType = node.parentElement.nodeName
39
+ if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
40
+
41
+ if (parentType=='OPTION' || parentType=='SPAN'){
42
+ pnode = node
43
+ for(var level=0; level<4; level++){
44
+ pnode = pnode.parentElement
45
+ if(! pnode) break;
46
+
47
+ if(ignore_ids_for_localization[pnode.id] == parentType) return false;
48
+ }
49
+ }
50
+
51
+ if(re_num.test(text)) return false;
52
+ if(re_emoji.test(text)) return false;
53
+ return true
54
+ }
55
+
56
+ function getTranslation(text){
57
+ if(! text) return undefined
58
+
59
+ if(translated_lines[text] === undefined){
60
+ original_lines[text] = 1
61
+ }
62
+
63
+ tl = localization[text]
64
+ if(tl !== undefined){
65
+ translated_lines[tl] = 1
66
+ }
67
+
68
+ return tl
69
+ }
70
+
71
+ function processTextNode(node){
72
+ text = node.textContent.trim()
73
+
74
+ if(! canBeTranslated(node, text)) return
75
+
76
+ tl = getTranslation(text)
77
+ if(tl !== undefined){
78
+ node.textContent = tl
79
+ }
80
+ }
81
+
82
+ function processNode(node){
83
+ if(node.nodeType == 3){
84
+ processTextNode(node)
85
+ return
86
+ }
87
+
88
+ if(node.title){
89
+ tl = getTranslation(node.title)
90
+ if(tl !== undefined){
91
+ node.title = tl
92
+ }
93
+ }
94
+
95
+ if(node.placeholder){
96
+ tl = getTranslation(node.placeholder)
97
+ if(tl !== undefined){
98
+ node.placeholder = tl
99
+ }
100
+ }
101
+
102
+ textNodesUnder(node).forEach(function(node){
103
+ processTextNode(node)
104
+ })
105
+ }
106
+
107
+ function dumpTranslations(){
108
+ dumped = {}
109
+ if (localization.rtl) {
110
+ dumped.rtl = true
111
+ }
112
+
113
+ Object.keys(original_lines).forEach(function(text){
114
+ if(dumped[text] !== undefined) return
115
+
116
+ dumped[text] = localization[text] || text
117
+ })
118
+
119
+ return dumped
120
+ }
121
+
122
+ onUiUpdate(function(m){
123
+ m.forEach(function(mutation){
124
+ mutation.addedNodes.forEach(function(node){
125
+ processNode(node)
126
+ })
127
+ });
128
+ })
129
+
130
+
131
+ document.addEventListener("DOMContentLoaded", function() {
132
+ processNode(gradioApp())
133
+
134
+ if (localization.rtl) { // if the language is from right to left,
135
+ (new MutationObserver((mutations, observer) => { // wait for the style to load
136
+ mutations.forEach(mutation => {
137
+ mutation.addedNodes.forEach(node => {
138
+ if (node.tagName === 'STYLE') {
139
+ observer.disconnect();
140
+
141
+ for (const x of node.sheet.rules) { // find all rtl media rules
142
+ if (Array.from(x.media || []).includes('rtl')) {
143
+ x.media.appendMedium('all'); // enable them
144
+ }
145
+ }
146
+ }
147
+ })
148
+ });
149
+ })).observe(gradioApp(), { childList: true });
150
+ }
151
+ })
152
+
153
+ function download_localization() {
154
+ text = JSON.stringify(dumpTranslations(), null, 4)
155
+
156
+ var element = document.createElement('a');
157
+ element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
158
+ element.setAttribute('download', "localization.json");
159
+ element.style.display = 'none';
160
+ document.body.appendChild(element);
161
+
162
+ element.click();
163
+
164
+ document.body.removeChild(element);
165
+ }
sd/stable-diffusion-webui/javascript/notification.js CHANGED
@@ -15,7 +15,7 @@ onUiUpdate(function(){
15
  }
16
  }
17
 
18
- const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden');
19
 
20
  if (galleryPreviews == null) return;
21
 
 
15
  }
16
  }
17
 
18
+ const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] img.h-full.w-full.overflow-hidden');
19
 
20
  if (galleryPreviews == null) return;
21
 
sd/stable-diffusion-webui/javascript/progressbar.js CHANGED
@@ -139,7 +139,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
139
 
140
  var divProgress = document.createElement('div')
141
  divProgress.className='progressDiv'
142
- divProgress.style.display = opts.show_progressbar ? "" : "none"
143
  var divInner = document.createElement('div')
144
  divInner.className='progress'
145
 
 
139
 
140
  var divProgress = document.createElement('div')
141
  divProgress.className='progressDiv'
142
+ divProgress.style.display = opts.show_progressbar ? "block" : "none"
143
  var divInner = document.createElement('div')
144
  divInner.className='progress'
145
 
sd/stable-diffusion-webui/javascript/textualInversion.js CHANGED
@@ -1,17 +1,17 @@
1
-
2
-
3
-
4
- function start_training_textual_inversion(){
5
- gradioApp().querySelector('#ti_error').innerHTML=''
6
-
7
- var id = randomId()
8
- requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
9
- gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
10
- })
11
-
12
- var res = args_to_array(arguments)
13
-
14
- res[0] = id
15
-
16
- return res
17
- }
 
1
+
2
+
3
+
4
+ function start_training_textual_inversion(){
5
+ gradioApp().querySelector('#ti_error').innerHTML=''
6
+
7
+ var id = randomId()
8
+ requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
9
+ gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
10
+ })
11
+
12
+ var res = args_to_array(arguments)
13
+
14
+ res[0] = id
15
+
16
+ return res
17
+ }
sd/stable-diffusion-webui/launch.py CHANGED
@@ -1,361 +1,375 @@
1
- # this scripts installs necessary requirements and launches main program in webui.py
2
- import subprocess
3
- import os
4
- import sys
5
- import importlib.util
6
- import shlex
7
- import platform
8
- import argparse
9
- import json
10
-
11
- dir_repos = "repositories"
12
- dir_extensions = "extensions"
13
- python = sys.executable
14
- git = os.environ.get('GIT', "git")
15
- index_url = os.environ.get('INDEX_URL', "")
16
- stored_commit_hash = None
17
- skip_install = False
18
-
19
-
20
- def check_python_version():
21
- is_windows = platform.system() == "Windows"
22
- major = sys.version_info.major
23
- minor = sys.version_info.minor
24
- micro = sys.version_info.micro
25
-
26
- if is_windows:
27
- supported_minors = [10]
28
- else:
29
- supported_minors = [7, 8, 9, 10, 11]
30
-
31
- if not (major == 3 and minor in supported_minors):
32
- import modules.errors
33
-
34
- modules.errors.print_error_explanation(f"""
35
- INCOMPATIBLE PYTHON VERSION
36
-
37
- This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
38
- If you encounter an error with "RuntimeError: Couldn't install torch." message,
39
- or any other error regarding unsuccessful package (library) installation,
40
- please downgrade (or upgrade) to the latest version of 3.10 Python
41
- and delete current Python and "venv" folder in WebUI's directory.
42
-
43
- You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
44
-
45
- {"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
46
-
47
- Use --skip-python-version-check to suppress this warning.
48
- """)
49
-
50
-
51
- def commit_hash():
52
- global stored_commit_hash
53
-
54
- if stored_commit_hash is not None:
55
- return stored_commit_hash
56
-
57
- try:
58
- stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
59
- except Exception:
60
- stored_commit_hash = "<none>"
61
-
62
- return stored_commit_hash
63
-
64
-
65
- def extract_arg(args, name):
66
- return [x for x in args if x != name], name in args
67
-
68
-
69
- def extract_opt(args, name):
70
- opt = None
71
- is_present = False
72
- if name in args:
73
- is_present = True
74
- idx = args.index(name)
75
- del args[idx]
76
- if idx < len(args) and args[idx][0] != "-":
77
- opt = args[idx]
78
- del args[idx]
79
- return args, is_present, opt
80
-
81
-
82
- def run(command, desc=None, errdesc=None, custom_env=None, live=False):
83
- if desc is not None:
84
- print(desc)
85
-
86
- if live:
87
- result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
88
- if result.returncode != 0:
89
- raise RuntimeError(f"""{errdesc or 'Error running command'}.
90
- Command: {command}
91
- Error code: {result.returncode}""")
92
-
93
- return ""
94
-
95
- result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
96
-
97
- if result.returncode != 0:
98
-
99
- message = f"""{errdesc or 'Error running command'}.
100
- Command: {command}
101
- Error code: {result.returncode}
102
- stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
103
- stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
104
- """
105
- raise RuntimeError(message)
106
-
107
- return result.stdout.decode(encoding="utf8", errors="ignore")
108
-
109
-
110
- def check_run(command):
111
- result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
112
- return result.returncode == 0
113
-
114
-
115
- def is_installed(package):
116
- try:
117
- spec = importlib.util.find_spec(package)
118
- except ModuleNotFoundError:
119
- return False
120
-
121
- return spec is not None
122
-
123
-
124
- def repo_dir(name):
125
- return os.path.join(dir_repos, name)
126
-
127
-
128
- def run_python(code, desc=None, errdesc=None):
129
- return run(f'"{python}" -c "{code}"', desc, errdesc)
130
-
131
-
132
- def run_pip(args, desc=None):
133
- if skip_install:
134
- return
135
-
136
- index_url_line = f' --index-url {index_url}' if index_url != '' else ''
137
- return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
138
-
139
-
140
- def check_run_python(code):
141
- return check_run(f'"{python}" -c "{code}"')
142
-
143
-
144
- def git_clone(url, dir, name, commithash=None):
145
- # TODO clone into temporary dir and move if successful
146
-
147
- if os.path.exists(dir):
148
- if commithash is None:
149
- return
150
-
151
- current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
152
- if current_hash == commithash:
153
- return
154
-
155
- run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
156
- run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
157
- return
158
-
159
- run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
160
-
161
- if commithash is not None:
162
- run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
163
-
164
-
165
- def version_check(commit):
166
- try:
167
- import requests
168
- commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
169
- if commit != "<none>" and commits['commit']['sha'] != commit:
170
- print("--------------------------------------------------------")
171
- print("| You are not up to date with the most recent release. |")
172
- print("| Consider running `git pull` to update. |")
173
- print("--------------------------------------------------------")
174
- elif commits['commit']['sha'] == commit:
175
- print("You are up to date with the most recent release.")
176
- else:
177
- print("Not a git clone, can't perform version check.")
178
- except Exception as e:
179
- print("version check failed", e)
180
-
181
-
182
- def run_extension_installer(extension_dir):
183
- path_installer = os.path.join(extension_dir, "install.py")
184
- if not os.path.isfile(path_installer):
185
- return
186
-
187
- try:
188
- env = os.environ.copy()
189
- env['PYTHONPATH'] = os.path.abspath(".")
190
-
191
- print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
192
- except Exception as e:
193
- print(e, file=sys.stderr)
194
-
195
-
196
- def list_extensions(settings_file):
197
- settings = {}
198
-
199
- try:
200
- if os.path.isfile(settings_file):
201
- with open(settings_file, "r", encoding="utf8") as file:
202
- settings = json.load(file)
203
- except Exception as e:
204
- print(e, file=sys.stderr)
205
-
206
- disabled_extensions = set(settings.get('disabled_extensions', []))
207
-
208
- return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions]
209
-
210
-
211
- def run_extensions_installers(settings_file):
212
- if not os.path.isdir(dir_extensions):
213
- return
214
-
215
- for dirname_extension in list_extensions(settings_file):
216
- run_extension_installer(os.path.join(dir_extensions, dirname_extension))
217
-
218
-
219
- def prepare_environment():
220
- global skip_install
221
-
222
- torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
223
- requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
224
- commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
225
-
226
- xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
227
- gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
228
- clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
229
- openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
230
-
231
- stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
232
- taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
233
- k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
234
- codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
235
- blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
236
-
237
- stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e")
238
- taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
239
- k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
240
- codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
241
- blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
242
-
243
- sys.argv += shlex.split(commandline_args)
244
-
245
- parser = argparse.ArgumentParser(add_help=False)
246
- parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
247
- args, _ = parser.parse_known_args(sys.argv)
248
-
249
- sys.argv, _ = extract_arg(sys.argv, '-f')
250
- sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
251
- sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
252
- sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
253
- sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
254
- sys.argv, update_check = extract_arg(sys.argv, '--update-check')
255
- sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
256
- sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
257
- xformers = '--xformers' in sys.argv
258
- ngrok = '--ngrok' in sys.argv
259
-
260
- if not skip_python_version_check:
261
- check_python_version()
262
-
263
- commit = commit_hash()
264
-
265
- print(f"Python {sys.version}")
266
- print(f"Commit hash: {commit}")
267
-
268
- if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
269
- run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
270
-
271
- if not skip_torch_cuda_test:
272
- run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
273
-
274
- if not is_installed("gfpgan"):
275
- run_pip(f"install {gfpgan_package}", "gfpgan")
276
-
277
- if not is_installed("clip"):
278
- run_pip(f"install {clip_package}", "clip")
279
-
280
- if not is_installed("open_clip"):
281
- run_pip(f"install {openclip_package}", "open_clip")
282
-
283
- if (not is_installed("xformers") or reinstall_xformers) and xformers:
284
- if platform.system() == "Windows":
285
- if platform.python_version().startswith("3.10"):
286
- run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
287
- else:
288
- print("Installation of xformers is not supported in this version of Python.")
289
- print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
290
- if not is_installed("xformers"):
291
- exit(0)
292
- elif platform.system() == "Linux":
293
- run_pip(f"install {xformers_package}", "xformers")
294
-
295
- if not is_installed("pyngrok") and ngrok:
296
- run_pip("install pyngrok", "ngrok")
297
-
298
- os.makedirs(dir_repos, exist_ok=True)
299
-
300
- git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
301
- git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
302
- git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
303
- git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
304
- git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
305
-
306
- if not is_installed("lpips"):
307
- run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
308
-
309
- run_pip(f"install -r {requirements_file}", "requirements for Web UI")
310
-
311
- run_extensions_installers(settings_file=args.ui_settings_file)
312
-
313
- if update_check:
314
- version_check(commit)
315
-
316
- if "--exit" in sys.argv:
317
- print("Exiting because of --exit argument")
318
- exit(0)
319
-
320
- if run_tests:
321
- exitcode = tests(test_dir)
322
- exit(exitcode)
323
-
324
-
325
- def tests(test_dir):
326
- if "--api" not in sys.argv:
327
- sys.argv.append("--api")
328
- if "--ckpt" not in sys.argv:
329
- sys.argv.append("--ckpt")
330
- sys.argv.append("./test/test_files/empty.pt")
331
- if "--skip-torch-cuda-test" not in sys.argv:
332
- sys.argv.append("--skip-torch-cuda-test")
333
- if "--disable-nan-check" not in sys.argv:
334
- sys.argv.append("--disable-nan-check")
335
-
336
- print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
337
-
338
- os.environ['COMMANDLINE_ARGS'] = ""
339
- with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
340
- proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
341
-
342
- import test.server_poll
343
- exitcode = test.server_poll.run_tests(proc, test_dir)
344
-
345
- print(f"Stopping Web UI process with id {proc.pid}")
346
- proc.kill()
347
- return exitcode
348
-
349
-
350
- def start():
351
- print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
352
- import webui
353
- if '--nowebui' in sys.argv:
354
- webui.api_only()
355
- else:
356
- webui.webui()
357
-
358
-
359
- if __name__ == "__main__":
360
- prepare_environment()
361
- start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this scripts installs necessary requirements and launches main program in webui.py
2
+ import subprocess
3
+ import os
4
+ import sys
5
+ import importlib.util
6
+ import shlex
7
+ import platform
8
+ import argparse
9
+ import json
10
+
11
+ dir_repos = "repositories"
12
+ dir_extensions = "extensions"
13
+ python = sys.executable
14
+ git = os.environ.get('GIT', "git")
15
+ index_url = os.environ.get('INDEX_URL', "")
16
+ stored_commit_hash = None
17
+ skip_install = False
18
+
19
+
20
+ def check_python_version():
21
+ is_windows = platform.system() == "Windows"
22
+ major = sys.version_info.major
23
+ minor = sys.version_info.minor
24
+ micro = sys.version_info.micro
25
+
26
+ if is_windows:
27
+ supported_minors = [10]
28
+ else:
29
+ supported_minors = [7, 8, 9, 10, 11]
30
+
31
+ if not (major == 3 and minor in supported_minors):
32
+ import modules.errors
33
+
34
+ modules.errors.print_error_explanation(f"""
35
+ INCOMPATIBLE PYTHON VERSION
36
+
37
+ This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
38
+ If you encounter an error with "RuntimeError: Couldn't install torch." message,
39
+ or any other error regarding unsuccessful package (library) installation,
40
+ please downgrade (or upgrade) to the latest version of 3.10 Python
41
+ and delete current Python and "venv" folder in WebUI's directory.
42
+
43
+ You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
44
+
45
+ {"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
46
+
47
+ Use --skip-python-version-check to suppress this warning.
48
+ """)
49
+
50
+
51
+ def commit_hash():
52
+ global stored_commit_hash
53
+
54
+ if stored_commit_hash is not None:
55
+ return stored_commit_hash
56
+
57
+ try:
58
+ stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
59
+ except Exception:
60
+ stored_commit_hash = "<none>"
61
+
62
+ return stored_commit_hash
63
+
64
+
65
+ def extract_arg(args, name):
66
+ return [x for x in args if x != name], name in args
67
+
68
+
69
+ def extract_opt(args, name):
70
+ opt = None
71
+ is_present = False
72
+ if name in args:
73
+ is_present = True
74
+ idx = args.index(name)
75
+ del args[idx]
76
+ if idx < len(args) and args[idx][0] != "-":
77
+ opt = args[idx]
78
+ del args[idx]
79
+ return args, is_present, opt
80
+
81
+
82
+ def run(command, desc=None, errdesc=None, custom_env=None, live=False):
83
+ if desc is not None:
84
+ print(desc)
85
+
86
+ if live:
87
+ result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
88
+ if result.returncode != 0:
89
+ raise RuntimeError(f"""{errdesc or 'Error running command'}.
90
+ Command: {command}
91
+ Error code: {result.returncode}""")
92
+
93
+ return ""
94
+
95
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
96
+
97
+ if result.returncode != 0:
98
+
99
+ message = f"""{errdesc or 'Error running command'}.
100
+ Command: {command}
101
+ Error code: {result.returncode}
102
+ stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
103
+ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
104
+ """
105
+ raise RuntimeError(message)
106
+
107
+ return result.stdout.decode(encoding="utf8", errors="ignore")
108
+
109
+
110
+ def check_run(command):
111
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
112
+ return result.returncode == 0
113
+
114
+
115
+ def is_installed(package):
116
+ try:
117
+ spec = importlib.util.find_spec(package)
118
+ except ModuleNotFoundError:
119
+ return False
120
+
121
+ return spec is not None
122
+
123
+
124
+ def repo_dir(name):
125
+ return os.path.join(dir_repos, name)
126
+
127
+
128
+ def run_python(code, desc=None, errdesc=None):
129
+ return run(f'"{python}" -c "{code}"', desc, errdesc)
130
+
131
+
132
+ def run_pip(args, desc=None):
133
+ if skip_install:
134
+ return
135
+
136
+ index_url_line = f' --index-url {index_url}' if index_url != '' else ''
137
+ return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
138
+
139
+
140
+ def check_run_python(code):
141
+ return check_run(f'"{python}" -c "{code}"')
142
+
143
+
144
+ def git_clone(url, dir, name, commithash=None):
145
+ # TODO clone into temporary dir and move if successful
146
+
147
+ if os.path.exists(dir):
148
+ if commithash is None:
149
+ return
150
+
151
+ current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
152
+ if current_hash == commithash:
153
+ return
154
+
155
+ run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
156
+ run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
157
+ return
158
+
159
+ run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
160
+
161
+ if commithash is not None:
162
+ run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
163
+
164
+
165
+ def git_pull_recursive(dir):
166
+ for subdir, _, _ in os.walk(dir):
167
+ if os.path.exists(os.path.join(subdir, '.git')):
168
+ try:
169
+ output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
170
+ print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
171
+ except subprocess.CalledProcessError as e:
172
+ print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
173
+
174
+
175
+ def version_check(commit):
176
+ try:
177
+ import requests
178
+ commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
179
+ if commit != "<none>" and commits['commit']['sha'] != commit:
180
+ print("--------------------------------------------------------")
181
+ print("| You are not up to date with the most recent release. |")
182
+ print("| Consider running `git pull` to update. |")
183
+ print("--------------------------------------------------------")
184
+ elif commits['commit']['sha'] == commit:
185
+ print("You are up to date with the most recent release.")
186
+ else:
187
+ print("Not a git clone, can't perform version check.")
188
+ except Exception as e:
189
+ print("version check failed", e)
190
+
191
+
192
+ def run_extension_installer(extension_dir):
193
+ path_installer = os.path.join(extension_dir, "install.py")
194
+ if not os.path.isfile(path_installer):
195
+ return
196
+
197
+ try:
198
+ env = os.environ.copy()
199
+ env['PYTHONPATH'] = os.path.abspath(".")
200
+
201
+ print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
202
+ except Exception as e:
203
+ print(e, file=sys.stderr)
204
+
205
+
206
+ def list_extensions(settings_file):
207
+ settings = {}
208
+
209
+ try:
210
+ if os.path.isfile(settings_file):
211
+ with open(settings_file, "r", encoding="utf8") as file:
212
+ settings = json.load(file)
213
+ except Exception as e:
214
+ print(e, file=sys.stderr)
215
+
216
+ disabled_extensions = set(settings.get('disabled_extensions', []))
217
+
218
+ return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions]
219
+
220
+
221
+ def run_extensions_installers(settings_file):
222
+ if not os.path.isdir(dir_extensions):
223
+ return
224
+
225
+ for dirname_extension in list_extensions(settings_file):
226
+ run_extension_installer(os.path.join(dir_extensions, dirname_extension))
227
+
228
+
229
+ def prepare_environment():
230
+ global skip_install
231
+
232
+ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
233
+ requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
234
+ commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
235
+
236
+ xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
237
+ gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
238
+ clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
239
+ openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
240
+
241
+ stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
242
+ taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
243
+ k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
244
+ codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
245
+ blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
246
+
247
+ stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e")
248
+ taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
249
+ k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
250
+ codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
251
+ blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
252
+
253
+ sys.argv += shlex.split(commandline_args)
254
+
255
+ parser = argparse.ArgumentParser(add_help=False)
256
+ parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
257
+ args, _ = parser.parse_known_args(sys.argv)
258
+
259
+ sys.argv, _ = extract_arg(sys.argv, '-f')
260
+ sys.argv, update_all_extensions = extract_arg(sys.argv, '--update-all-extensions')
261
+ sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
262
+ sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
263
+ sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
264
+ sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
265
+ sys.argv, update_check = extract_arg(sys.argv, '--update-check')
266
+ sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
267
+ sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
268
+ xformers = '--xformers' in sys.argv
269
+ ngrok = '--ngrok' in sys.argv
270
+
271
+ if not skip_python_version_check:
272
+ check_python_version()
273
+
274
+ commit = commit_hash()
275
+
276
+ print(f"Python {sys.version}")
277
+ print(f"Commit hash: {commit}")
278
+
279
+ if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
280
+ run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
281
+
282
+ if not skip_torch_cuda_test:
283
+ run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
284
+
285
+ if not is_installed("gfpgan"):
286
+ run_pip(f"install {gfpgan_package}", "gfpgan")
287
+
288
+ if not is_installed("clip"):
289
+ run_pip(f"install {clip_package}", "clip")
290
+
291
+ if not is_installed("open_clip"):
292
+ run_pip(f"install {openclip_package}", "open_clip")
293
+
294
+ if (not is_installed("xformers") or reinstall_xformers) and xformers:
295
+ if platform.system() == "Windows":
296
+ if platform.python_version().startswith("3.10"):
297
+ run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
298
+ else:
299
+ print("Installation of xformers is not supported in this version of Python.")
300
+ print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
301
+ if not is_installed("xformers"):
302
+ exit(0)
303
+ elif platform.system() == "Linux":
304
+ run_pip(f"install {xformers_package}", "xformers")
305
+
306
+ if not is_installed("pyngrok") and ngrok:
307
+ run_pip("install pyngrok", "ngrok")
308
+
309
+ os.makedirs(dir_repos, exist_ok=True)
310
+
311
+ git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
312
+ git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
313
+ git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
314
+ git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
315
+ git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
316
+
317
+ if not is_installed("lpips"):
318
+ run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
319
+
320
+ run_pip(f"install -r {requirements_file}", "requirements for Web UI")
321
+
322
+ run_extensions_installers(settings_file=args.ui_settings_file)
323
+
324
+ if update_check:
325
+ version_check(commit)
326
+
327
+ if update_all_extensions:
328
+ git_pull_recursive(dir_extensions)
329
+
330
+ if "--exit" in sys.argv:
331
+ print("Exiting because of --exit argument")
332
+ exit(0)
333
+
334
+ if run_tests:
335
+ exitcode = tests(test_dir)
336
+ exit(exitcode)
337
+
338
+
339
+ def tests(test_dir):
340
+ if "--api" not in sys.argv:
341
+ sys.argv.append("--api")
342
+ if "--ckpt" not in sys.argv:
343
+ sys.argv.append("--ckpt")
344
+ sys.argv.append("./test/test_files/empty.pt")
345
+ if "--skip-torch-cuda-test" not in sys.argv:
346
+ sys.argv.append("--skip-torch-cuda-test")
347
+ if "--disable-nan-check" not in sys.argv:
348
+ sys.argv.append("--disable-nan-check")
349
+
350
+ print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
351
+
352
+ os.environ['COMMANDLINE_ARGS'] = ""
353
+ with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
354
+ proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
355
+
356
+ import test.server_poll
357
+ exitcode = test.server_poll.run_tests(proc, test_dir)
358
+
359
+ print(f"Stopping Web UI process with id {proc.pid}")
360
+ proc.kill()
361
+ return exitcode
362
+
363
+
364
+ def start():
365
+ print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
366
+ import webui
367
+ if '--nowebui' in sys.argv:
368
+ webui.api_only()
369
+ else:
370
+ webui.webui()
371
+
372
+
373
+ if __name__ == "__main__":
374
+ prepare_environment()
375
+ start()
sd/stable-diffusion-webui/modules/api/api.py CHANGED
@@ -150,6 +150,7 @@ class Api:
150
  self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
151
  self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
152
  self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
 
153
 
154
  def add_api_route(self, path: str, endpoint, **kwargs):
155
  if shared.cmd_opts.api_auth:
@@ -174,36 +175,44 @@ class Api:
174
  script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
175
  script = script_runner.selectable_scripts[script_idx]
176
  return script, script_idx
 
 
 
 
 
 
177
 
178
  def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
179
  script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img)
180
 
181
- populate = txt2imgreq.copy(update={ # Override __init__ params
182
  "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
183
- "do_not_save_samples": True,
184
- "do_not_save_grid": True
185
- }
186
- )
187
  if populate.sampler_name:
188
  populate.sampler_index = None # prevent a warning later on
189
 
190
  args = vars(populate)
191
  args.pop('script_name', None)
192
 
 
 
 
193
  with self.queue_lock:
194
  p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
 
 
195
 
196
  shared.state.begin()
197
  if script is not None:
198
- p.outpath_grids = opts.outdir_txt2img_grids
199
- p.outpath_samples = opts.outdir_txt2img_samples
200
  p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
201
  processed = scripts.scripts_txt2img.run(p, *p.script_args)
202
  else:
203
  processed = process_images(p)
204
  shared.state.end()
205
 
206
- b64images = list(map(encode_pil_to_base64, processed.images))
207
 
208
  return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
209
 
@@ -218,13 +227,12 @@ class Api:
218
  if mask:
219
  mask = decode_base64_to_image(mask)
220
 
221
- populate = img2imgreq.copy(update={ # Override __init__ params
222
  "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
223
- "do_not_save_samples": True,
224
- "do_not_save_grid": True,
225
- "mask": mask
226
- }
227
- )
228
  if populate.sampler_name:
229
  populate.sampler_index = None # prevent a warning later on
230
 
@@ -232,21 +240,24 @@ class Api:
232
  args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
233
  args.pop('script_name', None)
234
 
 
 
 
235
  with self.queue_lock:
236
  p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
237
  p.init_images = [decode_base64_to_image(x) for x in init_images]
 
 
238
 
239
  shared.state.begin()
240
  if script is not None:
241
- p.outpath_grids = opts.outdir_img2img_grids
242
- p.outpath_samples = opts.outdir_img2img_samples
243
  p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
244
  processed = scripts.scripts_img2img.run(p, *p.script_args)
245
  else:
246
  processed = process_images(p)
247
  shared.state.end()
248
 
249
- b64images = list(map(encode_pil_to_base64, processed.images))
250
 
251
  if not img2imgreq.include_init_images:
252
  img2imgreq.init_images = None
 
150
  self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
151
  self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
152
  self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
153
+ self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
154
 
155
  def add_api_route(self, path: str, endpoint, **kwargs):
156
  if shared.cmd_opts.api_auth:
 
175
  script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
176
  script = script_runner.selectable_scripts[script_idx]
177
  return script, script_idx
178
+
179
+ def get_scripts_list(self):
180
+ t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
181
+ i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
182
+
183
+ return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
184
 
185
  def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
186
  script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img)
187
 
188
+ populate = txt2imgreq.copy(update={ # Override __init__ params
189
  "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
190
+ "do_not_save_samples": not txt2imgreq.save_images,
191
+ "do_not_save_grid": not txt2imgreq.save_images,
192
+ })
 
193
  if populate.sampler_name:
194
  populate.sampler_index = None # prevent a warning later on
195
 
196
  args = vars(populate)
197
  args.pop('script_name', None)
198
 
199
+ send_images = args.pop('send_images', True)
200
+ args.pop('save_images', None)
201
+
202
  with self.queue_lock:
203
  p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
204
+ p.outpath_grids = opts.outdir_txt2img_grids
205
+ p.outpath_samples = opts.outdir_txt2img_samples
206
 
207
  shared.state.begin()
208
  if script is not None:
 
 
209
  p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
210
  processed = scripts.scripts_txt2img.run(p, *p.script_args)
211
  else:
212
  processed = process_images(p)
213
  shared.state.end()
214
 
215
+ b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
216
 
217
  return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
218
 
 
227
  if mask:
228
  mask = decode_base64_to_image(mask)
229
 
230
+ populate = img2imgreq.copy(update={ # Override __init__ params
231
  "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
232
+ "do_not_save_samples": not img2imgreq.save_images,
233
+ "do_not_save_grid": not img2imgreq.save_images,
234
+ "mask": mask,
235
+ })
 
236
  if populate.sampler_name:
237
  populate.sampler_index = None # prevent a warning later on
238
 
 
240
  args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
241
  args.pop('script_name', None)
242
 
243
+ send_images = args.pop('send_images', True)
244
+ args.pop('save_images', None)
245
+
246
  with self.queue_lock:
247
  p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
248
  p.init_images = [decode_base64_to_image(x) for x in init_images]
249
+ p.outpath_grids = opts.outdir_img2img_grids
250
+ p.outpath_samples = opts.outdir_img2img_samples
251
 
252
  shared.state.begin()
253
  if script is not None:
 
 
254
  p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
255
  processed = scripts.scripts_img2img.run(p, *p.script_args)
256
  else:
257
  processed = process_images(p)
258
  shared.state.end()
259
 
260
+ b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
261
 
262
  if not img2imgreq.include_init_images:
263
  img2imgreq.init_images = None
sd/stable-diffusion-webui/modules/api/models.py CHANGED
@@ -14,8 +14,8 @@ API_NOT_ALLOWED = [
14
  "outpath_samples",
15
  "outpath_grids",
16
  "sampler_index",
17
- "do_not_save_samples",
18
- "do_not_save_grid",
19
  "extra_generation_params",
20
  "overlay_images",
21
  "do_not_reload_embeddings",
@@ -100,13 +100,29 @@ class PydanticModelGenerator:
100
  StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
101
  "StableDiffusionProcessingTxt2Img",
102
  StableDiffusionProcessingTxt2Img,
103
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
 
 
 
 
 
 
104
  ).generate_model()
105
 
106
  StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
107
  "StableDiffusionProcessingImg2Img",
108
  StableDiffusionProcessingImg2Img,
109
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
 
 
 
 
 
 
 
 
 
 
110
  ).generate_model()
111
 
112
  class TextToImageResponse(BaseModel):
@@ -267,3 +283,7 @@ class EmbeddingsResponse(BaseModel):
267
  class MemoryResponse(BaseModel):
268
  ram: dict = Field(title="RAM", description="System memory stats")
269
  cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
 
 
 
 
 
14
  "outpath_samples",
15
  "outpath_grids",
16
  "sampler_index",
17
+ # "do_not_save_samples",
18
+ # "do_not_save_grid",
19
  "extra_generation_params",
20
  "overlay_images",
21
  "do_not_reload_embeddings",
 
100
  StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
101
  "StableDiffusionProcessingTxt2Img",
102
  StableDiffusionProcessingTxt2Img,
103
+ [
104
+ {"key": "sampler_index", "type": str, "default": "Euler"},
105
+ {"key": "script_name", "type": str, "default": None},
106
+ {"key": "script_args", "type": list, "default": []},
107
+ {"key": "send_images", "type": bool, "default": True},
108
+ {"key": "save_images", "type": bool, "default": False},
109
+ ]
110
  ).generate_model()
111
 
112
  StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
113
  "StableDiffusionProcessingImg2Img",
114
  StableDiffusionProcessingImg2Img,
115
+ [
116
+ {"key": "sampler_index", "type": str, "default": "Euler"},
117
+ {"key": "init_images", "type": list, "default": None},
118
+ {"key": "denoising_strength", "type": float, "default": 0.75},
119
+ {"key": "mask", "type": str, "default": None},
120
+ {"key": "include_init_images", "type": bool, "default": False, "exclude" : True},
121
+ {"key": "script_name", "type": str, "default": None},
122
+ {"key": "script_args", "type": list, "default": []},
123
+ {"key": "send_images", "type": bool, "default": True},
124
+ {"key": "save_images", "type": bool, "default": False},
125
+ ]
126
  ).generate_model()
127
 
128
  class TextToImageResponse(BaseModel):
 
283
  class MemoryResponse(BaseModel):
284
  ram: dict = Field(title="RAM", description="System memory stats")
285
  cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
286
+
287
+ class ScriptsList(BaseModel):
288
+ txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
289
+ img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
sd/stable-diffusion-webui/modules/call_queue.py CHANGED
@@ -1,109 +1,109 @@
1
- import html
2
- import sys
3
- import threading
4
- import traceback
5
- import time
6
-
7
- from modules import shared, progress
8
-
9
- queue_lock = threading.Lock()
10
-
11
-
12
- def wrap_queued_call(func):
13
- def f(*args, **kwargs):
14
- with queue_lock:
15
- res = func(*args, **kwargs)
16
-
17
- return res
18
-
19
- return f
20
-
21
-
22
- def wrap_gradio_gpu_call(func, extra_outputs=None):
23
- def f(*args, **kwargs):
24
-
25
- # if the first argument is a string that says "task(...)", it is treated as a job id
26
- if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
27
- id_task = args[0]
28
- progress.add_task_to_queue(id_task)
29
- else:
30
- id_task = None
31
-
32
- with queue_lock:
33
- shared.state.begin()
34
- progress.start_task(id_task)
35
-
36
- try:
37
- res = func(*args, **kwargs)
38
- finally:
39
- progress.finish_task(id_task)
40
-
41
- shared.state.end()
42
-
43
- return res
44
-
45
- return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
46
-
47
-
48
- def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
49
- def f(*args, extra_outputs_array=extra_outputs, **kwargs):
50
- run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
51
- if run_memmon:
52
- shared.mem_mon.monitor()
53
- t = time.perf_counter()
54
-
55
- try:
56
- res = list(func(*args, **kwargs))
57
- except Exception as e:
58
- # When printing out our debug argument list, do not print out more than a MB of text
59
- max_debug_str_len = 131072 # (1024*1024)/8
60
-
61
- print("Error completing request", file=sys.stderr)
62
- argStr = f"Arguments: {str(args)} {str(kwargs)}"
63
- print(argStr[:max_debug_str_len], file=sys.stderr)
64
- if len(argStr) > max_debug_str_len:
65
- print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
66
-
67
- print(traceback.format_exc(), file=sys.stderr)
68
-
69
- shared.state.job = ""
70
- shared.state.job_count = 0
71
-
72
- if extra_outputs_array is None:
73
- extra_outputs_array = [None, '']
74
-
75
- res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
76
-
77
- shared.state.skipped = False
78
- shared.state.interrupted = False
79
- shared.state.job_count = 0
80
-
81
- if not add_stats:
82
- return tuple(res)
83
-
84
- elapsed = time.perf_counter() - t
85
- elapsed_m = int(elapsed // 60)
86
- elapsed_s = elapsed % 60
87
- elapsed_text = f"{elapsed_s:.2f}s"
88
- if elapsed_m > 0:
89
- elapsed_text = f"{elapsed_m}m "+elapsed_text
90
-
91
- if run_memmon:
92
- mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
93
- active_peak = mem_stats['active_peak']
94
- reserved_peak = mem_stats['reserved_peak']
95
- sys_peak = mem_stats['system_peak']
96
- sys_total = mem_stats['total']
97
- sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
98
-
99
- vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
100
- else:
101
- vram_html = ''
102
-
103
- # last item is always HTML
104
- res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
105
-
106
- return tuple(res)
107
-
108
- return f
109
-
 
1
+ import html
2
+ import sys
3
+ import threading
4
+ import traceback
5
+ import time
6
+
7
+ from modules import shared, progress
8
+
9
+ queue_lock = threading.Lock()
10
+
11
+
12
+ def wrap_queued_call(func):
13
+ def f(*args, **kwargs):
14
+ with queue_lock:
15
+ res = func(*args, **kwargs)
16
+
17
+ return res
18
+
19
+ return f
20
+
21
+
22
+ def wrap_gradio_gpu_call(func, extra_outputs=None):
23
+ def f(*args, **kwargs):
24
+
25
+ # if the first argument is a string that says "task(...)", it is treated as a job id
26
+ if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
27
+ id_task = args[0]
28
+ progress.add_task_to_queue(id_task)
29
+ else:
30
+ id_task = None
31
+
32
+ with queue_lock:
33
+ shared.state.begin()
34
+ progress.start_task(id_task)
35
+
36
+ try:
37
+ res = func(*args, **kwargs)
38
+ finally:
39
+ progress.finish_task(id_task)
40
+
41
+ shared.state.end()
42
+
43
+ return res
44
+
45
+ return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
46
+
47
+
48
+ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
49
+ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
50
+ run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
51
+ if run_memmon:
52
+ shared.mem_mon.monitor()
53
+ t = time.perf_counter()
54
+
55
+ try:
56
+ res = list(func(*args, **kwargs))
57
+ except Exception as e:
58
+ # When printing out our debug argument list, do not print out more than a MB of text
59
+ max_debug_str_len = 131072 # (1024*1024)/8
60
+
61
+ print("Error completing request", file=sys.stderr)
62
+ argStr = f"Arguments: {str(args)} {str(kwargs)}"
63
+ print(argStr[:max_debug_str_len], file=sys.stderr)
64
+ if len(argStr) > max_debug_str_len:
65
+ print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
66
+
67
+ print(traceback.format_exc(), file=sys.stderr)
68
+
69
+ shared.state.job = ""
70
+ shared.state.job_count = 0
71
+
72
+ if extra_outputs_array is None:
73
+ extra_outputs_array = [None, '']
74
+
75
+ res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
76
+
77
+ shared.state.skipped = False
78
+ shared.state.interrupted = False
79
+ shared.state.job_count = 0
80
+
81
+ if not add_stats:
82
+ return tuple(res)
83
+
84
+ elapsed = time.perf_counter() - t
85
+ elapsed_m = int(elapsed // 60)
86
+ elapsed_s = elapsed % 60
87
+ elapsed_text = f"{elapsed_s:.2f}s"
88
+ if elapsed_m > 0:
89
+ elapsed_text = f"{elapsed_m}m "+elapsed_text
90
+
91
+ if run_memmon:
92
+ mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
93
+ active_peak = mem_stats['active_peak']
94
+ reserved_peak = mem_stats['reserved_peak']
95
+ sys_peak = mem_stats['system_peak']
96
+ sys_total = mem_stats['total']
97
+ sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
98
+
99
+ vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
100
+ else:
101
+ vram_html = ''
102
+
103
+ # last item is always HTML
104
+ res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
105
+
106
+ return tuple(res)
107
+
108
+ return f
109
+
sd/stable-diffusion-webui/modules/codeformer_model.py CHANGED
@@ -1,143 +1,143 @@
1
- import os
2
- import sys
3
- import traceback
4
-
5
- import cv2
6
- import torch
7
-
8
- import modules.face_restoration
9
- import modules.shared
10
- from modules import shared, devices, modelloader
11
- from modules.paths import models_path
12
-
13
- # codeformer people made a choice to include modified basicsr library to their project which makes
14
- # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
15
- # I am making a choice to include some files from codeformer to work around this issue.
16
- model_dir = "Codeformer"
17
- model_path = os.path.join(models_path, model_dir)
18
- model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
19
-
20
- have_codeformer = False
21
- codeformer = None
22
-
23
-
24
- def setup_model(dirname):
25
- global model_path
26
- if not os.path.exists(model_path):
27
- os.makedirs(model_path)
28
-
29
- path = modules.paths.paths.get("CodeFormer", None)
30
- if path is None:
31
- return
32
-
33
- try:
34
- from torchvision.transforms.functional import normalize
35
- from modules.codeformer.codeformer_arch import CodeFormer
36
- from basicsr.utils.download_util import load_file_from_url
37
- from basicsr.utils import imwrite, img2tensor, tensor2img
38
- from facelib.utils.face_restoration_helper import FaceRestoreHelper
39
- from facelib.detection.retinaface import retinaface
40
- from modules.shared import cmd_opts
41
-
42
- net_class = CodeFormer
43
-
44
- class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
45
- def name(self):
46
- return "CodeFormer"
47
-
48
- def __init__(self, dirname):
49
- self.net = None
50
- self.face_helper = None
51
- self.cmd_dir = dirname
52
-
53
- def create_models(self):
54
-
55
- if self.net is not None and self.face_helper is not None:
56
- self.net.to(devices.device_codeformer)
57
- return self.net, self.face_helper
58
- model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
59
- if len(model_paths) != 0:
60
- ckpt_path = model_paths[0]
61
- else:
62
- print("Unable to load codeformer model.")
63
- return None, None
64
- net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
65
- checkpoint = torch.load(ckpt_path)['params_ema']
66
- net.load_state_dict(checkpoint)
67
- net.eval()
68
-
69
- if hasattr(retinaface, 'device'):
70
- retinaface.device = devices.device_codeformer
71
- face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
72
-
73
- self.net = net
74
- self.face_helper = face_helper
75
-
76
- return net, face_helper
77
-
78
- def send_model_to(self, device):
79
- self.net.to(device)
80
- self.face_helper.face_det.to(device)
81
- self.face_helper.face_parse.to(device)
82
-
83
- def restore(self, np_image, w=None):
84
- np_image = np_image[:, :, ::-1]
85
-
86
- original_resolution = np_image.shape[0:2]
87
-
88
- self.create_models()
89
- if self.net is None or self.face_helper is None:
90
- return np_image
91
-
92
- self.send_model_to(devices.device_codeformer)
93
-
94
- self.face_helper.clean_all()
95
- self.face_helper.read_image(np_image)
96
- self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
97
- self.face_helper.align_warp_face()
98
-
99
- for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
100
- cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
101
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
102
- cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
103
-
104
- try:
105
- with torch.no_grad():
106
- output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
107
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
108
- del output
109
- torch.cuda.empty_cache()
110
- except Exception as error:
111
- print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
112
- restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
113
-
114
- restored_face = restored_face.astype('uint8')
115
- self.face_helper.add_restored_face(restored_face)
116
-
117
- self.face_helper.get_inverse_affine(None)
118
-
119
- restored_img = self.face_helper.paste_faces_to_input_image()
120
- restored_img = restored_img[:, :, ::-1]
121
-
122
- if original_resolution != restored_img.shape[0:2]:
123
- restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
124
-
125
- self.face_helper.clean_all()
126
-
127
- if shared.opts.face_restoration_unload:
128
- self.send_model_to(devices.cpu)
129
-
130
- return restored_img
131
-
132
- global have_codeformer
133
- have_codeformer = True
134
-
135
- global codeformer
136
- codeformer = FaceRestorerCodeFormer(dirname)
137
- shared.face_restorers.append(codeformer)
138
-
139
- except Exception:
140
- print("Error setting up CodeFormer:", file=sys.stderr)
141
- print(traceback.format_exc(), file=sys.stderr)
142
-
143
- # sys.path = stored_sys_path
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ import cv2
6
+ import torch
7
+
8
+ import modules.face_restoration
9
+ import modules.shared
10
+ from modules import shared, devices, modelloader
11
+ from modules.paths import models_path
12
+
13
+ # codeformer people made a choice to include modified basicsr library to their project which makes
14
+ # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
15
+ # I am making a choice to include some files from codeformer to work around this issue.
16
+ model_dir = "Codeformer"
17
+ model_path = os.path.join(models_path, model_dir)
18
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
19
+
20
+ have_codeformer = False
21
+ codeformer = None
22
+
23
+
24
+ def setup_model(dirname):
25
+ global model_path
26
+ if not os.path.exists(model_path):
27
+ os.makedirs(model_path)
28
+
29
+ path = modules.paths.paths.get("CodeFormer", None)
30
+ if path is None:
31
+ return
32
+
33
+ try:
34
+ from torchvision.transforms.functional import normalize
35
+ from modules.codeformer.codeformer_arch import CodeFormer
36
+ from basicsr.utils.download_util import load_file_from_url
37
+ from basicsr.utils import imwrite, img2tensor, tensor2img
38
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
39
+ from facelib.detection.retinaface import retinaface
40
+ from modules.shared import cmd_opts
41
+
42
+ net_class = CodeFormer
43
+
44
+ class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
45
+ def name(self):
46
+ return "CodeFormer"
47
+
48
+ def __init__(self, dirname):
49
+ self.net = None
50
+ self.face_helper = None
51
+ self.cmd_dir = dirname
52
+
53
+ def create_models(self):
54
+
55
+ if self.net is not None and self.face_helper is not None:
56
+ self.net.to(devices.device_codeformer)
57
+ return self.net, self.face_helper
58
+ model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
59
+ if len(model_paths) != 0:
60
+ ckpt_path = model_paths[0]
61
+ else:
62
+ print("Unable to load codeformer model.")
63
+ return None, None
64
+ net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
65
+ checkpoint = torch.load(ckpt_path)['params_ema']
66
+ net.load_state_dict(checkpoint)
67
+ net.eval()
68
+
69
+ if hasattr(retinaface, 'device'):
70
+ retinaface.device = devices.device_codeformer
71
+ face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
72
+
73
+ self.net = net
74
+ self.face_helper = face_helper
75
+
76
+ return net, face_helper
77
+
78
+ def send_model_to(self, device):
79
+ self.net.to(device)
80
+ self.face_helper.face_det.to(device)
81
+ self.face_helper.face_parse.to(device)
82
+
83
+ def restore(self, np_image, w=None):
84
+ np_image = np_image[:, :, ::-1]
85
+
86
+ original_resolution = np_image.shape[0:2]
87
+
88
+ self.create_models()
89
+ if self.net is None or self.face_helper is None:
90
+ return np_image
91
+
92
+ self.send_model_to(devices.device_codeformer)
93
+
94
+ self.face_helper.clean_all()
95
+ self.face_helper.read_image(np_image)
96
+ self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
97
+ self.face_helper.align_warp_face()
98
+
99
+ for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
100
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
101
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
102
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
103
+
104
+ try:
105
+ with torch.no_grad():
106
+ output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
107
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
108
+ del output
109
+ torch.cuda.empty_cache()
110
+ except Exception as error:
111
+ print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
112
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
113
+
114
+ restored_face = restored_face.astype('uint8')
115
+ self.face_helper.add_restored_face(restored_face)
116
+
117
+ self.face_helper.get_inverse_affine(None)
118
+
119
+ restored_img = self.face_helper.paste_faces_to_input_image()
120
+ restored_img = restored_img[:, :, ::-1]
121
+
122
+ if original_resolution != restored_img.shape[0:2]:
123
+ restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
124
+
125
+ self.face_helper.clean_all()
126
+
127
+ if shared.opts.face_restoration_unload:
128
+ self.send_model_to(devices.cpu)
129
+
130
+ return restored_img
131
+
132
+ global have_codeformer
133
+ have_codeformer = True
134
+
135
+ global codeformer
136
+ codeformer = FaceRestorerCodeFormer(dirname)
137
+ shared.face_restorers.append(codeformer)
138
+
139
+ except Exception:
140
+ print("Error setting up CodeFormer:", file=sys.stderr)
141
+ print(traceback.format_exc(), file=sys.stderr)
142
+
143
+ # sys.path = stored_sys_path
sd/stable-diffusion-webui/modules/deepbooru_model.py CHANGED
@@ -1,678 +1,678 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from modules import devices
6
-
7
- # see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
8
-
9
-
10
- class DeepDanbooruModel(nn.Module):
11
- def __init__(self):
12
- super(DeepDanbooruModel, self).__init__()
13
-
14
- self.tags = []
15
-
16
- self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
17
- self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
18
- self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
19
- self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
20
- self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
21
- self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
22
- self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
23
- self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
24
- self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
25
- self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
26
- self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
27
- self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
28
- self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
29
- self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
30
- self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
31
- self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
32
- self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
33
- self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
34
- self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
35
- self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
36
- self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
37
- self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
38
- self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
39
- self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
40
- self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
41
- self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
42
- self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
43
- self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
44
- self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
45
- self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
46
- self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
47
- self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
48
- self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
49
- self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
50
- self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
51
- self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
52
- self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
53
- self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
54
- self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
55
- self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
56
- self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
57
- self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
58
- self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
59
- self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
60
- self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
61
- self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
62
- self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
63
- self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
64
- self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
65
- self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
66
- self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
67
- self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
68
- self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
69
- self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
70
- self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
71
- self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
72
- self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
73
- self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
74
- self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
75
- self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
76
- self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
77
- self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
78
- self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
79
- self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
80
- self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
81
- self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
82
- self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
83
- self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
84
- self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
85
- self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
86
- self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
87
- self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
88
- self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
89
- self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
90
- self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
91
- self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
92
- self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
93
- self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
94
- self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
95
- self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
96
- self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
97
- self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
98
- self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
99
- self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
100
- self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
101
- self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
102
- self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
103
- self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
104
- self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
105
- self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
106
- self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
107
- self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
108
- self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
109
- self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
110
- self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
111
- self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
112
- self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
113
- self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
114
- self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
115
- self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
116
- self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
117
- self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
118
- self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
119
- self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
120
- self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
121
- self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
122
- self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
123
- self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
124
- self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
125
- self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
126
- self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
127
- self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
128
- self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
129
- self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
130
- self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
131
- self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
132
- self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
133
- self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
134
- self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
135
- self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
136
- self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
137
- self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
138
- self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
139
- self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
140
- self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
141
- self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
142
- self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
143
- self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
144
- self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
145
- self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
146
- self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
147
- self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
148
- self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
149
- self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
150
- self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
151
- self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
152
- self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
153
- self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
154
- self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
155
- self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
156
- self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
157
- self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
158
- self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
159
- self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
160
- self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
161
- self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
162
- self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
163
- self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
164
- self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
165
- self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
166
- self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
167
- self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
168
- self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
169
- self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
170
- self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
171
- self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
172
- self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
173
- self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
174
- self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
175
- self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
176
- self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
177
- self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
178
- self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
179
- self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
180
- self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
181
- self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
182
- self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
183
- self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
184
- self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
185
- self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
186
- self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
187
- self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
188
- self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
189
- self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
190
- self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
191
- self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
192
- self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
193
- self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
194
- self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
195
- self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
196
-
197
- def forward(self, *inputs):
198
- t_358, = inputs
199
- t_359 = t_358.permute(*[0, 3, 1, 2])
200
- t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
201
- t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
202
- t_361 = F.relu(t_360)
203
- t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
204
- t_362 = self.n_MaxPool_0(t_361)
205
- t_363 = self.n_Conv_1(t_362)
206
- t_364 = self.n_Conv_2(t_362)
207
- t_365 = F.relu(t_364)
208
- t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
209
- t_366 = self.n_Conv_3(t_365_padded)
210
- t_367 = F.relu(t_366)
211
- t_368 = self.n_Conv_4(t_367)
212
- t_369 = torch.add(t_368, t_363)
213
- t_370 = F.relu(t_369)
214
- t_371 = self.n_Conv_5(t_370)
215
- t_372 = F.relu(t_371)
216
- t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
217
- t_373 = self.n_Conv_6(t_372_padded)
218
- t_374 = F.relu(t_373)
219
- t_375 = self.n_Conv_7(t_374)
220
- t_376 = torch.add(t_375, t_370)
221
- t_377 = F.relu(t_376)
222
- t_378 = self.n_Conv_8(t_377)
223
- t_379 = F.relu(t_378)
224
- t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
225
- t_380 = self.n_Conv_9(t_379_padded)
226
- t_381 = F.relu(t_380)
227
- t_382 = self.n_Conv_10(t_381)
228
- t_383 = torch.add(t_382, t_377)
229
- t_384 = F.relu(t_383)
230
- t_385 = self.n_Conv_11(t_384)
231
- t_386 = self.n_Conv_12(t_384)
232
- t_387 = F.relu(t_386)
233
- t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
234
- t_388 = self.n_Conv_13(t_387_padded)
235
- t_389 = F.relu(t_388)
236
- t_390 = self.n_Conv_14(t_389)
237
- t_391 = torch.add(t_390, t_385)
238
- t_392 = F.relu(t_391)
239
- t_393 = self.n_Conv_15(t_392)
240
- t_394 = F.relu(t_393)
241
- t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
242
- t_395 = self.n_Conv_16(t_394_padded)
243
- t_396 = F.relu(t_395)
244
- t_397 = self.n_Conv_17(t_396)
245
- t_398 = torch.add(t_397, t_392)
246
- t_399 = F.relu(t_398)
247
- t_400 = self.n_Conv_18(t_399)
248
- t_401 = F.relu(t_400)
249
- t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
250
- t_402 = self.n_Conv_19(t_401_padded)
251
- t_403 = F.relu(t_402)
252
- t_404 = self.n_Conv_20(t_403)
253
- t_405 = torch.add(t_404, t_399)
254
- t_406 = F.relu(t_405)
255
- t_407 = self.n_Conv_21(t_406)
256
- t_408 = F.relu(t_407)
257
- t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
258
- t_409 = self.n_Conv_22(t_408_padded)
259
- t_410 = F.relu(t_409)
260
- t_411 = self.n_Conv_23(t_410)
261
- t_412 = torch.add(t_411, t_406)
262
- t_413 = F.relu(t_412)
263
- t_414 = self.n_Conv_24(t_413)
264
- t_415 = F.relu(t_414)
265
- t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
266
- t_416 = self.n_Conv_25(t_415_padded)
267
- t_417 = F.relu(t_416)
268
- t_418 = self.n_Conv_26(t_417)
269
- t_419 = torch.add(t_418, t_413)
270
- t_420 = F.relu(t_419)
271
- t_421 = self.n_Conv_27(t_420)
272
- t_422 = F.relu(t_421)
273
- t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
274
- t_423 = self.n_Conv_28(t_422_padded)
275
- t_424 = F.relu(t_423)
276
- t_425 = self.n_Conv_29(t_424)
277
- t_426 = torch.add(t_425, t_420)
278
- t_427 = F.relu(t_426)
279
- t_428 = self.n_Conv_30(t_427)
280
- t_429 = F.relu(t_428)
281
- t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
282
- t_430 = self.n_Conv_31(t_429_padded)
283
- t_431 = F.relu(t_430)
284
- t_432 = self.n_Conv_32(t_431)
285
- t_433 = torch.add(t_432, t_427)
286
- t_434 = F.relu(t_433)
287
- t_435 = self.n_Conv_33(t_434)
288
- t_436 = F.relu(t_435)
289
- t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
290
- t_437 = self.n_Conv_34(t_436_padded)
291
- t_438 = F.relu(t_437)
292
- t_439 = self.n_Conv_35(t_438)
293
- t_440 = torch.add(t_439, t_434)
294
- t_441 = F.relu(t_440)
295
- t_442 = self.n_Conv_36(t_441)
296
- t_443 = self.n_Conv_37(t_441)
297
- t_444 = F.relu(t_443)
298
- t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
299
- t_445 = self.n_Conv_38(t_444_padded)
300
- t_446 = F.relu(t_445)
301
- t_447 = self.n_Conv_39(t_446)
302
- t_448 = torch.add(t_447, t_442)
303
- t_449 = F.relu(t_448)
304
- t_450 = self.n_Conv_40(t_449)
305
- t_451 = F.relu(t_450)
306
- t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
307
- t_452 = self.n_Conv_41(t_451_padded)
308
- t_453 = F.relu(t_452)
309
- t_454 = self.n_Conv_42(t_453)
310
- t_455 = torch.add(t_454, t_449)
311
- t_456 = F.relu(t_455)
312
- t_457 = self.n_Conv_43(t_456)
313
- t_458 = F.relu(t_457)
314
- t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
315
- t_459 = self.n_Conv_44(t_458_padded)
316
- t_460 = F.relu(t_459)
317
- t_461 = self.n_Conv_45(t_460)
318
- t_462 = torch.add(t_461, t_456)
319
- t_463 = F.relu(t_462)
320
- t_464 = self.n_Conv_46(t_463)
321
- t_465 = F.relu(t_464)
322
- t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
323
- t_466 = self.n_Conv_47(t_465_padded)
324
- t_467 = F.relu(t_466)
325
- t_468 = self.n_Conv_48(t_467)
326
- t_469 = torch.add(t_468, t_463)
327
- t_470 = F.relu(t_469)
328
- t_471 = self.n_Conv_49(t_470)
329
- t_472 = F.relu(t_471)
330
- t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
331
- t_473 = self.n_Conv_50(t_472_padded)
332
- t_474 = F.relu(t_473)
333
- t_475 = self.n_Conv_51(t_474)
334
- t_476 = torch.add(t_475, t_470)
335
- t_477 = F.relu(t_476)
336
- t_478 = self.n_Conv_52(t_477)
337
- t_479 = F.relu(t_478)
338
- t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
339
- t_480 = self.n_Conv_53(t_479_padded)
340
- t_481 = F.relu(t_480)
341
- t_482 = self.n_Conv_54(t_481)
342
- t_483 = torch.add(t_482, t_477)
343
- t_484 = F.relu(t_483)
344
- t_485 = self.n_Conv_55(t_484)
345
- t_486 = F.relu(t_485)
346
- t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
347
- t_487 = self.n_Conv_56(t_486_padded)
348
- t_488 = F.relu(t_487)
349
- t_489 = self.n_Conv_57(t_488)
350
- t_490 = torch.add(t_489, t_484)
351
- t_491 = F.relu(t_490)
352
- t_492 = self.n_Conv_58(t_491)
353
- t_493 = F.relu(t_492)
354
- t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
355
- t_494 = self.n_Conv_59(t_493_padded)
356
- t_495 = F.relu(t_494)
357
- t_496 = self.n_Conv_60(t_495)
358
- t_497 = torch.add(t_496, t_491)
359
- t_498 = F.relu(t_497)
360
- t_499 = self.n_Conv_61(t_498)
361
- t_500 = F.relu(t_499)
362
- t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
363
- t_501 = self.n_Conv_62(t_500_padded)
364
- t_502 = F.relu(t_501)
365
- t_503 = self.n_Conv_63(t_502)
366
- t_504 = torch.add(t_503, t_498)
367
- t_505 = F.relu(t_504)
368
- t_506 = self.n_Conv_64(t_505)
369
- t_507 = F.relu(t_506)
370
- t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
371
- t_508 = self.n_Conv_65(t_507_padded)
372
- t_509 = F.relu(t_508)
373
- t_510 = self.n_Conv_66(t_509)
374
- t_511 = torch.add(t_510, t_505)
375
- t_512 = F.relu(t_511)
376
- t_513 = self.n_Conv_67(t_512)
377
- t_514 = F.relu(t_513)
378
- t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
379
- t_515 = self.n_Conv_68(t_514_padded)
380
- t_516 = F.relu(t_515)
381
- t_517 = self.n_Conv_69(t_516)
382
- t_518 = torch.add(t_517, t_512)
383
- t_519 = F.relu(t_518)
384
- t_520 = self.n_Conv_70(t_519)
385
- t_521 = F.relu(t_520)
386
- t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
387
- t_522 = self.n_Conv_71(t_521_padded)
388
- t_523 = F.relu(t_522)
389
- t_524 = self.n_Conv_72(t_523)
390
- t_525 = torch.add(t_524, t_519)
391
- t_526 = F.relu(t_525)
392
- t_527 = self.n_Conv_73(t_526)
393
- t_528 = F.relu(t_527)
394
- t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
395
- t_529 = self.n_Conv_74(t_528_padded)
396
- t_530 = F.relu(t_529)
397
- t_531 = self.n_Conv_75(t_530)
398
- t_532 = torch.add(t_531, t_526)
399
- t_533 = F.relu(t_532)
400
- t_534 = self.n_Conv_76(t_533)
401
- t_535 = F.relu(t_534)
402
- t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
403
- t_536 = self.n_Conv_77(t_535_padded)
404
- t_537 = F.relu(t_536)
405
- t_538 = self.n_Conv_78(t_537)
406
- t_539 = torch.add(t_538, t_533)
407
- t_540 = F.relu(t_539)
408
- t_541 = self.n_Conv_79(t_540)
409
- t_542 = F.relu(t_541)
410
- t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
411
- t_543 = self.n_Conv_80(t_542_padded)
412
- t_544 = F.relu(t_543)
413
- t_545 = self.n_Conv_81(t_544)
414
- t_546 = torch.add(t_545, t_540)
415
- t_547 = F.relu(t_546)
416
- t_548 = self.n_Conv_82(t_547)
417
- t_549 = F.relu(t_548)
418
- t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
419
- t_550 = self.n_Conv_83(t_549_padded)
420
- t_551 = F.relu(t_550)
421
- t_552 = self.n_Conv_84(t_551)
422
- t_553 = torch.add(t_552, t_547)
423
- t_554 = F.relu(t_553)
424
- t_555 = self.n_Conv_85(t_554)
425
- t_556 = F.relu(t_555)
426
- t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
427
- t_557 = self.n_Conv_86(t_556_padded)
428
- t_558 = F.relu(t_557)
429
- t_559 = self.n_Conv_87(t_558)
430
- t_560 = torch.add(t_559, t_554)
431
- t_561 = F.relu(t_560)
432
- t_562 = self.n_Conv_88(t_561)
433
- t_563 = F.relu(t_562)
434
- t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
435
- t_564 = self.n_Conv_89(t_563_padded)
436
- t_565 = F.relu(t_564)
437
- t_566 = self.n_Conv_90(t_565)
438
- t_567 = torch.add(t_566, t_561)
439
- t_568 = F.relu(t_567)
440
- t_569 = self.n_Conv_91(t_568)
441
- t_570 = F.relu(t_569)
442
- t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
443
- t_571 = self.n_Conv_92(t_570_padded)
444
- t_572 = F.relu(t_571)
445
- t_573 = self.n_Conv_93(t_572)
446
- t_574 = torch.add(t_573, t_568)
447
- t_575 = F.relu(t_574)
448
- t_576 = self.n_Conv_94(t_575)
449
- t_577 = F.relu(t_576)
450
- t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
451
- t_578 = self.n_Conv_95(t_577_padded)
452
- t_579 = F.relu(t_578)
453
- t_580 = self.n_Conv_96(t_579)
454
- t_581 = torch.add(t_580, t_575)
455
- t_582 = F.relu(t_581)
456
- t_583 = self.n_Conv_97(t_582)
457
- t_584 = F.relu(t_583)
458
- t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
459
- t_585 = self.n_Conv_98(t_584_padded)
460
- t_586 = F.relu(t_585)
461
- t_587 = self.n_Conv_99(t_586)
462
- t_588 = self.n_Conv_100(t_582)
463
- t_589 = torch.add(t_587, t_588)
464
- t_590 = F.relu(t_589)
465
- t_591 = self.n_Conv_101(t_590)
466
- t_592 = F.relu(t_591)
467
- t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
468
- t_593 = self.n_Conv_102(t_592_padded)
469
- t_594 = F.relu(t_593)
470
- t_595 = self.n_Conv_103(t_594)
471
- t_596 = torch.add(t_595, t_590)
472
- t_597 = F.relu(t_596)
473
- t_598 = self.n_Conv_104(t_597)
474
- t_599 = F.relu(t_598)
475
- t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
476
- t_600 = self.n_Conv_105(t_599_padded)
477
- t_601 = F.relu(t_600)
478
- t_602 = self.n_Conv_106(t_601)
479
- t_603 = torch.add(t_602, t_597)
480
- t_604 = F.relu(t_603)
481
- t_605 = self.n_Conv_107(t_604)
482
- t_606 = F.relu(t_605)
483
- t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
484
- t_607 = self.n_Conv_108(t_606_padded)
485
- t_608 = F.relu(t_607)
486
- t_609 = self.n_Conv_109(t_608)
487
- t_610 = torch.add(t_609, t_604)
488
- t_611 = F.relu(t_610)
489
- t_612 = self.n_Conv_110(t_611)
490
- t_613 = F.relu(t_612)
491
- t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
492
- t_614 = self.n_Conv_111(t_613_padded)
493
- t_615 = F.relu(t_614)
494
- t_616 = self.n_Conv_112(t_615)
495
- t_617 = torch.add(t_616, t_611)
496
- t_618 = F.relu(t_617)
497
- t_619 = self.n_Conv_113(t_618)
498
- t_620 = F.relu(t_619)
499
- t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
500
- t_621 = self.n_Conv_114(t_620_padded)
501
- t_622 = F.relu(t_621)
502
- t_623 = self.n_Conv_115(t_622)
503
- t_624 = torch.add(t_623, t_618)
504
- t_625 = F.relu(t_624)
505
- t_626 = self.n_Conv_116(t_625)
506
- t_627 = F.relu(t_626)
507
- t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
508
- t_628 = self.n_Conv_117(t_627_padded)
509
- t_629 = F.relu(t_628)
510
- t_630 = self.n_Conv_118(t_629)
511
- t_631 = torch.add(t_630, t_625)
512
- t_632 = F.relu(t_631)
513
- t_633 = self.n_Conv_119(t_632)
514
- t_634 = F.relu(t_633)
515
- t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
516
- t_635 = self.n_Conv_120(t_634_padded)
517
- t_636 = F.relu(t_635)
518
- t_637 = self.n_Conv_121(t_636)
519
- t_638 = torch.add(t_637, t_632)
520
- t_639 = F.relu(t_638)
521
- t_640 = self.n_Conv_122(t_639)
522
- t_641 = F.relu(t_640)
523
- t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
524
- t_642 = self.n_Conv_123(t_641_padded)
525
- t_643 = F.relu(t_642)
526
- t_644 = self.n_Conv_124(t_643)
527
- t_645 = torch.add(t_644, t_639)
528
- t_646 = F.relu(t_645)
529
- t_647 = self.n_Conv_125(t_646)
530
- t_648 = F.relu(t_647)
531
- t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
532
- t_649 = self.n_Conv_126(t_648_padded)
533
- t_650 = F.relu(t_649)
534
- t_651 = self.n_Conv_127(t_650)
535
- t_652 = torch.add(t_651, t_646)
536
- t_653 = F.relu(t_652)
537
- t_654 = self.n_Conv_128(t_653)
538
- t_655 = F.relu(t_654)
539
- t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
540
- t_656 = self.n_Conv_129(t_655_padded)
541
- t_657 = F.relu(t_656)
542
- t_658 = self.n_Conv_130(t_657)
543
- t_659 = torch.add(t_658, t_653)
544
- t_660 = F.relu(t_659)
545
- t_661 = self.n_Conv_131(t_660)
546
- t_662 = F.relu(t_661)
547
- t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
548
- t_663 = self.n_Conv_132(t_662_padded)
549
- t_664 = F.relu(t_663)
550
- t_665 = self.n_Conv_133(t_664)
551
- t_666 = torch.add(t_665, t_660)
552
- t_667 = F.relu(t_666)
553
- t_668 = self.n_Conv_134(t_667)
554
- t_669 = F.relu(t_668)
555
- t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
556
- t_670 = self.n_Conv_135(t_669_padded)
557
- t_671 = F.relu(t_670)
558
- t_672 = self.n_Conv_136(t_671)
559
- t_673 = torch.add(t_672, t_667)
560
- t_674 = F.relu(t_673)
561
- t_675 = self.n_Conv_137(t_674)
562
- t_676 = F.relu(t_675)
563
- t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
564
- t_677 = self.n_Conv_138(t_676_padded)
565
- t_678 = F.relu(t_677)
566
- t_679 = self.n_Conv_139(t_678)
567
- t_680 = torch.add(t_679, t_674)
568
- t_681 = F.relu(t_680)
569
- t_682 = self.n_Conv_140(t_681)
570
- t_683 = F.relu(t_682)
571
- t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
572
- t_684 = self.n_Conv_141(t_683_padded)
573
- t_685 = F.relu(t_684)
574
- t_686 = self.n_Conv_142(t_685)
575
- t_687 = torch.add(t_686, t_681)
576
- t_688 = F.relu(t_687)
577
- t_689 = self.n_Conv_143(t_688)
578
- t_690 = F.relu(t_689)
579
- t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
580
- t_691 = self.n_Conv_144(t_690_padded)
581
- t_692 = F.relu(t_691)
582
- t_693 = self.n_Conv_145(t_692)
583
- t_694 = torch.add(t_693, t_688)
584
- t_695 = F.relu(t_694)
585
- t_696 = self.n_Conv_146(t_695)
586
- t_697 = F.relu(t_696)
587
- t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
588
- t_698 = self.n_Conv_147(t_697_padded)
589
- t_699 = F.relu(t_698)
590
- t_700 = self.n_Conv_148(t_699)
591
- t_701 = torch.add(t_700, t_695)
592
- t_702 = F.relu(t_701)
593
- t_703 = self.n_Conv_149(t_702)
594
- t_704 = F.relu(t_703)
595
- t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
596
- t_705 = self.n_Conv_150(t_704_padded)
597
- t_706 = F.relu(t_705)
598
- t_707 = self.n_Conv_151(t_706)
599
- t_708 = torch.add(t_707, t_702)
600
- t_709 = F.relu(t_708)
601
- t_710 = self.n_Conv_152(t_709)
602
- t_711 = F.relu(t_710)
603
- t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
604
- t_712 = self.n_Conv_153(t_711_padded)
605
- t_713 = F.relu(t_712)
606
- t_714 = self.n_Conv_154(t_713)
607
- t_715 = torch.add(t_714, t_709)
608
- t_716 = F.relu(t_715)
609
- t_717 = self.n_Conv_155(t_716)
610
- t_718 = F.relu(t_717)
611
- t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
612
- t_719 = self.n_Conv_156(t_718_padded)
613
- t_720 = F.relu(t_719)
614
- t_721 = self.n_Conv_157(t_720)
615
- t_722 = torch.add(t_721, t_716)
616
- t_723 = F.relu(t_722)
617
- t_724 = self.n_Conv_158(t_723)
618
- t_725 = self.n_Conv_159(t_723)
619
- t_726 = F.relu(t_725)
620
- t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
621
- t_727 = self.n_Conv_160(t_726_padded)
622
- t_728 = F.relu(t_727)
623
- t_729 = self.n_Conv_161(t_728)
624
- t_730 = torch.add(t_729, t_724)
625
- t_731 = F.relu(t_730)
626
- t_732 = self.n_Conv_162(t_731)
627
- t_733 = F.relu(t_732)
628
- t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
629
- t_734 = self.n_Conv_163(t_733_padded)
630
- t_735 = F.relu(t_734)
631
- t_736 = self.n_Conv_164(t_735)
632
- t_737 = torch.add(t_736, t_731)
633
- t_738 = F.relu(t_737)
634
- t_739 = self.n_Conv_165(t_738)
635
- t_740 = F.relu(t_739)
636
- t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
637
- t_741 = self.n_Conv_166(t_740_padded)
638
- t_742 = F.relu(t_741)
639
- t_743 = self.n_Conv_167(t_742)
640
- t_744 = torch.add(t_743, t_738)
641
- t_745 = F.relu(t_744)
642
- t_746 = self.n_Conv_168(t_745)
643
- t_747 = self.n_Conv_169(t_745)
644
- t_748 = F.relu(t_747)
645
- t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
646
- t_749 = self.n_Conv_170(t_748_padded)
647
- t_750 = F.relu(t_749)
648
- t_751 = self.n_Conv_171(t_750)
649
- t_752 = torch.add(t_751, t_746)
650
- t_753 = F.relu(t_752)
651
- t_754 = self.n_Conv_172(t_753)
652
- t_755 = F.relu(t_754)
653
- t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
654
- t_756 = self.n_Conv_173(t_755_padded)
655
- t_757 = F.relu(t_756)
656
- t_758 = self.n_Conv_174(t_757)
657
- t_759 = torch.add(t_758, t_753)
658
- t_760 = F.relu(t_759)
659
- t_761 = self.n_Conv_175(t_760)
660
- t_762 = F.relu(t_761)
661
- t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
662
- t_763 = self.n_Conv_176(t_762_padded)
663
- t_764 = F.relu(t_763)
664
- t_765 = self.n_Conv_177(t_764)
665
- t_766 = torch.add(t_765, t_760)
666
- t_767 = F.relu(t_766)
667
- t_768 = self.n_Conv_178(t_767)
668
- t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
669
- t_770 = torch.squeeze(t_769, 3)
670
- t_770 = torch.squeeze(t_770, 2)
671
- t_771 = torch.sigmoid(t_770)
672
- return t_771
673
-
674
- def load_state_dict(self, state_dict, **kwargs):
675
- self.tags = state_dict.get('tags', [])
676
-
677
- super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
678
-
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from modules import devices
6
+
7
+ # see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
8
+
9
+
10
+ class DeepDanbooruModel(nn.Module):
11
+ def __init__(self):
12
+ super(DeepDanbooruModel, self).__init__()
13
+
14
+ self.tags = []
15
+
16
+ self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
17
+ self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
18
+ self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
19
+ self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
20
+ self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
21
+ self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
22
+ self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
23
+ self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
24
+ self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
25
+ self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
26
+ self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
27
+ self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
28
+ self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
29
+ self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
30
+ self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
31
+ self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
32
+ self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
33
+ self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
34
+ self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
35
+ self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
36
+ self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
37
+ self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
38
+ self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
39
+ self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
40
+ self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
41
+ self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
42
+ self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
43
+ self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
44
+ self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
45
+ self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
46
+ self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
47
+ self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
48
+ self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
49
+ self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
50
+ self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
51
+ self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
52
+ self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
53
+ self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
54
+ self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
55
+ self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
56
+ self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
57
+ self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
58
+ self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
59
+ self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
60
+ self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
61
+ self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
62
+ self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
63
+ self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
64
+ self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
65
+ self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
66
+ self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
67
+ self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
68
+ self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
69
+ self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
70
+ self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
71
+ self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
72
+ self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
73
+ self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
74
+ self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
75
+ self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
76
+ self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
77
+ self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
78
+ self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
79
+ self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
80
+ self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
81
+ self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
82
+ self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
83
+ self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
84
+ self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
85
+ self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
86
+ self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
87
+ self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
88
+ self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
89
+ self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
90
+ self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
91
+ self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
92
+ self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
93
+ self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
94
+ self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
95
+ self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
96
+ self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
97
+ self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
98
+ self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
99
+ self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
100
+ self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
101
+ self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
102
+ self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
103
+ self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
104
+ self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
105
+ self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
106
+ self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
107
+ self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
108
+ self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
109
+ self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
110
+ self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
111
+ self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
112
+ self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
113
+ self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
114
+ self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
115
+ self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
116
+ self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
117
+ self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
118
+ self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
119
+ self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
120
+ self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
121
+ self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
122
+ self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
123
+ self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
124
+ self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
125
+ self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
126
+ self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
127
+ self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
128
+ self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
129
+ self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
130
+ self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
131
+ self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
132
+ self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
133
+ self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
134
+ self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
135
+ self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
136
+ self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
137
+ self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
138
+ self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
139
+ self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
140
+ self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
141
+ self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
142
+ self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
143
+ self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
144
+ self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
145
+ self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
146
+ self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
147
+ self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
148
+ self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
149
+ self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
150
+ self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
151
+ self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
152
+ self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
153
+ self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
154
+ self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
155
+ self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
156
+ self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
157
+ self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
158
+ self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
159
+ self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
160
+ self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
161
+ self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
162
+ self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
163
+ self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
164
+ self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
165
+ self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
166
+ self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
167
+ self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
168
+ self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
169
+ self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
170
+ self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
171
+ self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
172
+ self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
173
+ self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
174
+ self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
175
+ self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
176
+ self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
177
+ self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
178
+ self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
179
+ self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
180
+ self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
181
+ self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
182
+ self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
183
+ self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
184
+ self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
185
+ self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
186
+ self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
187
+ self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
188
+ self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
189
+ self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
190
+ self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
191
+ self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
192
+ self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
193
+ self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
194
+ self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
195
+ self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
196
+
197
+ def forward(self, *inputs):
198
+ t_358, = inputs
199
+ t_359 = t_358.permute(*[0, 3, 1, 2])
200
+ t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
201
+ t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
202
+ t_361 = F.relu(t_360)
203
+ t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
204
+ t_362 = self.n_MaxPool_0(t_361)
205
+ t_363 = self.n_Conv_1(t_362)
206
+ t_364 = self.n_Conv_2(t_362)
207
+ t_365 = F.relu(t_364)
208
+ t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
209
+ t_366 = self.n_Conv_3(t_365_padded)
210
+ t_367 = F.relu(t_366)
211
+ t_368 = self.n_Conv_4(t_367)
212
+ t_369 = torch.add(t_368, t_363)
213
+ t_370 = F.relu(t_369)
214
+ t_371 = self.n_Conv_5(t_370)
215
+ t_372 = F.relu(t_371)
216
+ t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
217
+ t_373 = self.n_Conv_6(t_372_padded)
218
+ t_374 = F.relu(t_373)
219
+ t_375 = self.n_Conv_7(t_374)
220
+ t_376 = torch.add(t_375, t_370)
221
+ t_377 = F.relu(t_376)
222
+ t_378 = self.n_Conv_8(t_377)
223
+ t_379 = F.relu(t_378)
224
+ t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
225
+ t_380 = self.n_Conv_9(t_379_padded)
226
+ t_381 = F.relu(t_380)
227
+ t_382 = self.n_Conv_10(t_381)
228
+ t_383 = torch.add(t_382, t_377)
229
+ t_384 = F.relu(t_383)
230
+ t_385 = self.n_Conv_11(t_384)
231
+ t_386 = self.n_Conv_12(t_384)
232
+ t_387 = F.relu(t_386)
233
+ t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
234
+ t_388 = self.n_Conv_13(t_387_padded)
235
+ t_389 = F.relu(t_388)
236
+ t_390 = self.n_Conv_14(t_389)
237
+ t_391 = torch.add(t_390, t_385)
238
+ t_392 = F.relu(t_391)
239
+ t_393 = self.n_Conv_15(t_392)
240
+ t_394 = F.relu(t_393)
241
+ t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
242
+ t_395 = self.n_Conv_16(t_394_padded)
243
+ t_396 = F.relu(t_395)
244
+ t_397 = self.n_Conv_17(t_396)
245
+ t_398 = torch.add(t_397, t_392)
246
+ t_399 = F.relu(t_398)
247
+ t_400 = self.n_Conv_18(t_399)
248
+ t_401 = F.relu(t_400)
249
+ t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
250
+ t_402 = self.n_Conv_19(t_401_padded)
251
+ t_403 = F.relu(t_402)
252
+ t_404 = self.n_Conv_20(t_403)
253
+ t_405 = torch.add(t_404, t_399)
254
+ t_406 = F.relu(t_405)
255
+ t_407 = self.n_Conv_21(t_406)
256
+ t_408 = F.relu(t_407)
257
+ t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
258
+ t_409 = self.n_Conv_22(t_408_padded)
259
+ t_410 = F.relu(t_409)
260
+ t_411 = self.n_Conv_23(t_410)
261
+ t_412 = torch.add(t_411, t_406)
262
+ t_413 = F.relu(t_412)
263
+ t_414 = self.n_Conv_24(t_413)
264
+ t_415 = F.relu(t_414)
265
+ t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
266
+ t_416 = self.n_Conv_25(t_415_padded)
267
+ t_417 = F.relu(t_416)
268
+ t_418 = self.n_Conv_26(t_417)
269
+ t_419 = torch.add(t_418, t_413)
270
+ t_420 = F.relu(t_419)
271
+ t_421 = self.n_Conv_27(t_420)
272
+ t_422 = F.relu(t_421)
273
+ t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
274
+ t_423 = self.n_Conv_28(t_422_padded)
275
+ t_424 = F.relu(t_423)
276
+ t_425 = self.n_Conv_29(t_424)
277
+ t_426 = torch.add(t_425, t_420)
278
+ t_427 = F.relu(t_426)
279
+ t_428 = self.n_Conv_30(t_427)
280
+ t_429 = F.relu(t_428)
281
+ t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
282
+ t_430 = self.n_Conv_31(t_429_padded)
283
+ t_431 = F.relu(t_430)
284
+ t_432 = self.n_Conv_32(t_431)
285
+ t_433 = torch.add(t_432, t_427)
286
+ t_434 = F.relu(t_433)
287
+ t_435 = self.n_Conv_33(t_434)
288
+ t_436 = F.relu(t_435)
289
+ t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
290
+ t_437 = self.n_Conv_34(t_436_padded)
291
+ t_438 = F.relu(t_437)
292
+ t_439 = self.n_Conv_35(t_438)
293
+ t_440 = torch.add(t_439, t_434)
294
+ t_441 = F.relu(t_440)
295
+ t_442 = self.n_Conv_36(t_441)
296
+ t_443 = self.n_Conv_37(t_441)
297
+ t_444 = F.relu(t_443)
298
+ t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
299
+ t_445 = self.n_Conv_38(t_444_padded)
300
+ t_446 = F.relu(t_445)
301
+ t_447 = self.n_Conv_39(t_446)
302
+ t_448 = torch.add(t_447, t_442)
303
+ t_449 = F.relu(t_448)
304
+ t_450 = self.n_Conv_40(t_449)
305
+ t_451 = F.relu(t_450)
306
+ t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
307
+ t_452 = self.n_Conv_41(t_451_padded)
308
+ t_453 = F.relu(t_452)
309
+ t_454 = self.n_Conv_42(t_453)
310
+ t_455 = torch.add(t_454, t_449)
311
+ t_456 = F.relu(t_455)
312
+ t_457 = self.n_Conv_43(t_456)
313
+ t_458 = F.relu(t_457)
314
+ t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
315
+ t_459 = self.n_Conv_44(t_458_padded)
316
+ t_460 = F.relu(t_459)
317
+ t_461 = self.n_Conv_45(t_460)
318
+ t_462 = torch.add(t_461, t_456)
319
+ t_463 = F.relu(t_462)
320
+ t_464 = self.n_Conv_46(t_463)
321
+ t_465 = F.relu(t_464)
322
+ t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
323
+ t_466 = self.n_Conv_47(t_465_padded)
324
+ t_467 = F.relu(t_466)
325
+ t_468 = self.n_Conv_48(t_467)
326
+ t_469 = torch.add(t_468, t_463)
327
+ t_470 = F.relu(t_469)
328
+ t_471 = self.n_Conv_49(t_470)
329
+ t_472 = F.relu(t_471)
330
+ t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
331
+ t_473 = self.n_Conv_50(t_472_padded)
332
+ t_474 = F.relu(t_473)
333
+ t_475 = self.n_Conv_51(t_474)
334
+ t_476 = torch.add(t_475, t_470)
335
+ t_477 = F.relu(t_476)
336
+ t_478 = self.n_Conv_52(t_477)
337
+ t_479 = F.relu(t_478)
338
+ t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
339
+ t_480 = self.n_Conv_53(t_479_padded)
340
+ t_481 = F.relu(t_480)
341
+ t_482 = self.n_Conv_54(t_481)
342
+ t_483 = torch.add(t_482, t_477)
343
+ t_484 = F.relu(t_483)
344
+ t_485 = self.n_Conv_55(t_484)
345
+ t_486 = F.relu(t_485)
346
+ t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
347
+ t_487 = self.n_Conv_56(t_486_padded)
348
+ t_488 = F.relu(t_487)
349
+ t_489 = self.n_Conv_57(t_488)
350
+ t_490 = torch.add(t_489, t_484)
351
+ t_491 = F.relu(t_490)
352
+ t_492 = self.n_Conv_58(t_491)
353
+ t_493 = F.relu(t_492)
354
+ t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
355
+ t_494 = self.n_Conv_59(t_493_padded)
356
+ t_495 = F.relu(t_494)
357
+ t_496 = self.n_Conv_60(t_495)
358
+ t_497 = torch.add(t_496, t_491)
359
+ t_498 = F.relu(t_497)
360
+ t_499 = self.n_Conv_61(t_498)
361
+ t_500 = F.relu(t_499)
362
+ t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
363
+ t_501 = self.n_Conv_62(t_500_padded)
364
+ t_502 = F.relu(t_501)
365
+ t_503 = self.n_Conv_63(t_502)
366
+ t_504 = torch.add(t_503, t_498)
367
+ t_505 = F.relu(t_504)
368
+ t_506 = self.n_Conv_64(t_505)
369
+ t_507 = F.relu(t_506)
370
+ t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
371
+ t_508 = self.n_Conv_65(t_507_padded)
372
+ t_509 = F.relu(t_508)
373
+ t_510 = self.n_Conv_66(t_509)
374
+ t_511 = torch.add(t_510, t_505)
375
+ t_512 = F.relu(t_511)
376
+ t_513 = self.n_Conv_67(t_512)
377
+ t_514 = F.relu(t_513)
378
+ t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
379
+ t_515 = self.n_Conv_68(t_514_padded)
380
+ t_516 = F.relu(t_515)
381
+ t_517 = self.n_Conv_69(t_516)
382
+ t_518 = torch.add(t_517, t_512)
383
+ t_519 = F.relu(t_518)
384
+ t_520 = self.n_Conv_70(t_519)
385
+ t_521 = F.relu(t_520)
386
+ t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
387
+ t_522 = self.n_Conv_71(t_521_padded)
388
+ t_523 = F.relu(t_522)
389
+ t_524 = self.n_Conv_72(t_523)
390
+ t_525 = torch.add(t_524, t_519)
391
+ t_526 = F.relu(t_525)
392
+ t_527 = self.n_Conv_73(t_526)
393
+ t_528 = F.relu(t_527)
394
+ t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
395
+ t_529 = self.n_Conv_74(t_528_padded)
396
+ t_530 = F.relu(t_529)
397
+ t_531 = self.n_Conv_75(t_530)
398
+ t_532 = torch.add(t_531, t_526)
399
+ t_533 = F.relu(t_532)
400
+ t_534 = self.n_Conv_76(t_533)
401
+ t_535 = F.relu(t_534)
402
+ t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
403
+ t_536 = self.n_Conv_77(t_535_padded)
404
+ t_537 = F.relu(t_536)
405
+ t_538 = self.n_Conv_78(t_537)
406
+ t_539 = torch.add(t_538, t_533)
407
+ t_540 = F.relu(t_539)
408
+ t_541 = self.n_Conv_79(t_540)
409
+ t_542 = F.relu(t_541)
410
+ t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
411
+ t_543 = self.n_Conv_80(t_542_padded)
412
+ t_544 = F.relu(t_543)
413
+ t_545 = self.n_Conv_81(t_544)
414
+ t_546 = torch.add(t_545, t_540)
415
+ t_547 = F.relu(t_546)
416
+ t_548 = self.n_Conv_82(t_547)
417
+ t_549 = F.relu(t_548)
418
+ t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
419
+ t_550 = self.n_Conv_83(t_549_padded)
420
+ t_551 = F.relu(t_550)
421
+ t_552 = self.n_Conv_84(t_551)
422
+ t_553 = torch.add(t_552, t_547)
423
+ t_554 = F.relu(t_553)
424
+ t_555 = self.n_Conv_85(t_554)
425
+ t_556 = F.relu(t_555)
426
+ t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
427
+ t_557 = self.n_Conv_86(t_556_padded)
428
+ t_558 = F.relu(t_557)
429
+ t_559 = self.n_Conv_87(t_558)
430
+ t_560 = torch.add(t_559, t_554)
431
+ t_561 = F.relu(t_560)
432
+ t_562 = self.n_Conv_88(t_561)
433
+ t_563 = F.relu(t_562)
434
+ t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
435
+ t_564 = self.n_Conv_89(t_563_padded)
436
+ t_565 = F.relu(t_564)
437
+ t_566 = self.n_Conv_90(t_565)
438
+ t_567 = torch.add(t_566, t_561)
439
+ t_568 = F.relu(t_567)
440
+ t_569 = self.n_Conv_91(t_568)
441
+ t_570 = F.relu(t_569)
442
+ t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
443
+ t_571 = self.n_Conv_92(t_570_padded)
444
+ t_572 = F.relu(t_571)
445
+ t_573 = self.n_Conv_93(t_572)
446
+ t_574 = torch.add(t_573, t_568)
447
+ t_575 = F.relu(t_574)
448
+ t_576 = self.n_Conv_94(t_575)
449
+ t_577 = F.relu(t_576)
450
+ t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
451
+ t_578 = self.n_Conv_95(t_577_padded)
452
+ t_579 = F.relu(t_578)
453
+ t_580 = self.n_Conv_96(t_579)
454
+ t_581 = torch.add(t_580, t_575)
455
+ t_582 = F.relu(t_581)
456
+ t_583 = self.n_Conv_97(t_582)
457
+ t_584 = F.relu(t_583)
458
+ t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
459
+ t_585 = self.n_Conv_98(t_584_padded)
460
+ t_586 = F.relu(t_585)
461
+ t_587 = self.n_Conv_99(t_586)
462
+ t_588 = self.n_Conv_100(t_582)
463
+ t_589 = torch.add(t_587, t_588)
464
+ t_590 = F.relu(t_589)
465
+ t_591 = self.n_Conv_101(t_590)
466
+ t_592 = F.relu(t_591)
467
+ t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
468
+ t_593 = self.n_Conv_102(t_592_padded)
469
+ t_594 = F.relu(t_593)
470
+ t_595 = self.n_Conv_103(t_594)
471
+ t_596 = torch.add(t_595, t_590)
472
+ t_597 = F.relu(t_596)
473
+ t_598 = self.n_Conv_104(t_597)
474
+ t_599 = F.relu(t_598)
475
+ t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
476
+ t_600 = self.n_Conv_105(t_599_padded)
477
+ t_601 = F.relu(t_600)
478
+ t_602 = self.n_Conv_106(t_601)
479
+ t_603 = torch.add(t_602, t_597)
480
+ t_604 = F.relu(t_603)
481
+ t_605 = self.n_Conv_107(t_604)
482
+ t_606 = F.relu(t_605)
483
+ t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
484
+ t_607 = self.n_Conv_108(t_606_padded)
485
+ t_608 = F.relu(t_607)
486
+ t_609 = self.n_Conv_109(t_608)
487
+ t_610 = torch.add(t_609, t_604)
488
+ t_611 = F.relu(t_610)
489
+ t_612 = self.n_Conv_110(t_611)
490
+ t_613 = F.relu(t_612)
491
+ t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
492
+ t_614 = self.n_Conv_111(t_613_padded)
493
+ t_615 = F.relu(t_614)
494
+ t_616 = self.n_Conv_112(t_615)
495
+ t_617 = torch.add(t_616, t_611)
496
+ t_618 = F.relu(t_617)
497
+ t_619 = self.n_Conv_113(t_618)
498
+ t_620 = F.relu(t_619)
499
+ t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
500
+ t_621 = self.n_Conv_114(t_620_padded)
501
+ t_622 = F.relu(t_621)
502
+ t_623 = self.n_Conv_115(t_622)
503
+ t_624 = torch.add(t_623, t_618)
504
+ t_625 = F.relu(t_624)
505
+ t_626 = self.n_Conv_116(t_625)
506
+ t_627 = F.relu(t_626)
507
+ t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
508
+ t_628 = self.n_Conv_117(t_627_padded)
509
+ t_629 = F.relu(t_628)
510
+ t_630 = self.n_Conv_118(t_629)
511
+ t_631 = torch.add(t_630, t_625)
512
+ t_632 = F.relu(t_631)
513
+ t_633 = self.n_Conv_119(t_632)
514
+ t_634 = F.relu(t_633)
515
+ t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
516
+ t_635 = self.n_Conv_120(t_634_padded)
517
+ t_636 = F.relu(t_635)
518
+ t_637 = self.n_Conv_121(t_636)
519
+ t_638 = torch.add(t_637, t_632)
520
+ t_639 = F.relu(t_638)
521
+ t_640 = self.n_Conv_122(t_639)
522
+ t_641 = F.relu(t_640)
523
+ t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
524
+ t_642 = self.n_Conv_123(t_641_padded)
525
+ t_643 = F.relu(t_642)
526
+ t_644 = self.n_Conv_124(t_643)
527
+ t_645 = torch.add(t_644, t_639)
528
+ t_646 = F.relu(t_645)
529
+ t_647 = self.n_Conv_125(t_646)
530
+ t_648 = F.relu(t_647)
531
+ t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
532
+ t_649 = self.n_Conv_126(t_648_padded)
533
+ t_650 = F.relu(t_649)
534
+ t_651 = self.n_Conv_127(t_650)
535
+ t_652 = torch.add(t_651, t_646)
536
+ t_653 = F.relu(t_652)
537
+ t_654 = self.n_Conv_128(t_653)
538
+ t_655 = F.relu(t_654)
539
+ t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
540
+ t_656 = self.n_Conv_129(t_655_padded)
541
+ t_657 = F.relu(t_656)
542
+ t_658 = self.n_Conv_130(t_657)
543
+ t_659 = torch.add(t_658, t_653)
544
+ t_660 = F.relu(t_659)
545
+ t_661 = self.n_Conv_131(t_660)
546
+ t_662 = F.relu(t_661)
547
+ t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
548
+ t_663 = self.n_Conv_132(t_662_padded)
549
+ t_664 = F.relu(t_663)
550
+ t_665 = self.n_Conv_133(t_664)
551
+ t_666 = torch.add(t_665, t_660)
552
+ t_667 = F.relu(t_666)
553
+ t_668 = self.n_Conv_134(t_667)
554
+ t_669 = F.relu(t_668)
555
+ t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
556
+ t_670 = self.n_Conv_135(t_669_padded)
557
+ t_671 = F.relu(t_670)
558
+ t_672 = self.n_Conv_136(t_671)
559
+ t_673 = torch.add(t_672, t_667)
560
+ t_674 = F.relu(t_673)
561
+ t_675 = self.n_Conv_137(t_674)
562
+ t_676 = F.relu(t_675)
563
+ t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
564
+ t_677 = self.n_Conv_138(t_676_padded)
565
+ t_678 = F.relu(t_677)
566
+ t_679 = self.n_Conv_139(t_678)
567
+ t_680 = torch.add(t_679, t_674)
568
+ t_681 = F.relu(t_680)
569
+ t_682 = self.n_Conv_140(t_681)
570
+ t_683 = F.relu(t_682)
571
+ t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
572
+ t_684 = self.n_Conv_141(t_683_padded)
573
+ t_685 = F.relu(t_684)
574
+ t_686 = self.n_Conv_142(t_685)
575
+ t_687 = torch.add(t_686, t_681)
576
+ t_688 = F.relu(t_687)
577
+ t_689 = self.n_Conv_143(t_688)
578
+ t_690 = F.relu(t_689)
579
+ t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
580
+ t_691 = self.n_Conv_144(t_690_padded)
581
+ t_692 = F.relu(t_691)
582
+ t_693 = self.n_Conv_145(t_692)
583
+ t_694 = torch.add(t_693, t_688)
584
+ t_695 = F.relu(t_694)
585
+ t_696 = self.n_Conv_146(t_695)
586
+ t_697 = F.relu(t_696)
587
+ t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
588
+ t_698 = self.n_Conv_147(t_697_padded)
589
+ t_699 = F.relu(t_698)
590
+ t_700 = self.n_Conv_148(t_699)
591
+ t_701 = torch.add(t_700, t_695)
592
+ t_702 = F.relu(t_701)
593
+ t_703 = self.n_Conv_149(t_702)
594
+ t_704 = F.relu(t_703)
595
+ t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
596
+ t_705 = self.n_Conv_150(t_704_padded)
597
+ t_706 = F.relu(t_705)
598
+ t_707 = self.n_Conv_151(t_706)
599
+ t_708 = torch.add(t_707, t_702)
600
+ t_709 = F.relu(t_708)
601
+ t_710 = self.n_Conv_152(t_709)
602
+ t_711 = F.relu(t_710)
603
+ t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
604
+ t_712 = self.n_Conv_153(t_711_padded)
605
+ t_713 = F.relu(t_712)
606
+ t_714 = self.n_Conv_154(t_713)
607
+ t_715 = torch.add(t_714, t_709)
608
+ t_716 = F.relu(t_715)
609
+ t_717 = self.n_Conv_155(t_716)
610
+ t_718 = F.relu(t_717)
611
+ t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
612
+ t_719 = self.n_Conv_156(t_718_padded)
613
+ t_720 = F.relu(t_719)
614
+ t_721 = self.n_Conv_157(t_720)
615
+ t_722 = torch.add(t_721, t_716)
616
+ t_723 = F.relu(t_722)
617
+ t_724 = self.n_Conv_158(t_723)
618
+ t_725 = self.n_Conv_159(t_723)
619
+ t_726 = F.relu(t_725)
620
+ t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
621
+ t_727 = self.n_Conv_160(t_726_padded)
622
+ t_728 = F.relu(t_727)
623
+ t_729 = self.n_Conv_161(t_728)
624
+ t_730 = torch.add(t_729, t_724)
625
+ t_731 = F.relu(t_730)
626
+ t_732 = self.n_Conv_162(t_731)
627
+ t_733 = F.relu(t_732)
628
+ t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
629
+ t_734 = self.n_Conv_163(t_733_padded)
630
+ t_735 = F.relu(t_734)
631
+ t_736 = self.n_Conv_164(t_735)
632
+ t_737 = torch.add(t_736, t_731)
633
+ t_738 = F.relu(t_737)
634
+ t_739 = self.n_Conv_165(t_738)
635
+ t_740 = F.relu(t_739)
636
+ t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
637
+ t_741 = self.n_Conv_166(t_740_padded)
638
+ t_742 = F.relu(t_741)
639
+ t_743 = self.n_Conv_167(t_742)
640
+ t_744 = torch.add(t_743, t_738)
641
+ t_745 = F.relu(t_744)
642
+ t_746 = self.n_Conv_168(t_745)
643
+ t_747 = self.n_Conv_169(t_745)
644
+ t_748 = F.relu(t_747)
645
+ t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
646
+ t_749 = self.n_Conv_170(t_748_padded)
647
+ t_750 = F.relu(t_749)
648
+ t_751 = self.n_Conv_171(t_750)
649
+ t_752 = torch.add(t_751, t_746)
650
+ t_753 = F.relu(t_752)
651
+ t_754 = self.n_Conv_172(t_753)
652
+ t_755 = F.relu(t_754)
653
+ t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
654
+ t_756 = self.n_Conv_173(t_755_padded)
655
+ t_757 = F.relu(t_756)
656
+ t_758 = self.n_Conv_174(t_757)
657
+ t_759 = torch.add(t_758, t_753)
658
+ t_760 = F.relu(t_759)
659
+ t_761 = self.n_Conv_175(t_760)
660
+ t_762 = F.relu(t_761)
661
+ t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
662
+ t_763 = self.n_Conv_176(t_762_padded)
663
+ t_764 = F.relu(t_763)
664
+ t_765 = self.n_Conv_177(t_764)
665
+ t_766 = torch.add(t_765, t_760)
666
+ t_767 = F.relu(t_766)
667
+ t_768 = self.n_Conv_178(t_767)
668
+ t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
669
+ t_770 = torch.squeeze(t_769, 3)
670
+ t_770 = torch.squeeze(t_770, 2)
671
+ t_771 = torch.sigmoid(t_770)
672
+ return t_771
673
+
674
+ def load_state_dict(self, state_dict, **kwargs):
675
+ self.tags = state_dict.get('tags', [])
676
+
677
+ super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
678
+
sd/stable-diffusion-webui/modules/errors.py CHANGED
@@ -1,43 +1,43 @@
1
- import sys
2
- import traceback
3
-
4
-
5
- def print_error_explanation(message):
6
- lines = message.strip().split("\n")
7
- max_len = max([len(x) for x in lines])
8
-
9
- print('=' * max_len, file=sys.stderr)
10
- for line in lines:
11
- print(line, file=sys.stderr)
12
- print('=' * max_len, file=sys.stderr)
13
-
14
-
15
- def display(e: Exception, task):
16
- print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
17
- print(traceback.format_exc(), file=sys.stderr)
18
-
19
- message = str(e)
20
- if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
21
- print_error_explanation("""
22
- The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
23
- See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
24
- """)
25
-
26
-
27
- already_displayed = {}
28
-
29
-
30
- def display_once(e: Exception, task):
31
- if task in already_displayed:
32
- return
33
-
34
- display(e, task)
35
-
36
- already_displayed[task] = 1
37
-
38
-
39
- def run(code, task):
40
- try:
41
- code()
42
- except Exception as e:
43
- display(task, e)
 
1
+ import sys
2
+ import traceback
3
+
4
+
5
+ def print_error_explanation(message):
6
+ lines = message.strip().split("\n")
7
+ max_len = max([len(x) for x in lines])
8
+
9
+ print('=' * max_len, file=sys.stderr)
10
+ for line in lines:
11
+ print(line, file=sys.stderr)
12
+ print('=' * max_len, file=sys.stderr)
13
+
14
+
15
+ def display(e: Exception, task):
16
+ print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
17
+ print(traceback.format_exc(), file=sys.stderr)
18
+
19
+ message = str(e)
20
+ if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
21
+ print_error_explanation("""
22
+ The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
23
+ See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
24
+ """)
25
+
26
+
27
+ already_displayed = {}
28
+
29
+
30
+ def display_once(e: Exception, task):
31
+ if task in already_displayed:
32
+ return
33
+
34
+ display(e, task)
35
+
36
+ already_displayed[task] = 1
37
+
38
+
39
+ def run(code, task):
40
+ try:
41
+ code()
42
+ except Exception as e:
43
+ display(task, e)
sd/stable-diffusion-webui/modules/esrgan_model.py CHANGED
@@ -1,233 +1,233 @@
1
- import os
2
-
3
- import numpy as np
4
- import torch
5
- from PIL import Image
6
- from basicsr.utils.download_util import load_file_from_url
7
-
8
- import modules.esrgan_model_arch as arch
9
- from modules import shared, modelloader, images, devices
10
- from modules.upscaler import Upscaler, UpscalerData
11
- from modules.shared import opts
12
-
13
-
14
-
15
- def mod2normal(state_dict):
16
- # this code is copied from https://github.com/victorca25/iNNfer
17
- if 'conv_first.weight' in state_dict:
18
- crt_net = {}
19
- items = []
20
- for k, v in state_dict.items():
21
- items.append(k)
22
-
23
- crt_net['model.0.weight'] = state_dict['conv_first.weight']
24
- crt_net['model.0.bias'] = state_dict['conv_first.bias']
25
-
26
- for k in items.copy():
27
- if 'RDB' in k:
28
- ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
29
- if '.weight' in k:
30
- ori_k = ori_k.replace('.weight', '.0.weight')
31
- elif '.bias' in k:
32
- ori_k = ori_k.replace('.bias', '.0.bias')
33
- crt_net[ori_k] = state_dict[k]
34
- items.remove(k)
35
-
36
- crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
37
- crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
38
- crt_net['model.3.weight'] = state_dict['upconv1.weight']
39
- crt_net['model.3.bias'] = state_dict['upconv1.bias']
40
- crt_net['model.6.weight'] = state_dict['upconv2.weight']
41
- crt_net['model.6.bias'] = state_dict['upconv2.bias']
42
- crt_net['model.8.weight'] = state_dict['HRconv.weight']
43
- crt_net['model.8.bias'] = state_dict['HRconv.bias']
44
- crt_net['model.10.weight'] = state_dict['conv_last.weight']
45
- crt_net['model.10.bias'] = state_dict['conv_last.bias']
46
- state_dict = crt_net
47
- return state_dict
48
-
49
-
50
- def resrgan2normal(state_dict, nb=23):
51
- # this code is copied from https://github.com/victorca25/iNNfer
52
- if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
53
- re8x = 0
54
- crt_net = {}
55
- items = []
56
- for k, v in state_dict.items():
57
- items.append(k)
58
-
59
- crt_net['model.0.weight'] = state_dict['conv_first.weight']
60
- crt_net['model.0.bias'] = state_dict['conv_first.bias']
61
-
62
- for k in items.copy():
63
- if "rdb" in k:
64
- ori_k = k.replace('body.', 'model.1.sub.')
65
- ori_k = ori_k.replace('.rdb', '.RDB')
66
- if '.weight' in k:
67
- ori_k = ori_k.replace('.weight', '.0.weight')
68
- elif '.bias' in k:
69
- ori_k = ori_k.replace('.bias', '.0.bias')
70
- crt_net[ori_k] = state_dict[k]
71
- items.remove(k)
72
-
73
- crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
74
- crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
75
- crt_net['model.3.weight'] = state_dict['conv_up1.weight']
76
- crt_net['model.3.bias'] = state_dict['conv_up1.bias']
77
- crt_net['model.6.weight'] = state_dict['conv_up2.weight']
78
- crt_net['model.6.bias'] = state_dict['conv_up2.bias']
79
-
80
- if 'conv_up3.weight' in state_dict:
81
- # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
82
- re8x = 3
83
- crt_net['model.9.weight'] = state_dict['conv_up3.weight']
84
- crt_net['model.9.bias'] = state_dict['conv_up3.bias']
85
-
86
- crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
87
- crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
88
- crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
89
- crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
90
-
91
- state_dict = crt_net
92
- return state_dict
93
-
94
-
95
- def infer_params(state_dict):
96
- # this code is copied from https://github.com/victorca25/iNNfer
97
- scale2x = 0
98
- scalemin = 6
99
- n_uplayer = 0
100
- plus = False
101
-
102
- for block in list(state_dict):
103
- parts = block.split(".")
104
- n_parts = len(parts)
105
- if n_parts == 5 and parts[2] == "sub":
106
- nb = int(parts[3])
107
- elif n_parts == 3:
108
- part_num = int(parts[1])
109
- if (part_num > scalemin
110
- and parts[0] == "model"
111
- and parts[2] == "weight"):
112
- scale2x += 1
113
- if part_num > n_uplayer:
114
- n_uplayer = part_num
115
- out_nc = state_dict[block].shape[0]
116
- if not plus and "conv1x1" in block:
117
- plus = True
118
-
119
- nf = state_dict["model.0.weight"].shape[0]
120
- in_nc = state_dict["model.0.weight"].shape[1]
121
- out_nc = out_nc
122
- scale = 2 ** scale2x
123
-
124
- return in_nc, out_nc, nf, nb, plus, scale
125
-
126
-
127
- class UpscalerESRGAN(Upscaler):
128
- def __init__(self, dirname):
129
- self.name = "ESRGAN"
130
- self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
131
- self.model_name = "ESRGAN_4x"
132
- self.scalers = []
133
- self.user_path = dirname
134
- super().__init__()
135
- model_paths = self.find_models(ext_filter=[".pt", ".pth"])
136
- scalers = []
137
- if len(model_paths) == 0:
138
- scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
139
- scalers.append(scaler_data)
140
- for file in model_paths:
141
- if "http" in file:
142
- name = self.model_name
143
- else:
144
- name = modelloader.friendly_name(file)
145
-
146
- scaler_data = UpscalerData(name, file, self, 4)
147
- self.scalers.append(scaler_data)
148
-
149
- def do_upscale(self, img, selected_model):
150
- model = self.load_model(selected_model)
151
- if model is None:
152
- return img
153
- model.to(devices.device_esrgan)
154
- img = esrgan_upscale(model, img)
155
- return img
156
-
157
- def load_model(self, path: str):
158
- if "http" in path:
159
- filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
160
- file_name="%s.pth" % self.model_name,
161
- progress=True)
162
- else:
163
- filename = path
164
- if not os.path.exists(filename) or filename is None:
165
- print("Unable to load %s from %s" % (self.model_path, filename))
166
- return None
167
-
168
- state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
169
-
170
- if "params_ema" in state_dict:
171
- state_dict = state_dict["params_ema"]
172
- elif "params" in state_dict:
173
- state_dict = state_dict["params"]
174
- num_conv = 16 if "realesr-animevideov3" in filename else 32
175
- model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
176
- model.load_state_dict(state_dict)
177
- model.eval()
178
- return model
179
-
180
- if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
181
- nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
182
- state_dict = resrgan2normal(state_dict, nb)
183
- elif "conv_first.weight" in state_dict:
184
- state_dict = mod2normal(state_dict)
185
- elif "model.0.weight" not in state_dict:
186
- raise Exception("The file is not a recognized ESRGAN model.")
187
-
188
- in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
189
-
190
- model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
191
- model.load_state_dict(state_dict)
192
- model.eval()
193
-
194
- return model
195
-
196
-
197
- def upscale_without_tiling(model, img):
198
- img = np.array(img)
199
- img = img[:, :, ::-1]
200
- img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
201
- img = torch.from_numpy(img).float()
202
- img = img.unsqueeze(0).to(devices.device_esrgan)
203
- with torch.no_grad():
204
- output = model(img)
205
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
206
- output = 255. * np.moveaxis(output, 0, 2)
207
- output = output.astype(np.uint8)
208
- output = output[:, :, ::-1]
209
- return Image.fromarray(output, 'RGB')
210
-
211
-
212
- def esrgan_upscale(model, img):
213
- if opts.ESRGAN_tile == 0:
214
- return upscale_without_tiling(model, img)
215
-
216
- grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
217
- newtiles = []
218
- scale_factor = 1
219
-
220
- for y, h, row in grid.tiles:
221
- newrow = []
222
- for tiledata in row:
223
- x, w, tile = tiledata
224
-
225
- output = upscale_without_tiling(model, tile)
226
- scale_factor = output.width // tile.width
227
-
228
- newrow.append([x * scale_factor, w * scale_factor, output])
229
- newtiles.append([y * scale_factor, h * scale_factor, newrow])
230
-
231
- newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
232
- output = images.combine_grid(newgrid)
233
- return output
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from basicsr.utils.download_util import load_file_from_url
7
+
8
+ import modules.esrgan_model_arch as arch
9
+ from modules import shared, modelloader, images, devices
10
+ from modules.upscaler import Upscaler, UpscalerData
11
+ from modules.shared import opts
12
+
13
+
14
+
15
+ def mod2normal(state_dict):
16
+ # this code is copied from https://github.com/victorca25/iNNfer
17
+ if 'conv_first.weight' in state_dict:
18
+ crt_net = {}
19
+ items = []
20
+ for k, v in state_dict.items():
21
+ items.append(k)
22
+
23
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
24
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
25
+
26
+ for k in items.copy():
27
+ if 'RDB' in k:
28
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
29
+ if '.weight' in k:
30
+ ori_k = ori_k.replace('.weight', '.0.weight')
31
+ elif '.bias' in k:
32
+ ori_k = ori_k.replace('.bias', '.0.bias')
33
+ crt_net[ori_k] = state_dict[k]
34
+ items.remove(k)
35
+
36
+ crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
37
+ crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
38
+ crt_net['model.3.weight'] = state_dict['upconv1.weight']
39
+ crt_net['model.3.bias'] = state_dict['upconv1.bias']
40
+ crt_net['model.6.weight'] = state_dict['upconv2.weight']
41
+ crt_net['model.6.bias'] = state_dict['upconv2.bias']
42
+ crt_net['model.8.weight'] = state_dict['HRconv.weight']
43
+ crt_net['model.8.bias'] = state_dict['HRconv.bias']
44
+ crt_net['model.10.weight'] = state_dict['conv_last.weight']
45
+ crt_net['model.10.bias'] = state_dict['conv_last.bias']
46
+ state_dict = crt_net
47
+ return state_dict
48
+
49
+
50
+ def resrgan2normal(state_dict, nb=23):
51
+ # this code is copied from https://github.com/victorca25/iNNfer
52
+ if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
53
+ re8x = 0
54
+ crt_net = {}
55
+ items = []
56
+ for k, v in state_dict.items():
57
+ items.append(k)
58
+
59
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
60
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
61
+
62
+ for k in items.copy():
63
+ if "rdb" in k:
64
+ ori_k = k.replace('body.', 'model.1.sub.')
65
+ ori_k = ori_k.replace('.rdb', '.RDB')
66
+ if '.weight' in k:
67
+ ori_k = ori_k.replace('.weight', '.0.weight')
68
+ elif '.bias' in k:
69
+ ori_k = ori_k.replace('.bias', '.0.bias')
70
+ crt_net[ori_k] = state_dict[k]
71
+ items.remove(k)
72
+
73
+ crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
74
+ crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
75
+ crt_net['model.3.weight'] = state_dict['conv_up1.weight']
76
+ crt_net['model.3.bias'] = state_dict['conv_up1.bias']
77
+ crt_net['model.6.weight'] = state_dict['conv_up2.weight']
78
+ crt_net['model.6.bias'] = state_dict['conv_up2.bias']
79
+
80
+ if 'conv_up3.weight' in state_dict:
81
+ # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
82
+ re8x = 3
83
+ crt_net['model.9.weight'] = state_dict['conv_up3.weight']
84
+ crt_net['model.9.bias'] = state_dict['conv_up3.bias']
85
+
86
+ crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
87
+ crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
88
+ crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
89
+ crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
90
+
91
+ state_dict = crt_net
92
+ return state_dict
93
+
94
+
95
+ def infer_params(state_dict):
96
+ # this code is copied from https://github.com/victorca25/iNNfer
97
+ scale2x = 0
98
+ scalemin = 6
99
+ n_uplayer = 0
100
+ plus = False
101
+
102
+ for block in list(state_dict):
103
+ parts = block.split(".")
104
+ n_parts = len(parts)
105
+ if n_parts == 5 and parts[2] == "sub":
106
+ nb = int(parts[3])
107
+ elif n_parts == 3:
108
+ part_num = int(parts[1])
109
+ if (part_num > scalemin
110
+ and parts[0] == "model"
111
+ and parts[2] == "weight"):
112
+ scale2x += 1
113
+ if part_num > n_uplayer:
114
+ n_uplayer = part_num
115
+ out_nc = state_dict[block].shape[0]
116
+ if not plus and "conv1x1" in block:
117
+ plus = True
118
+
119
+ nf = state_dict["model.0.weight"].shape[0]
120
+ in_nc = state_dict["model.0.weight"].shape[1]
121
+ out_nc = out_nc
122
+ scale = 2 ** scale2x
123
+
124
+ return in_nc, out_nc, nf, nb, plus, scale
125
+
126
+
127
+ class UpscalerESRGAN(Upscaler):
128
+ def __init__(self, dirname):
129
+ self.name = "ESRGAN"
130
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
131
+ self.model_name = "ESRGAN_4x"
132
+ self.scalers = []
133
+ self.user_path = dirname
134
+ super().__init__()
135
+ model_paths = self.find_models(ext_filter=[".pt", ".pth"])
136
+ scalers = []
137
+ if len(model_paths) == 0:
138
+ scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
139
+ scalers.append(scaler_data)
140
+ for file in model_paths:
141
+ if "http" in file:
142
+ name = self.model_name
143
+ else:
144
+ name = modelloader.friendly_name(file)
145
+
146
+ scaler_data = UpscalerData(name, file, self, 4)
147
+ self.scalers.append(scaler_data)
148
+
149
+ def do_upscale(self, img, selected_model):
150
+ model = self.load_model(selected_model)
151
+ if model is None:
152
+ return img
153
+ model.to(devices.device_esrgan)
154
+ img = esrgan_upscale(model, img)
155
+ return img
156
+
157
+ def load_model(self, path: str):
158
+ if "http" in path:
159
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
160
+ file_name="%s.pth" % self.model_name,
161
+ progress=True)
162
+ else:
163
+ filename = path
164
+ if not os.path.exists(filename) or filename is None:
165
+ print("Unable to load %s from %s" % (self.model_path, filename))
166
+ return None
167
+
168
+ state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
169
+
170
+ if "params_ema" in state_dict:
171
+ state_dict = state_dict["params_ema"]
172
+ elif "params" in state_dict:
173
+ state_dict = state_dict["params"]
174
+ num_conv = 16 if "realesr-animevideov3" in filename else 32
175
+ model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
176
+ model.load_state_dict(state_dict)
177
+ model.eval()
178
+ return model
179
+
180
+ if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
181
+ nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
182
+ state_dict = resrgan2normal(state_dict, nb)
183
+ elif "conv_first.weight" in state_dict:
184
+ state_dict = mod2normal(state_dict)
185
+ elif "model.0.weight" not in state_dict:
186
+ raise Exception("The file is not a recognized ESRGAN model.")
187
+
188
+ in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
189
+
190
+ model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
191
+ model.load_state_dict(state_dict)
192
+ model.eval()
193
+
194
+ return model
195
+
196
+
197
+ def upscale_without_tiling(model, img):
198
+ img = np.array(img)
199
+ img = img[:, :, ::-1]
200
+ img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
201
+ img = torch.from_numpy(img).float()
202
+ img = img.unsqueeze(0).to(devices.device_esrgan)
203
+ with torch.no_grad():
204
+ output = model(img)
205
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
206
+ output = 255. * np.moveaxis(output, 0, 2)
207
+ output = output.astype(np.uint8)
208
+ output = output[:, :, ::-1]
209
+ return Image.fromarray(output, 'RGB')
210
+
211
+
212
+ def esrgan_upscale(model, img):
213
+ if opts.ESRGAN_tile == 0:
214
+ return upscale_without_tiling(model, img)
215
+
216
+ grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
217
+ newtiles = []
218
+ scale_factor = 1
219
+
220
+ for y, h, row in grid.tiles:
221
+ newrow = []
222
+ for tiledata in row:
223
+ x, w, tile = tiledata
224
+
225
+ output = upscale_without_tiling(model, tile)
226
+ scale_factor = output.width // tile.width
227
+
228
+ newrow.append([x * scale_factor, w * scale_factor, output])
229
+ newtiles.append([y * scale_factor, h * scale_factor, newrow])
230
+
231
+ newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
232
+ output = images.combine_grid(newgrid)
233
+ return output
sd/stable-diffusion-webui/modules/esrgan_model_arch.py CHANGED
@@ -1,464 +1,464 @@
1
- # this file is adapted from https://github.com/victorca25/iNNfer
2
-
3
- from collections import OrderedDict
4
- import math
5
- import functools
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
-
10
-
11
- ####################
12
- # RRDBNet Generator
13
- ####################
14
-
15
- class RRDBNet(nn.Module):
16
- def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
17
- act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
18
- finalact=None, gaussian_noise=False, plus=False):
19
- super(RRDBNet, self).__init__()
20
- n_upscale = int(math.log(upscale, 2))
21
- if upscale == 3:
22
- n_upscale = 1
23
-
24
- self.resrgan_scale = 0
25
- if in_nc % 16 == 0:
26
- self.resrgan_scale = 1
27
- elif in_nc != 4 and in_nc % 4 == 0:
28
- self.resrgan_scale = 2
29
-
30
- fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
31
- rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
32
- norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
33
- gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
34
- LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
35
-
36
- if upsample_mode == 'upconv':
37
- upsample_block = upconv_block
38
- elif upsample_mode == 'pixelshuffle':
39
- upsample_block = pixelshuffle_block
40
- else:
41
- raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
42
- if upscale == 3:
43
- upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
44
- else:
45
- upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
46
- HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
47
- HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
48
-
49
- outact = act(finalact) if finalact else None
50
-
51
- self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
52
- *upsampler, HR_conv0, HR_conv1, outact)
53
-
54
- def forward(self, x, outm=None):
55
- if self.resrgan_scale == 1:
56
- feat = pixel_unshuffle(x, scale=4)
57
- elif self.resrgan_scale == 2:
58
- feat = pixel_unshuffle(x, scale=2)
59
- else:
60
- feat = x
61
-
62
- return self.model(feat)
63
-
64
-
65
- class RRDB(nn.Module):
66
- """
67
- Residual in Residual Dense Block
68
- (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
69
- """
70
-
71
- def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
72
- norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
73
- spectral_norm=False, gaussian_noise=False, plus=False):
74
- super(RRDB, self).__init__()
75
- # This is for backwards compatibility with existing models
76
- if nr == 3:
77
- self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
78
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
79
- gaussian_noise=gaussian_noise, plus=plus)
80
- self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
81
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
82
- gaussian_noise=gaussian_noise, plus=plus)
83
- self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
84
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
85
- gaussian_noise=gaussian_noise, plus=plus)
86
- else:
87
- RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
88
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
89
- gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
90
- self.RDBs = nn.Sequential(*RDB_list)
91
-
92
- def forward(self, x):
93
- if hasattr(self, 'RDB1'):
94
- out = self.RDB1(x)
95
- out = self.RDB2(out)
96
- out = self.RDB3(out)
97
- else:
98
- out = self.RDBs(x)
99
- return out * 0.2 + x
100
-
101
-
102
- class ResidualDenseBlock_5C(nn.Module):
103
- """
104
- Residual Dense Block
105
- The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
106
- Modified options that can be used:
107
- - "Partial Convolution based Padding" arXiv:1811.11718
108
- - "Spectral normalization" arXiv:1802.05957
109
- - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
110
- {Rakotonirina} and A. {Rasoanaivo}
111
- """
112
-
113
- def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
114
- norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
115
- spectral_norm=False, gaussian_noise=False, plus=False):
116
- super(ResidualDenseBlock_5C, self).__init__()
117
-
118
- self.noise = GaussianNoise() if gaussian_noise else None
119
- self.conv1x1 = conv1x1(nf, gc) if plus else None
120
-
121
- self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
122
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
123
- spectral_norm=spectral_norm)
124
- self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
125
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
126
- spectral_norm=spectral_norm)
127
- self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
128
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
129
- spectral_norm=spectral_norm)
130
- self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
131
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
132
- spectral_norm=spectral_norm)
133
- if mode == 'CNA':
134
- last_act = None
135
- else:
136
- last_act = act_type
137
- self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
138
- norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
139
- spectral_norm=spectral_norm)
140
-
141
- def forward(self, x):
142
- x1 = self.conv1(x)
143
- x2 = self.conv2(torch.cat((x, x1), 1))
144
- if self.conv1x1:
145
- x2 = x2 + self.conv1x1(x)
146
- x3 = self.conv3(torch.cat((x, x1, x2), 1))
147
- x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
148
- if self.conv1x1:
149
- x4 = x4 + x2
150
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
151
- if self.noise:
152
- return self.noise(x5.mul(0.2) + x)
153
- else:
154
- return x5 * 0.2 + x
155
-
156
-
157
- ####################
158
- # ESRGANplus
159
- ####################
160
-
161
- class GaussianNoise(nn.Module):
162
- def __init__(self, sigma=0.1, is_relative_detach=False):
163
- super().__init__()
164
- self.sigma = sigma
165
- self.is_relative_detach = is_relative_detach
166
- self.noise = torch.tensor(0, dtype=torch.float)
167
-
168
- def forward(self, x):
169
- if self.training and self.sigma != 0:
170
- self.noise = self.noise.to(x.device)
171
- scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
172
- sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
173
- x = x + sampled_noise
174
- return x
175
-
176
- def conv1x1(in_planes, out_planes, stride=1):
177
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
178
-
179
-
180
- ####################
181
- # SRVGGNetCompact
182
- ####################
183
-
184
- class SRVGGNetCompact(nn.Module):
185
- """A compact VGG-style network structure for super-resolution.
186
- This class is copied from https://github.com/xinntao/Real-ESRGAN
187
- """
188
-
189
- def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
190
- super(SRVGGNetCompact, self).__init__()
191
- self.num_in_ch = num_in_ch
192
- self.num_out_ch = num_out_ch
193
- self.num_feat = num_feat
194
- self.num_conv = num_conv
195
- self.upscale = upscale
196
- self.act_type = act_type
197
-
198
- self.body = nn.ModuleList()
199
- # the first conv
200
- self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
201
- # the first activation
202
- if act_type == 'relu':
203
- activation = nn.ReLU(inplace=True)
204
- elif act_type == 'prelu':
205
- activation = nn.PReLU(num_parameters=num_feat)
206
- elif act_type == 'leakyrelu':
207
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
208
- self.body.append(activation)
209
-
210
- # the body structure
211
- for _ in range(num_conv):
212
- self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
213
- # activation
214
- if act_type == 'relu':
215
- activation = nn.ReLU(inplace=True)
216
- elif act_type == 'prelu':
217
- activation = nn.PReLU(num_parameters=num_feat)
218
- elif act_type == 'leakyrelu':
219
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
220
- self.body.append(activation)
221
-
222
- # the last conv
223
- self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
224
- # upsample
225
- self.upsampler = nn.PixelShuffle(upscale)
226
-
227
- def forward(self, x):
228
- out = x
229
- for i in range(0, len(self.body)):
230
- out = self.body[i](out)
231
-
232
- out = self.upsampler(out)
233
- # add the nearest upsampled image, so that the network learns the residual
234
- base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
235
- out += base
236
- return out
237
-
238
-
239
- ####################
240
- # Upsampler
241
- ####################
242
-
243
- class Upsample(nn.Module):
244
- r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
245
- The input data is assumed to be of the form
246
- `minibatch x channels x [optional depth] x [optional height] x width`.
247
- """
248
-
249
- def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
250
- super(Upsample, self).__init__()
251
- if isinstance(scale_factor, tuple):
252
- self.scale_factor = tuple(float(factor) for factor in scale_factor)
253
- else:
254
- self.scale_factor = float(scale_factor) if scale_factor else None
255
- self.mode = mode
256
- self.size = size
257
- self.align_corners = align_corners
258
-
259
- def forward(self, x):
260
- return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
261
-
262
- def extra_repr(self):
263
- if self.scale_factor is not None:
264
- info = 'scale_factor=' + str(self.scale_factor)
265
- else:
266
- info = 'size=' + str(self.size)
267
- info += ', mode=' + self.mode
268
- return info
269
-
270
-
271
- def pixel_unshuffle(x, scale):
272
- """ Pixel unshuffle.
273
- Args:
274
- x (Tensor): Input feature with shape (b, c, hh, hw).
275
- scale (int): Downsample ratio.
276
- Returns:
277
- Tensor: the pixel unshuffled feature.
278
- """
279
- b, c, hh, hw = x.size()
280
- out_channel = c * (scale**2)
281
- assert hh % scale == 0 and hw % scale == 0
282
- h = hh // scale
283
- w = hw // scale
284
- x_view = x.view(b, c, h, scale, w, scale)
285
- return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
286
-
287
-
288
- def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
289
- pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
290
- """
291
- Pixel shuffle layer
292
- (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
293
- Neural Network, CVPR17)
294
- """
295
- conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
296
- pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
297
- pixel_shuffle = nn.PixelShuffle(upscale_factor)
298
-
299
- n = norm(norm_type, out_nc) if norm_type else None
300
- a = act(act_type) if act_type else None
301
- return sequential(conv, pixel_shuffle, n, a)
302
-
303
-
304
- def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
305
- pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
306
- """ Upconv layer """
307
- upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
308
- upsample = Upsample(scale_factor=upscale_factor, mode=mode)
309
- conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
310
- pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
311
- return sequential(upsample, conv)
312
-
313
-
314
-
315
-
316
-
317
-
318
-
319
-
320
- ####################
321
- # Basic blocks
322
- ####################
323
-
324
-
325
- def make_layer(basic_block, num_basic_block, **kwarg):
326
- """Make layers by stacking the same blocks.
327
- Args:
328
- basic_block (nn.module): nn.module class for basic block. (block)
329
- num_basic_block (int): number of blocks. (n_layers)
330
- Returns:
331
- nn.Sequential: Stacked blocks in nn.Sequential.
332
- """
333
- layers = []
334
- for _ in range(num_basic_block):
335
- layers.append(basic_block(**kwarg))
336
- return nn.Sequential(*layers)
337
-
338
-
339
- def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
340
- """ activation helper """
341
- act_type = act_type.lower()
342
- if act_type == 'relu':
343
- layer = nn.ReLU(inplace)
344
- elif act_type in ('leakyrelu', 'lrelu'):
345
- layer = nn.LeakyReLU(neg_slope, inplace)
346
- elif act_type == 'prelu':
347
- layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
348
- elif act_type == 'tanh': # [-1, 1] range output
349
- layer = nn.Tanh()
350
- elif act_type == 'sigmoid': # [0, 1] range output
351
- layer = nn.Sigmoid()
352
- else:
353
- raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
354
- return layer
355
-
356
-
357
- class Identity(nn.Module):
358
- def __init__(self, *kwargs):
359
- super(Identity, self).__init__()
360
-
361
- def forward(self, x, *kwargs):
362
- return x
363
-
364
-
365
- def norm(norm_type, nc):
366
- """ Return a normalization layer """
367
- norm_type = norm_type.lower()
368
- if norm_type == 'batch':
369
- layer = nn.BatchNorm2d(nc, affine=True)
370
- elif norm_type == 'instance':
371
- layer = nn.InstanceNorm2d(nc, affine=False)
372
- elif norm_type == 'none':
373
- def norm_layer(x): return Identity()
374
- else:
375
- raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
376
- return layer
377
-
378
-
379
- def pad(pad_type, padding):
380
- """ padding layer helper """
381
- pad_type = pad_type.lower()
382
- if padding == 0:
383
- return None
384
- if pad_type == 'reflect':
385
- layer = nn.ReflectionPad2d(padding)
386
- elif pad_type == 'replicate':
387
- layer = nn.ReplicationPad2d(padding)
388
- elif pad_type == 'zero':
389
- layer = nn.ZeroPad2d(padding)
390
- else:
391
- raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
392
- return layer
393
-
394
-
395
- def get_valid_padding(kernel_size, dilation):
396
- kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
397
- padding = (kernel_size - 1) // 2
398
- return padding
399
-
400
-
401
- class ShortcutBlock(nn.Module):
402
- """ Elementwise sum the output of a submodule to its input """
403
- def __init__(self, submodule):
404
- super(ShortcutBlock, self).__init__()
405
- self.sub = submodule
406
-
407
- def forward(self, x):
408
- output = x + self.sub(x)
409
- return output
410
-
411
- def __repr__(self):
412
- return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
413
-
414
-
415
- def sequential(*args):
416
- """ Flatten Sequential. It unwraps nn.Sequential. """
417
- if len(args) == 1:
418
- if isinstance(args[0], OrderedDict):
419
- raise NotImplementedError('sequential does not support OrderedDict input.')
420
- return args[0] # No sequential is needed.
421
- modules = []
422
- for module in args:
423
- if isinstance(module, nn.Sequential):
424
- for submodule in module.children():
425
- modules.append(submodule)
426
- elif isinstance(module, nn.Module):
427
- modules.append(module)
428
- return nn.Sequential(*modules)
429
-
430
-
431
- def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
432
- pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
433
- spectral_norm=False):
434
- """ Conv layer with padding, normalization, activation """
435
- assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
436
- padding = get_valid_padding(kernel_size, dilation)
437
- p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
438
- padding = padding if pad_type == 'zero' else 0
439
-
440
- if convtype=='PartialConv2D':
441
- c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
442
- dilation=dilation, bias=bias, groups=groups)
443
- elif convtype=='DeformConv2D':
444
- c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
445
- dilation=dilation, bias=bias, groups=groups)
446
- elif convtype=='Conv3D':
447
- c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
448
- dilation=dilation, bias=bias, groups=groups)
449
- else:
450
- c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
451
- dilation=dilation, bias=bias, groups=groups)
452
-
453
- if spectral_norm:
454
- c = nn.utils.spectral_norm(c)
455
-
456
- a = act(act_type) if act_type else None
457
- if 'CNA' in mode:
458
- n = norm(norm_type, out_nc) if norm_type else None
459
- return sequential(p, c, n, a)
460
- elif mode == 'NAC':
461
- if norm_type is None and act_type is not None:
462
- a = act(act_type, inplace=False)
463
- n = norm(norm_type, in_nc) if norm_type else None
464
- return sequential(n, a, p, c)
 
1
+ # this file is adapted from https://github.com/victorca25/iNNfer
2
+
3
+ from collections import OrderedDict
4
+ import math
5
+ import functools
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ ####################
12
+ # RRDBNet Generator
13
+ ####################
14
+
15
+ class RRDBNet(nn.Module):
16
+ def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
17
+ act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
18
+ finalact=None, gaussian_noise=False, plus=False):
19
+ super(RRDBNet, self).__init__()
20
+ n_upscale = int(math.log(upscale, 2))
21
+ if upscale == 3:
22
+ n_upscale = 1
23
+
24
+ self.resrgan_scale = 0
25
+ if in_nc % 16 == 0:
26
+ self.resrgan_scale = 1
27
+ elif in_nc != 4 and in_nc % 4 == 0:
28
+ self.resrgan_scale = 2
29
+
30
+ fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
31
+ rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
32
+ norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
33
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
34
+ LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
35
+
36
+ if upsample_mode == 'upconv':
37
+ upsample_block = upconv_block
38
+ elif upsample_mode == 'pixelshuffle':
39
+ upsample_block = pixelshuffle_block
40
+ else:
41
+ raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
42
+ if upscale == 3:
43
+ upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
44
+ else:
45
+ upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
46
+ HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
47
+ HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
48
+
49
+ outact = act(finalact) if finalact else None
50
+
51
+ self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
52
+ *upsampler, HR_conv0, HR_conv1, outact)
53
+
54
+ def forward(self, x, outm=None):
55
+ if self.resrgan_scale == 1:
56
+ feat = pixel_unshuffle(x, scale=4)
57
+ elif self.resrgan_scale == 2:
58
+ feat = pixel_unshuffle(x, scale=2)
59
+ else:
60
+ feat = x
61
+
62
+ return self.model(feat)
63
+
64
+
65
+ class RRDB(nn.Module):
66
+ """
67
+ Residual in Residual Dense Block
68
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
69
+ """
70
+
71
+ def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
72
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
73
+ spectral_norm=False, gaussian_noise=False, plus=False):
74
+ super(RRDB, self).__init__()
75
+ # This is for backwards compatibility with existing models
76
+ if nr == 3:
77
+ self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
78
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
79
+ gaussian_noise=gaussian_noise, plus=plus)
80
+ self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
81
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
82
+ gaussian_noise=gaussian_noise, plus=plus)
83
+ self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
84
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
85
+ gaussian_noise=gaussian_noise, plus=plus)
86
+ else:
87
+ RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
88
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
89
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
90
+ self.RDBs = nn.Sequential(*RDB_list)
91
+
92
+ def forward(self, x):
93
+ if hasattr(self, 'RDB1'):
94
+ out = self.RDB1(x)
95
+ out = self.RDB2(out)
96
+ out = self.RDB3(out)
97
+ else:
98
+ out = self.RDBs(x)
99
+ return out * 0.2 + x
100
+
101
+
102
+ class ResidualDenseBlock_5C(nn.Module):
103
+ """
104
+ Residual Dense Block
105
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
106
+ Modified options that can be used:
107
+ - "Partial Convolution based Padding" arXiv:1811.11718
108
+ - "Spectral normalization" arXiv:1802.05957
109
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
110
+ {Rakotonirina} and A. {Rasoanaivo}
111
+ """
112
+
113
+ def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
114
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
115
+ spectral_norm=False, gaussian_noise=False, plus=False):
116
+ super(ResidualDenseBlock_5C, self).__init__()
117
+
118
+ self.noise = GaussianNoise() if gaussian_noise else None
119
+ self.conv1x1 = conv1x1(nf, gc) if plus else None
120
+
121
+ self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
122
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
123
+ spectral_norm=spectral_norm)
124
+ self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
125
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
126
+ spectral_norm=spectral_norm)
127
+ self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
128
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
129
+ spectral_norm=spectral_norm)
130
+ self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
131
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
132
+ spectral_norm=spectral_norm)
133
+ if mode == 'CNA':
134
+ last_act = None
135
+ else:
136
+ last_act = act_type
137
+ self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
138
+ norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
139
+ spectral_norm=spectral_norm)
140
+
141
+ def forward(self, x):
142
+ x1 = self.conv1(x)
143
+ x2 = self.conv2(torch.cat((x, x1), 1))
144
+ if self.conv1x1:
145
+ x2 = x2 + self.conv1x1(x)
146
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
147
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
148
+ if self.conv1x1:
149
+ x4 = x4 + x2
150
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
151
+ if self.noise:
152
+ return self.noise(x5.mul(0.2) + x)
153
+ else:
154
+ return x5 * 0.2 + x
155
+
156
+
157
+ ####################
158
+ # ESRGANplus
159
+ ####################
160
+
161
+ class GaussianNoise(nn.Module):
162
+ def __init__(self, sigma=0.1, is_relative_detach=False):
163
+ super().__init__()
164
+ self.sigma = sigma
165
+ self.is_relative_detach = is_relative_detach
166
+ self.noise = torch.tensor(0, dtype=torch.float)
167
+
168
+ def forward(self, x):
169
+ if self.training and self.sigma != 0:
170
+ self.noise = self.noise.to(x.device)
171
+ scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
172
+ sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
173
+ x = x + sampled_noise
174
+ return x
175
+
176
+ def conv1x1(in_planes, out_planes, stride=1):
177
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
178
+
179
+
180
+ ####################
181
+ # SRVGGNetCompact
182
+ ####################
183
+
184
+ class SRVGGNetCompact(nn.Module):
185
+ """A compact VGG-style network structure for super-resolution.
186
+ This class is copied from https://github.com/xinntao/Real-ESRGAN
187
+ """
188
+
189
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
190
+ super(SRVGGNetCompact, self).__init__()
191
+ self.num_in_ch = num_in_ch
192
+ self.num_out_ch = num_out_ch
193
+ self.num_feat = num_feat
194
+ self.num_conv = num_conv
195
+ self.upscale = upscale
196
+ self.act_type = act_type
197
+
198
+ self.body = nn.ModuleList()
199
+ # the first conv
200
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
201
+ # the first activation
202
+ if act_type == 'relu':
203
+ activation = nn.ReLU(inplace=True)
204
+ elif act_type == 'prelu':
205
+ activation = nn.PReLU(num_parameters=num_feat)
206
+ elif act_type == 'leakyrelu':
207
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
208
+ self.body.append(activation)
209
+
210
+ # the body structure
211
+ for _ in range(num_conv):
212
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
213
+ # activation
214
+ if act_type == 'relu':
215
+ activation = nn.ReLU(inplace=True)
216
+ elif act_type == 'prelu':
217
+ activation = nn.PReLU(num_parameters=num_feat)
218
+ elif act_type == 'leakyrelu':
219
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
220
+ self.body.append(activation)
221
+
222
+ # the last conv
223
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
224
+ # upsample
225
+ self.upsampler = nn.PixelShuffle(upscale)
226
+
227
+ def forward(self, x):
228
+ out = x
229
+ for i in range(0, len(self.body)):
230
+ out = self.body[i](out)
231
+
232
+ out = self.upsampler(out)
233
+ # add the nearest upsampled image, so that the network learns the residual
234
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
235
+ out += base
236
+ return out
237
+
238
+
239
+ ####################
240
+ # Upsampler
241
+ ####################
242
+
243
+ class Upsample(nn.Module):
244
+ r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
245
+ The input data is assumed to be of the form
246
+ `minibatch x channels x [optional depth] x [optional height] x width`.
247
+ """
248
+
249
+ def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
250
+ super(Upsample, self).__init__()
251
+ if isinstance(scale_factor, tuple):
252
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
253
+ else:
254
+ self.scale_factor = float(scale_factor) if scale_factor else None
255
+ self.mode = mode
256
+ self.size = size
257
+ self.align_corners = align_corners
258
+
259
+ def forward(self, x):
260
+ return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
261
+
262
+ def extra_repr(self):
263
+ if self.scale_factor is not None:
264
+ info = 'scale_factor=' + str(self.scale_factor)
265
+ else:
266
+ info = 'size=' + str(self.size)
267
+ info += ', mode=' + self.mode
268
+ return info
269
+
270
+
271
+ def pixel_unshuffle(x, scale):
272
+ """ Pixel unshuffle.
273
+ Args:
274
+ x (Tensor): Input feature with shape (b, c, hh, hw).
275
+ scale (int): Downsample ratio.
276
+ Returns:
277
+ Tensor: the pixel unshuffled feature.
278
+ """
279
+ b, c, hh, hw = x.size()
280
+ out_channel = c * (scale**2)
281
+ assert hh % scale == 0 and hw % scale == 0
282
+ h = hh // scale
283
+ w = hw // scale
284
+ x_view = x.view(b, c, h, scale, w, scale)
285
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
286
+
287
+
288
+ def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
289
+ pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
290
+ """
291
+ Pixel shuffle layer
292
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
293
+ Neural Network, CVPR17)
294
+ """
295
+ conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
296
+ pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
297
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
298
+
299
+ n = norm(norm_type, out_nc) if norm_type else None
300
+ a = act(act_type) if act_type else None
301
+ return sequential(conv, pixel_shuffle, n, a)
302
+
303
+
304
+ def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
305
+ pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
306
+ """ Upconv layer """
307
+ upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
308
+ upsample = Upsample(scale_factor=upscale_factor, mode=mode)
309
+ conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
310
+ pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
311
+ return sequential(upsample, conv)
312
+
313
+
314
+
315
+
316
+
317
+
318
+
319
+
320
+ ####################
321
+ # Basic blocks
322
+ ####################
323
+
324
+
325
+ def make_layer(basic_block, num_basic_block, **kwarg):
326
+ """Make layers by stacking the same blocks.
327
+ Args:
328
+ basic_block (nn.module): nn.module class for basic block. (block)
329
+ num_basic_block (int): number of blocks. (n_layers)
330
+ Returns:
331
+ nn.Sequential: Stacked blocks in nn.Sequential.
332
+ """
333
+ layers = []
334
+ for _ in range(num_basic_block):
335
+ layers.append(basic_block(**kwarg))
336
+ return nn.Sequential(*layers)
337
+
338
+
339
+ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
340
+ """ activation helper """
341
+ act_type = act_type.lower()
342
+ if act_type == 'relu':
343
+ layer = nn.ReLU(inplace)
344
+ elif act_type in ('leakyrelu', 'lrelu'):
345
+ layer = nn.LeakyReLU(neg_slope, inplace)
346
+ elif act_type == 'prelu':
347
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
348
+ elif act_type == 'tanh': # [-1, 1] range output
349
+ layer = nn.Tanh()
350
+ elif act_type == 'sigmoid': # [0, 1] range output
351
+ layer = nn.Sigmoid()
352
+ else:
353
+ raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
354
+ return layer
355
+
356
+
357
+ class Identity(nn.Module):
358
+ def __init__(self, *kwargs):
359
+ super(Identity, self).__init__()
360
+
361
+ def forward(self, x, *kwargs):
362
+ return x
363
+
364
+
365
+ def norm(norm_type, nc):
366
+ """ Return a normalization layer """
367
+ norm_type = norm_type.lower()
368
+ if norm_type == 'batch':
369
+ layer = nn.BatchNorm2d(nc, affine=True)
370
+ elif norm_type == 'instance':
371
+ layer = nn.InstanceNorm2d(nc, affine=False)
372
+ elif norm_type == 'none':
373
+ def norm_layer(x): return Identity()
374
+ else:
375
+ raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
376
+ return layer
377
+
378
+
379
+ def pad(pad_type, padding):
380
+ """ padding layer helper """
381
+ pad_type = pad_type.lower()
382
+ if padding == 0:
383
+ return None
384
+ if pad_type == 'reflect':
385
+ layer = nn.ReflectionPad2d(padding)
386
+ elif pad_type == 'replicate':
387
+ layer = nn.ReplicationPad2d(padding)
388
+ elif pad_type == 'zero':
389
+ layer = nn.ZeroPad2d(padding)
390
+ else:
391
+ raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
392
+ return layer
393
+
394
+
395
+ def get_valid_padding(kernel_size, dilation):
396
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
397
+ padding = (kernel_size - 1) // 2
398
+ return padding
399
+
400
+
401
+ class ShortcutBlock(nn.Module):
402
+ """ Elementwise sum the output of a submodule to its input """
403
+ def __init__(self, submodule):
404
+ super(ShortcutBlock, self).__init__()
405
+ self.sub = submodule
406
+
407
+ def forward(self, x):
408
+ output = x + self.sub(x)
409
+ return output
410
+
411
+ def __repr__(self):
412
+ return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
413
+
414
+
415
+ def sequential(*args):
416
+ """ Flatten Sequential. It unwraps nn.Sequential. """
417
+ if len(args) == 1:
418
+ if isinstance(args[0], OrderedDict):
419
+ raise NotImplementedError('sequential does not support OrderedDict input.')
420
+ return args[0] # No sequential is needed.
421
+ modules = []
422
+ for module in args:
423
+ if isinstance(module, nn.Sequential):
424
+ for submodule in module.children():
425
+ modules.append(submodule)
426
+ elif isinstance(module, nn.Module):
427
+ modules.append(module)
428
+ return nn.Sequential(*modules)
429
+
430
+
431
+ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
432
+ pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
433
+ spectral_norm=False):
434
+ """ Conv layer with padding, normalization, activation """
435
+ assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
436
+ padding = get_valid_padding(kernel_size, dilation)
437
+ p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
438
+ padding = padding if pad_type == 'zero' else 0
439
+
440
+ if convtype=='PartialConv2D':
441
+ c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
442
+ dilation=dilation, bias=bias, groups=groups)
443
+ elif convtype=='DeformConv2D':
444
+ c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
445
+ dilation=dilation, bias=bias, groups=groups)
446
+ elif convtype=='Conv3D':
447
+ c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
448
+ dilation=dilation, bias=bias, groups=groups)
449
+ else:
450
+ c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
451
+ dilation=dilation, bias=bias, groups=groups)
452
+
453
+ if spectral_norm:
454
+ c = nn.utils.spectral_norm(c)
455
+
456
+ a = act(act_type) if act_type else None
457
+ if 'CNA' in mode:
458
+ n = norm(norm_type, out_nc) if norm_type else None
459
+ return sequential(p, c, n, a)
460
+ elif mode == 'NAC':
461
+ if norm_type is None and act_type is not None:
462
+ a = act(act_type, inplace=False)
463
+ n = norm(norm_type, in_nc) if norm_type else None
464
+ return sequential(n, a, p, c)
sd/stable-diffusion-webui/modules/extensions.py CHANGED
@@ -1,107 +1,107 @@
1
- import os
2
- import sys
3
- import traceback
4
-
5
- import time
6
- import git
7
-
8
- from modules import paths, shared
9
-
10
- extensions = []
11
- extensions_dir = os.path.join(paths.data_path, "extensions")
12
- extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
13
-
14
- if not os.path.exists(extensions_dir):
15
- os.makedirs(extensions_dir)
16
-
17
- def active():
18
- return [x for x in extensions if x.enabled]
19
-
20
-
21
- class Extension:
22
- def __init__(self, name, path, enabled=True, is_builtin=False):
23
- self.name = name
24
- self.path = path
25
- self.enabled = enabled
26
- self.status = ''
27
- self.can_update = False
28
- self.is_builtin = is_builtin
29
- self.version = ''
30
-
31
- repo = None
32
- try:
33
- if os.path.exists(os.path.join(path, ".git")):
34
- repo = git.Repo(path)
35
- except Exception:
36
- print(f"Error reading github repository info from {path}:", file=sys.stderr)
37
- print(traceback.format_exc(), file=sys.stderr)
38
-
39
- if repo is None or repo.bare:
40
- self.remote = None
41
- else:
42
- try:
43
- self.remote = next(repo.remote().urls, None)
44
- self.status = 'unknown'
45
- head = repo.head.commit
46
- ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
47
- self.version = f'{head.hexsha[:8]} ({ts})'
48
-
49
- except Exception:
50
- self.remote = None
51
-
52
- def list_files(self, subdir, extension):
53
- from modules import scripts
54
-
55
- dirpath = os.path.join(self.path, subdir)
56
- if not os.path.isdir(dirpath):
57
- return []
58
-
59
- res = []
60
- for filename in sorted(os.listdir(dirpath)):
61
- res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
62
-
63
- res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
64
-
65
- return res
66
-
67
- def check_updates(self):
68
- repo = git.Repo(self.path)
69
- for fetch in repo.remote().fetch("--dry-run"):
70
- if fetch.flags != fetch.HEAD_UPTODATE:
71
- self.can_update = True
72
- self.status = "behind"
73
- return
74
-
75
- self.can_update = False
76
- self.status = "latest"
77
-
78
- def fetch_and_reset_hard(self):
79
- repo = git.Repo(self.path)
80
- # Fix: `error: Your local changes to the following files would be overwritten by merge`,
81
- # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
82
- repo.git.fetch('--all')
83
- repo.git.reset('--hard', 'origin')
84
-
85
-
86
- def list_extensions():
87
- extensions.clear()
88
-
89
- if not os.path.isdir(extensions_dir):
90
- return
91
-
92
- paths = []
93
- for dirname in [extensions_dir, extensions_builtin_dir]:
94
- if not os.path.isdir(dirname):
95
- return
96
-
97
- for extension_dirname in sorted(os.listdir(dirname)):
98
- path = os.path.join(dirname, extension_dirname)
99
- if not os.path.isdir(path):
100
- continue
101
-
102
- paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
103
-
104
- for dirname, path, is_builtin in paths:
105
- extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
106
- extensions.append(extension)
107
-
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ import time
6
+ import git
7
+
8
+ from modules import paths, shared
9
+
10
+ extensions = []
11
+ extensions_dir = os.path.join(paths.data_path, "extensions")
12
+ extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
13
+
14
+ if not os.path.exists(extensions_dir):
15
+ os.makedirs(extensions_dir)
16
+
17
+ def active():
18
+ return [x for x in extensions if x.enabled]
19
+
20
+
21
+ class Extension:
22
+ def __init__(self, name, path, enabled=True, is_builtin=False):
23
+ self.name = name
24
+ self.path = path
25
+ self.enabled = enabled
26
+ self.status = ''
27
+ self.can_update = False
28
+ self.is_builtin = is_builtin
29
+ self.version = ''
30
+
31
+ repo = None
32
+ try:
33
+ if os.path.exists(os.path.join(path, ".git")):
34
+ repo = git.Repo(path)
35
+ except Exception:
36
+ print(f"Error reading github repository info from {path}:", file=sys.stderr)
37
+ print(traceback.format_exc(), file=sys.stderr)
38
+
39
+ if repo is None or repo.bare:
40
+ self.remote = None
41
+ else:
42
+ try:
43
+ self.remote = next(repo.remote().urls, None)
44
+ self.status = 'unknown'
45
+ head = repo.head.commit
46
+ ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
47
+ self.version = f'{head.hexsha[:8]} ({ts})'
48
+
49
+ except Exception:
50
+ self.remote = None
51
+
52
+ def list_files(self, subdir, extension):
53
+ from modules import scripts
54
+
55
+ dirpath = os.path.join(self.path, subdir)
56
+ if not os.path.isdir(dirpath):
57
+ return []
58
+
59
+ res = []
60
+ for filename in sorted(os.listdir(dirpath)):
61
+ res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
62
+
63
+ res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
64
+
65
+ return res
66
+
67
+ def check_updates(self):
68
+ repo = git.Repo(self.path)
69
+ for fetch in repo.remote().fetch("--dry-run"):
70
+ if fetch.flags != fetch.HEAD_UPTODATE:
71
+ self.can_update = True
72
+ self.status = "behind"
73
+ return
74
+
75
+ self.can_update = False
76
+ self.status = "latest"
77
+
78
+ def fetch_and_reset_hard(self):
79
+ repo = git.Repo(self.path)
80
+ # Fix: `error: Your local changes to the following files would be overwritten by merge`,
81
+ # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
82
+ repo.git.fetch('--all')
83
+ repo.git.reset('--hard', 'origin')
84
+
85
+
86
+ def list_extensions():
87
+ extensions.clear()
88
+
89
+ if not os.path.isdir(extensions_dir):
90
+ return
91
+
92
+ paths = []
93
+ for dirname in [extensions_dir, extensions_builtin_dir]:
94
+ if not os.path.isdir(dirname):
95
+ return
96
+
97
+ for extension_dirname in sorted(os.listdir(dirname)):
98
+ path = os.path.join(dirname, extension_dirname)
99
+ if not os.path.isdir(path):
100
+ continue
101
+
102
+ paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
103
+
104
+ for dirname, path, is_builtin in paths:
105
+ extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
106
+ extensions.append(extension)
107
+
sd/stable-diffusion-webui/modules/extra_networks.py CHANGED
@@ -1,147 +1,147 @@
1
- import re
2
- from collections import defaultdict
3
-
4
- from modules import errors
5
-
6
- extra_network_registry = {}
7
-
8
-
9
- def initialize():
10
- extra_network_registry.clear()
11
-
12
-
13
- def register_extra_network(extra_network):
14
- extra_network_registry[extra_network.name] = extra_network
15
-
16
-
17
- class ExtraNetworkParams:
18
- def __init__(self, items=None):
19
- self.items = items or []
20
-
21
-
22
- class ExtraNetwork:
23
- def __init__(self, name):
24
- self.name = name
25
-
26
- def activate(self, p, params_list):
27
- """
28
- Called by processing on every run. Whatever the extra network is meant to do should be activated here.
29
- Passes arguments related to this extra network in params_list.
30
- User passes arguments by specifying this in his prompt:
31
-
32
- <name:arg1:arg2:arg3>
33
-
34
- Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
35
- separated by colon.
36
-
37
- Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
38
- in this case, all effects of this extra networks should be disabled.
39
-
40
- Can be called multiple times before deactivate() - each new call should override the previous call completely.
41
-
42
- For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
43
-
44
- > "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
45
-
46
- params_list will be:
47
-
48
- [
49
- ExtraNetworkParams(items=["agm", "1.1"]),
50
- ExtraNetworkParams(items=["ray"])
51
- ]
52
-
53
- """
54
- raise NotImplementedError
55
-
56
- def deactivate(self, p):
57
- """
58
- Called at the end of processing for housekeeping. No need to do anything here.
59
- """
60
-
61
- raise NotImplementedError
62
-
63
-
64
- def activate(p, extra_network_data):
65
- """call activate for extra networks in extra_network_data in specified order, then call
66
- activate for all remaining registered networks with an empty argument list"""
67
-
68
- for extra_network_name, extra_network_args in extra_network_data.items():
69
- extra_network = extra_network_registry.get(extra_network_name, None)
70
- if extra_network is None:
71
- print(f"Skipping unknown extra network: {extra_network_name}")
72
- continue
73
-
74
- try:
75
- extra_network.activate(p, extra_network_args)
76
- except Exception as e:
77
- errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
78
-
79
- for extra_network_name, extra_network in extra_network_registry.items():
80
- args = extra_network_data.get(extra_network_name, None)
81
- if args is not None:
82
- continue
83
-
84
- try:
85
- extra_network.activate(p, [])
86
- except Exception as e:
87
- errors.display(e, f"activating extra network {extra_network_name}")
88
-
89
-
90
- def deactivate(p, extra_network_data):
91
- """call deactivate for extra networks in extra_network_data in specified order, then call
92
- deactivate for all remaining registered networks"""
93
-
94
- for extra_network_name, extra_network_args in extra_network_data.items():
95
- extra_network = extra_network_registry.get(extra_network_name, None)
96
- if extra_network is None:
97
- continue
98
-
99
- try:
100
- extra_network.deactivate(p)
101
- except Exception as e:
102
- errors.display(e, f"deactivating extra network {extra_network_name}")
103
-
104
- for extra_network_name, extra_network in extra_network_registry.items():
105
- args = extra_network_data.get(extra_network_name, None)
106
- if args is not None:
107
- continue
108
-
109
- try:
110
- extra_network.deactivate(p)
111
- except Exception as e:
112
- errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
113
-
114
-
115
- re_extra_net = re.compile(r"<(\w+):([^>]+)>")
116
-
117
-
118
- def parse_prompt(prompt):
119
- res = defaultdict(list)
120
-
121
- def found(m):
122
- name = m.group(1)
123
- args = m.group(2)
124
-
125
- res[name].append(ExtraNetworkParams(items=args.split(":")))
126
-
127
- return ""
128
-
129
- prompt = re.sub(re_extra_net, found, prompt)
130
-
131
- return prompt, res
132
-
133
-
134
- def parse_prompts(prompts):
135
- res = []
136
- extra_data = None
137
-
138
- for prompt in prompts:
139
- updated_prompt, parsed_extra_data = parse_prompt(prompt)
140
-
141
- if extra_data is None:
142
- extra_data = parsed_extra_data
143
-
144
- res.append(updated_prompt)
145
-
146
- return res, extra_data
147
-
 
1
+ import re
2
+ from collections import defaultdict
3
+
4
+ from modules import errors
5
+
6
+ extra_network_registry = {}
7
+
8
+
9
+ def initialize():
10
+ extra_network_registry.clear()
11
+
12
+
13
+ def register_extra_network(extra_network):
14
+ extra_network_registry[extra_network.name] = extra_network
15
+
16
+
17
+ class ExtraNetworkParams:
18
+ def __init__(self, items=None):
19
+ self.items = items or []
20
+
21
+
22
+ class ExtraNetwork:
23
+ def __init__(self, name):
24
+ self.name = name
25
+
26
+ def activate(self, p, params_list):
27
+ """
28
+ Called by processing on every run. Whatever the extra network is meant to do should be activated here.
29
+ Passes arguments related to this extra network in params_list.
30
+ User passes arguments by specifying this in his prompt:
31
+
32
+ <name:arg1:arg2:arg3>
33
+
34
+ Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
35
+ separated by colon.
36
+
37
+ Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
38
+ in this case, all effects of this extra networks should be disabled.
39
+
40
+ Can be called multiple times before deactivate() - each new call should override the previous call completely.
41
+
42
+ For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
43
+
44
+ > "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
45
+
46
+ params_list will be:
47
+
48
+ [
49
+ ExtraNetworkParams(items=["agm", "1.1"]),
50
+ ExtraNetworkParams(items=["ray"])
51
+ ]
52
+
53
+ """
54
+ raise NotImplementedError
55
+
56
+ def deactivate(self, p):
57
+ """
58
+ Called at the end of processing for housekeeping. No need to do anything here.
59
+ """
60
+
61
+ raise NotImplementedError
62
+
63
+
64
+ def activate(p, extra_network_data):
65
+ """call activate for extra networks in extra_network_data in specified order, then call
66
+ activate for all remaining registered networks with an empty argument list"""
67
+
68
+ for extra_network_name, extra_network_args in extra_network_data.items():
69
+ extra_network = extra_network_registry.get(extra_network_name, None)
70
+ if extra_network is None:
71
+ print(f"Skipping unknown extra network: {extra_network_name}")
72
+ continue
73
+
74
+ try:
75
+ extra_network.activate(p, extra_network_args)
76
+ except Exception as e:
77
+ errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
78
+
79
+ for extra_network_name, extra_network in extra_network_registry.items():
80
+ args = extra_network_data.get(extra_network_name, None)
81
+ if args is not None:
82
+ continue
83
+
84
+ try:
85
+ extra_network.activate(p, [])
86
+ except Exception as e:
87
+ errors.display(e, f"activating extra network {extra_network_name}")
88
+
89
+
90
+ def deactivate(p, extra_network_data):
91
+ """call deactivate for extra networks in extra_network_data in specified order, then call
92
+ deactivate for all remaining registered networks"""
93
+
94
+ for extra_network_name, extra_network_args in extra_network_data.items():
95
+ extra_network = extra_network_registry.get(extra_network_name, None)
96
+ if extra_network is None:
97
+ continue
98
+
99
+ try:
100
+ extra_network.deactivate(p)
101
+ except Exception as e:
102
+ errors.display(e, f"deactivating extra network {extra_network_name}")
103
+
104
+ for extra_network_name, extra_network in extra_network_registry.items():
105
+ args = extra_network_data.get(extra_network_name, None)
106
+ if args is not None:
107
+ continue
108
+
109
+ try:
110
+ extra_network.deactivate(p)
111
+ except Exception as e:
112
+ errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
113
+
114
+
115
+ re_extra_net = re.compile(r"<(\w+):([^>]+)>")
116
+
117
+
118
+ def parse_prompt(prompt):
119
+ res = defaultdict(list)
120
+
121
+ def found(m):
122
+ name = m.group(1)
123
+ args = m.group(2)
124
+
125
+ res[name].append(ExtraNetworkParams(items=args.split(":")))
126
+
127
+ return ""
128
+
129
+ prompt = re.sub(re_extra_net, found, prompt)
130
+
131
+ return prompt, res
132
+
133
+
134
+ def parse_prompts(prompts):
135
+ res = []
136
+ extra_data = None
137
+
138
+ for prompt in prompts:
139
+ updated_prompt, parsed_extra_data = parse_prompt(prompt)
140
+
141
+ if extra_data is None:
142
+ extra_data = parsed_extra_data
143
+
144
+ res.append(updated_prompt)
145
+
146
+ return res, extra_data
147
+
sd/stable-diffusion-webui/modules/extra_networks_hypernet.py CHANGED
@@ -1,27 +1,27 @@
1
- from modules import extra_networks, shared, extra_networks
2
- from modules.hypernetworks import hypernetwork
3
-
4
-
5
- class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
6
- def __init__(self):
7
- super().__init__('hypernet')
8
-
9
- def activate(self, p, params_list):
10
- additional = shared.opts.sd_hypernetwork
11
-
12
- if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
13
- p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
14
- params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
15
-
16
- names = []
17
- multipliers = []
18
- for params in params_list:
19
- assert len(params.items) > 0
20
-
21
- names.append(params.items[0])
22
- multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
23
-
24
- hypernetwork.load_hypernetworks(names, multipliers)
25
-
26
- def deactivate(self, p):
27
- pass
 
1
+ from modules import extra_networks, shared, extra_networks
2
+ from modules.hypernetworks import hypernetwork
3
+
4
+
5
+ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
6
+ def __init__(self):
7
+ super().__init__('hypernet')
8
+
9
+ def activate(self, p, params_list):
10
+ additional = shared.opts.sd_hypernetwork
11
+
12
+ if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
13
+ p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
14
+ params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
15
+
16
+ names = []
17
+ multipliers = []
18
+ for params in params_list:
19
+ assert len(params.items) > 0
20
+
21
+ names.append(params.items[0])
22
+ multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
23
+
24
+ hypernetwork.load_hypernetworks(names, multipliers)
25
+
26
+ def deactivate(self, p):
27
+ pass
sd/stable-diffusion-webui/modules/extras.py CHANGED
@@ -1,258 +1,258 @@
1
- import os
2
- import re
3
- import shutil
4
-
5
-
6
- import torch
7
- import tqdm
8
-
9
- from modules import shared, images, sd_models, sd_vae, sd_models_config
10
- from modules.ui_common import plaintext_to_html
11
- import gradio as gr
12
- import safetensors.torch
13
-
14
-
15
- def run_pnginfo(image):
16
- if image is None:
17
- return '', '', ''
18
-
19
- geninfo, items = images.read_info_from_image(image)
20
- items = {**{'parameters': geninfo}, **items}
21
-
22
- info = ''
23
- for key, text in items.items():
24
- info += f"""
25
- <div>
26
- <p><b>{plaintext_to_html(str(key))}</b></p>
27
- <p>{plaintext_to_html(str(text))}</p>
28
- </div>
29
- """.strip()+"\n"
30
-
31
- if len(info) == 0:
32
- message = "Nothing found in the image."
33
- info = f"<div><p>{message}<p></div>"
34
-
35
- return '', geninfo, info
36
-
37
-
38
- def create_config(ckpt_result, config_source, a, b, c):
39
- def config(x):
40
- res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
41
- return res if res != shared.sd_default_config else None
42
-
43
- if config_source == 0:
44
- cfg = config(a) or config(b) or config(c)
45
- elif config_source == 1:
46
- cfg = config(b)
47
- elif config_source == 2:
48
- cfg = config(c)
49
- else:
50
- cfg = None
51
-
52
- if cfg is None:
53
- return
54
-
55
- filename, _ = os.path.splitext(ckpt_result)
56
- checkpoint_filename = filename + ".yaml"
57
-
58
- print("Copying config:")
59
- print(" from:", cfg)
60
- print(" to:", checkpoint_filename)
61
- shutil.copyfile(cfg, checkpoint_filename)
62
-
63
-
64
- checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
65
-
66
-
67
- def to_half(tensor, enable):
68
- if enable and tensor.dtype == torch.float:
69
- return tensor.half()
70
-
71
- return tensor
72
-
73
-
74
- def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
75
- shared.state.begin()
76
- shared.state.job = 'model-merge'
77
-
78
- def fail(message):
79
- shared.state.textinfo = message
80
- shared.state.end()
81
- return [*[gr.update() for _ in range(4)], message]
82
-
83
- def weighted_sum(theta0, theta1, alpha):
84
- return ((1 - alpha) * theta0) + (alpha * theta1)
85
-
86
- def get_difference(theta1, theta2):
87
- return theta1 - theta2
88
-
89
- def add_difference(theta0, theta1_2_diff, alpha):
90
- return theta0 + (alpha * theta1_2_diff)
91
-
92
- def filename_weighted_sum():
93
- a = primary_model_info.model_name
94
- b = secondary_model_info.model_name
95
- Ma = round(1 - multiplier, 2)
96
- Mb = round(multiplier, 2)
97
-
98
- return f"{Ma}({a}) + {Mb}({b})"
99
-
100
- def filename_add_difference():
101
- a = primary_model_info.model_name
102
- b = secondary_model_info.model_name
103
- c = tertiary_model_info.model_name
104
- M = round(multiplier, 2)
105
-
106
- return f"{a} + {M}({b} - {c})"
107
-
108
- def filename_nothing():
109
- return primary_model_info.model_name
110
-
111
- theta_funcs = {
112
- "Weighted sum": (filename_weighted_sum, None, weighted_sum),
113
- "Add difference": (filename_add_difference, get_difference, add_difference),
114
- "No interpolation": (filename_nothing, None, None),
115
- }
116
- filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
117
- shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
118
-
119
- if not primary_model_name:
120
- return fail("Failed: Merging requires a primary model.")
121
-
122
- primary_model_info = sd_models.checkpoints_list[primary_model_name]
123
-
124
- if theta_func2 and not secondary_model_name:
125
- return fail("Failed: Merging requires a secondary model.")
126
-
127
- secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
128
-
129
- if theta_func1 and not tertiary_model_name:
130
- return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
131
-
132
- tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
133
-
134
- result_is_inpainting_model = False
135
- result_is_instruct_pix2pix_model = False
136
-
137
- if theta_func2:
138
- shared.state.textinfo = f"Loading B"
139
- print(f"Loading {secondary_model_info.filename}...")
140
- theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
141
- else:
142
- theta_1 = None
143
-
144
- if theta_func1:
145
- shared.state.textinfo = f"Loading C"
146
- print(f"Loading {tertiary_model_info.filename}...")
147
- theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
148
-
149
- shared.state.textinfo = 'Merging B and C'
150
- shared.state.sampling_steps = len(theta_1.keys())
151
- for key in tqdm.tqdm(theta_1.keys()):
152
- if key in checkpoint_dict_skip_on_merge:
153
- continue
154
-
155
- if 'model' in key:
156
- if key in theta_2:
157
- t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
158
- theta_1[key] = theta_func1(theta_1[key], t2)
159
- else:
160
- theta_1[key] = torch.zeros_like(theta_1[key])
161
-
162
- shared.state.sampling_step += 1
163
- del theta_2
164
-
165
- shared.state.nextjob()
166
-
167
- shared.state.textinfo = f"Loading {primary_model_info.filename}..."
168
- print(f"Loading {primary_model_info.filename}...")
169
- theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
170
-
171
- print("Merging...")
172
- shared.state.textinfo = 'Merging A and B'
173
- shared.state.sampling_steps = len(theta_0.keys())
174
- for key in tqdm.tqdm(theta_0.keys()):
175
- if theta_1 and 'model' in key and key in theta_1:
176
-
177
- if key in checkpoint_dict_skip_on_merge:
178
- continue
179
-
180
- a = theta_0[key]
181
- b = theta_1[key]
182
-
183
- # this enables merging an inpainting model (A) with another one (B);
184
- # where normal model would have 4 channels, for latenst space, inpainting model would
185
- # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
186
- if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
187
- if a.shape[1] == 4 and b.shape[1] == 9:
188
- raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
189
- if a.shape[1] == 4 and b.shape[1] == 8:
190
- raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
191
-
192
- if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
193
- theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
194
- result_is_instruct_pix2pix_model = True
195
- else:
196
- assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
197
- theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
198
- result_is_inpainting_model = True
199
- else:
200
- theta_0[key] = theta_func2(a, b, multiplier)
201
-
202
- theta_0[key] = to_half(theta_0[key], save_as_half)
203
-
204
- shared.state.sampling_step += 1
205
-
206
- del theta_1
207
-
208
- bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
209
- if bake_in_vae_filename is not None:
210
- print(f"Baking in VAE from {bake_in_vae_filename}")
211
- shared.state.textinfo = 'Baking in VAE'
212
- vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
213
-
214
- for key in vae_dict.keys():
215
- theta_0_key = 'first_stage_model.' + key
216
- if theta_0_key in theta_0:
217
- theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
218
-
219
- del vae_dict
220
-
221
- if save_as_half and not theta_func2:
222
- for key in theta_0.keys():
223
- theta_0[key] = to_half(theta_0[key], save_as_half)
224
-
225
- if discard_weights:
226
- regex = re.compile(discard_weights)
227
- for key in list(theta_0):
228
- if re.search(regex, key):
229
- theta_0.pop(key, None)
230
-
231
- ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
232
-
233
- filename = filename_generator() if custom_name == '' else custom_name
234
- filename += ".inpainting" if result_is_inpainting_model else ""
235
- filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
236
- filename += "." + checkpoint_format
237
-
238
- output_modelname = os.path.join(ckpt_dir, filename)
239
-
240
- shared.state.nextjob()
241
- shared.state.textinfo = "Saving"
242
- print(f"Saving to {output_modelname}...")
243
-
244
- _, extension = os.path.splitext(output_modelname)
245
- if extension.lower() == ".safetensors":
246
- safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
247
- else:
248
- torch.save(theta_0, output_modelname)
249
-
250
- sd_models.list_models()
251
-
252
- create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
253
-
254
- print(f"Checkpoint saved to {output_modelname}.")
255
- shared.state.textinfo = "Checkpoint saved"
256
- shared.state.end()
257
-
258
- return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
 
1
+ import os
2
+ import re
3
+ import shutil
4
+
5
+
6
+ import torch
7
+ import tqdm
8
+
9
+ from modules import shared, images, sd_models, sd_vae, sd_models_config
10
+ from modules.ui_common import plaintext_to_html
11
+ import gradio as gr
12
+ import safetensors.torch
13
+
14
+
15
+ def run_pnginfo(image):
16
+ if image is None:
17
+ return '', '', ''
18
+
19
+ geninfo, items = images.read_info_from_image(image)
20
+ items = {**{'parameters': geninfo}, **items}
21
+
22
+ info = ''
23
+ for key, text in items.items():
24
+ info += f"""
25
+ <div>
26
+ <p><b>{plaintext_to_html(str(key))}</b></p>
27
+ <p>{plaintext_to_html(str(text))}</p>
28
+ </div>
29
+ """.strip()+"\n"
30
+
31
+ if len(info) == 0:
32
+ message = "Nothing found in the image."
33
+ info = f"<div><p>{message}<p></div>"
34
+
35
+ return '', geninfo, info
36
+
37
+
38
+ def create_config(ckpt_result, config_source, a, b, c):
39
+ def config(x):
40
+ res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
41
+ return res if res != shared.sd_default_config else None
42
+
43
+ if config_source == 0:
44
+ cfg = config(a) or config(b) or config(c)
45
+ elif config_source == 1:
46
+ cfg = config(b)
47
+ elif config_source == 2:
48
+ cfg = config(c)
49
+ else:
50
+ cfg = None
51
+
52
+ if cfg is None:
53
+ return
54
+
55
+ filename, _ = os.path.splitext(ckpt_result)
56
+ checkpoint_filename = filename + ".yaml"
57
+
58
+ print("Copying config:")
59
+ print(" from:", cfg)
60
+ print(" to:", checkpoint_filename)
61
+ shutil.copyfile(cfg, checkpoint_filename)
62
+
63
+
64
+ checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
65
+
66
+
67
+ def to_half(tensor, enable):
68
+ if enable and tensor.dtype == torch.float:
69
+ return tensor.half()
70
+
71
+ return tensor
72
+
73
+
74
+ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
75
+ shared.state.begin()
76
+ shared.state.job = 'model-merge'
77
+
78
+ def fail(message):
79
+ shared.state.textinfo = message
80
+ shared.state.end()
81
+ return [*[gr.update() for _ in range(4)], message]
82
+
83
+ def weighted_sum(theta0, theta1, alpha):
84
+ return ((1 - alpha) * theta0) + (alpha * theta1)
85
+
86
+ def get_difference(theta1, theta2):
87
+ return theta1 - theta2
88
+
89
+ def add_difference(theta0, theta1_2_diff, alpha):
90
+ return theta0 + (alpha * theta1_2_diff)
91
+
92
+ def filename_weighted_sum():
93
+ a = primary_model_info.model_name
94
+ b = secondary_model_info.model_name
95
+ Ma = round(1 - multiplier, 2)
96
+ Mb = round(multiplier, 2)
97
+
98
+ return f"{Ma}({a}) + {Mb}({b})"
99
+
100
+ def filename_add_difference():
101
+ a = primary_model_info.model_name
102
+ b = secondary_model_info.model_name
103
+ c = tertiary_model_info.model_name
104
+ M = round(multiplier, 2)
105
+
106
+ return f"{a} + {M}({b} - {c})"
107
+
108
+ def filename_nothing():
109
+ return primary_model_info.model_name
110
+
111
+ theta_funcs = {
112
+ "Weighted sum": (filename_weighted_sum, None, weighted_sum),
113
+ "Add difference": (filename_add_difference, get_difference, add_difference),
114
+ "No interpolation": (filename_nothing, None, None),
115
+ }
116
+ filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
117
+ shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
118
+
119
+ if not primary_model_name:
120
+ return fail("Failed: Merging requires a primary model.")
121
+
122
+ primary_model_info = sd_models.checkpoints_list[primary_model_name]
123
+
124
+ if theta_func2 and not secondary_model_name:
125
+ return fail("Failed: Merging requires a secondary model.")
126
+
127
+ secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
128
+
129
+ if theta_func1 and not tertiary_model_name:
130
+ return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
131
+
132
+ tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
133
+
134
+ result_is_inpainting_model = False
135
+ result_is_instruct_pix2pix_model = False
136
+
137
+ if theta_func2:
138
+ shared.state.textinfo = f"Loading B"
139
+ print(f"Loading {secondary_model_info.filename}...")
140
+ theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
141
+ else:
142
+ theta_1 = None
143
+
144
+ if theta_func1:
145
+ shared.state.textinfo = f"Loading C"
146
+ print(f"Loading {tertiary_model_info.filename}...")
147
+ theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
148
+
149
+ shared.state.textinfo = 'Merging B and C'
150
+ shared.state.sampling_steps = len(theta_1.keys())
151
+ for key in tqdm.tqdm(theta_1.keys()):
152
+ if key in checkpoint_dict_skip_on_merge:
153
+ continue
154
+
155
+ if 'model' in key:
156
+ if key in theta_2:
157
+ t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
158
+ theta_1[key] = theta_func1(theta_1[key], t2)
159
+ else:
160
+ theta_1[key] = torch.zeros_like(theta_1[key])
161
+
162
+ shared.state.sampling_step += 1
163
+ del theta_2
164
+
165
+ shared.state.nextjob()
166
+
167
+ shared.state.textinfo = f"Loading {primary_model_info.filename}..."
168
+ print(f"Loading {primary_model_info.filename}...")
169
+ theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
170
+
171
+ print("Merging...")
172
+ shared.state.textinfo = 'Merging A and B'
173
+ shared.state.sampling_steps = len(theta_0.keys())
174
+ for key in tqdm.tqdm(theta_0.keys()):
175
+ if theta_1 and 'model' in key and key in theta_1:
176
+
177
+ if key in checkpoint_dict_skip_on_merge:
178
+ continue
179
+
180
+ a = theta_0[key]
181
+ b = theta_1[key]
182
+
183
+ # this enables merging an inpainting model (A) with another one (B);
184
+ # where normal model would have 4 channels, for latenst space, inpainting model would
185
+ # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
186
+ if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
187
+ if a.shape[1] == 4 and b.shape[1] == 9:
188
+ raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
189
+ if a.shape[1] == 4 and b.shape[1] == 8:
190
+ raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
191
+
192
+ if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
193
+ theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
194
+ result_is_instruct_pix2pix_model = True
195
+ else:
196
+ assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
197
+ theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
198
+ result_is_inpainting_model = True
199
+ else:
200
+ theta_0[key] = theta_func2(a, b, multiplier)
201
+
202
+ theta_0[key] = to_half(theta_0[key], save_as_half)
203
+
204
+ shared.state.sampling_step += 1
205
+
206
+ del theta_1
207
+
208
+ bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
209
+ if bake_in_vae_filename is not None:
210
+ print(f"Baking in VAE from {bake_in_vae_filename}")
211
+ shared.state.textinfo = 'Baking in VAE'
212
+ vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
213
+
214
+ for key in vae_dict.keys():
215
+ theta_0_key = 'first_stage_model.' + key
216
+ if theta_0_key in theta_0:
217
+ theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
218
+
219
+ del vae_dict
220
+
221
+ if save_as_half and not theta_func2:
222
+ for key in theta_0.keys():
223
+ theta_0[key] = to_half(theta_0[key], save_as_half)
224
+
225
+ if discard_weights:
226
+ regex = re.compile(discard_weights)
227
+ for key in list(theta_0):
228
+ if re.search(regex, key):
229
+ theta_0.pop(key, None)
230
+
231
+ ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
232
+
233
+ filename = filename_generator() if custom_name == '' else custom_name
234
+ filename += ".inpainting" if result_is_inpainting_model else ""
235
+ filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
236
+ filename += "." + checkpoint_format
237
+
238
+ output_modelname = os.path.join(ckpt_dir, filename)
239
+
240
+ shared.state.nextjob()
241
+ shared.state.textinfo = "Saving"
242
+ print(f"Saving to {output_modelname}...")
243
+
244
+ _, extension = os.path.splitext(output_modelname)
245
+ if extension.lower() == ".safetensors":
246
+ safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
247
+ else:
248
+ torch.save(theta_0, output_modelname)
249
+
250
+ sd_models.list_models()
251
+
252
+ create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
253
+
254
+ print(f"Checkpoint saved to {output_modelname}.")
255
+ shared.state.textinfo = "Checkpoint saved"
256
+ shared.state.end()
257
+
258
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
sd/stable-diffusion-webui/modules/face_restoration.py CHANGED
@@ -1,19 +1,19 @@
1
- from modules import shared
2
-
3
-
4
- class FaceRestoration:
5
- def name(self):
6
- return "None"
7
-
8
- def restore(self, np_image):
9
- return np_image
10
-
11
-
12
- def restore_faces(np_image):
13
- face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
14
- if len(face_restorers) == 0:
15
- return np_image
16
-
17
- face_restorer = face_restorers[0]
18
-
19
- return face_restorer.restore(np_image)
 
1
+ from modules import shared
2
+
3
+
4
+ class FaceRestoration:
5
+ def name(self):
6
+ return "None"
7
+
8
+ def restore(self, np_image):
9
+ return np_image
10
+
11
+
12
+ def restore_faces(np_image):
13
+ face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
14
+ if len(face_restorers) == 0:
15
+ return np_image
16
+
17
+ face_restorer = face_restorers[0]
18
+
19
+ return face_restorer.restore(np_image)
sd/stable-diffusion-webui/modules/generation_parameters_copypaste.py CHANGED
@@ -1,402 +1,408 @@
1
- import base64
2
- import html
3
- import io
4
- import math
5
- import os
6
- import re
7
- from pathlib import Path
8
-
9
- import gradio as gr
10
- from modules.paths import data_path
11
- from modules import shared, ui_tempdir, script_callbacks
12
- import tempfile
13
- from PIL import Image
14
-
15
- re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
16
- re_param = re.compile(re_param_code)
17
- re_imagesize = re.compile(r"^(\d+)x(\d+)$")
18
- re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
19
- type_of_gr_update = type(gr.update())
20
-
21
- paste_fields = {}
22
- registered_param_bindings = []
23
-
24
-
25
- class ParamBinding:
26
- def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
27
- self.paste_button = paste_button
28
- self.tabname = tabname
29
- self.source_text_component = source_text_component
30
- self.source_image_component = source_image_component
31
- self.source_tabname = source_tabname
32
- self.override_settings_component = override_settings_component
33
-
34
-
35
- def reset():
36
- paste_fields.clear()
37
-
38
-
39
- def quote(text):
40
- if ',' not in str(text):
41
- return text
42
-
43
- text = str(text)
44
- text = text.replace('\\', '\\\\')
45
- text = text.replace('"', '\\"')
46
- return f'"{text}"'
47
-
48
-
49
- def image_from_url_text(filedata):
50
- if filedata is None:
51
- return None
52
-
53
- if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
54
- filedata = filedata[0]
55
-
56
- if type(filedata) == dict and filedata.get("is_file", False):
57
- filename = filedata["name"]
58
- is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
59
- assert is_in_right_dir, 'trying to open image file outside of allowed directories'
60
-
61
- return Image.open(filename)
62
-
63
- if type(filedata) == list:
64
- if len(filedata) == 0:
65
- return None
66
-
67
- filedata = filedata[0]
68
-
69
- if filedata.startswith("data:image/png;base64,"):
70
- filedata = filedata[len("data:image/png;base64,"):]
71
-
72
- filedata = base64.decodebytes(filedata.encode('utf-8'))
73
- image = Image.open(io.BytesIO(filedata))
74
- return image
75
-
76
-
77
- def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
78
- paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
79
-
80
- # backwards compatibility for existing extensions
81
- import modules.ui
82
- if tabname == 'txt2img':
83
- modules.ui.txt2img_paste_fields = fields
84
- elif tabname == 'img2img':
85
- modules.ui.img2img_paste_fields = fields
86
-
87
-
88
- def create_buttons(tabs_list):
89
- buttons = {}
90
- for tab in tabs_list:
91
- buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
92
- return buttons
93
-
94
-
95
- def bind_buttons(buttons, send_image, send_generate_info):
96
- """old function for backwards compatibility; do not use this, use register_paste_params_button"""
97
- for tabname, button in buttons.items():
98
- source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
99
- source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
100
-
101
- register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
102
-
103
-
104
- def register_paste_params_button(binding: ParamBinding):
105
- registered_param_bindings.append(binding)
106
-
107
-
108
- def connect_paste_params_buttons():
109
- binding: ParamBinding
110
- for binding in registered_param_bindings:
111
- destination_image_component = paste_fields[binding.tabname]["init_img"]
112
- fields = paste_fields[binding.tabname]["fields"]
113
- override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
114
-
115
- destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
116
- destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
117
-
118
- if binding.source_image_component and destination_image_component:
119
- if isinstance(binding.source_image_component, gr.Gallery):
120
- func = send_image_and_dimensions if destination_width_component else image_from_url_text
121
- jsfunc = "extract_image_from_gallery"
122
- else:
123
- func = send_image_and_dimensions if destination_width_component else lambda x: x
124
- jsfunc = None
125
-
126
- binding.paste_button.click(
127
- fn=func,
128
- _js=jsfunc,
129
- inputs=[binding.source_image_component],
130
- outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
131
- )
132
-
133
- if binding.source_text_component is not None and fields is not None:
134
- connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
135
-
136
- if binding.source_tabname is not None and fields is not None:
137
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
138
- binding.paste_button.click(
139
- fn=lambda *x: x,
140
- inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
141
- outputs=[field for field, name in fields if name in paste_field_names],
142
- )
143
-
144
- binding.paste_button.click(
145
- fn=None,
146
- _js=f"switch_to_{binding.tabname}",
147
- inputs=None,
148
- outputs=None,
149
- )
150
-
151
-
152
- def send_image_and_dimensions(x):
153
- if isinstance(x, Image.Image):
154
- img = x
155
- else:
156
- img = image_from_url_text(x)
157
-
158
- if shared.opts.send_size and isinstance(img, Image.Image):
159
- w = img.width
160
- h = img.height
161
- else:
162
- w = gr.update()
163
- h = gr.update()
164
-
165
- return img, w, h
166
-
167
-
168
-
169
- def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
170
- """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
171
-
172
- Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
173
- parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
174
-
175
- If the infotext has no hash, then a hypernet with the same name will be selected instead.
176
- """
177
- hypernet_name = hypernet_name.lower()
178
- if hypernet_hash is not None:
179
- # Try to match the hash in the name
180
- for hypernet_key in shared.hypernetworks.keys():
181
- result = re_hypernet_hash.search(hypernet_key)
182
- if result is not None and result[1] == hypernet_hash:
183
- return hypernet_key
184
- else:
185
- # Fall back to a hypernet with the same name
186
- for hypernet_key in shared.hypernetworks.keys():
187
- if hypernet_key.lower().startswith(hypernet_name):
188
- return hypernet_key
189
-
190
- return None
191
-
192
-
193
- def restore_old_hires_fix_params(res):
194
- """for infotexts that specify old First pass size parameter, convert it into
195
- width, height, and hr scale"""
196
-
197
- firstpass_width = res.get('First pass size-1', None)
198
- firstpass_height = res.get('First pass size-2', None)
199
-
200
- if shared.opts.use_old_hires_fix_width_height:
201
- hires_width = int(res.get("Hires resize-1", 0))
202
- hires_height = int(res.get("Hires resize-2", 0))
203
-
204
- if hires_width and hires_height:
205
- res['Size-1'] = hires_width
206
- res['Size-2'] = hires_height
207
- return
208
-
209
- if firstpass_width is None or firstpass_height is None:
210
- return
211
-
212
- firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
213
- width = int(res.get("Size-1", 512))
214
- height = int(res.get("Size-2", 512))
215
-
216
- if firstpass_width == 0 or firstpass_height == 0:
217
- from modules import processing
218
- firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
219
-
220
- res['Size-1'] = firstpass_width
221
- res['Size-2'] = firstpass_height
222
- res['Hires resize-1'] = width
223
- res['Hires resize-2'] = height
224
-
225
-
226
- def parse_generation_parameters(x: str):
227
- """parses generation parameters string, the one you see in text field under the picture in UI:
228
- ```
229
- girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
230
- Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
231
- Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
232
- ```
233
-
234
- returns a dict with field values
235
- """
236
-
237
- res = {}
238
-
239
- prompt = ""
240
- negative_prompt = ""
241
-
242
- done_with_prompt = False
243
-
244
- *lines, lastline = x.strip().split("\n")
245
- if len(re_param.findall(lastline)) < 3:
246
- lines.append(lastline)
247
- lastline = ''
248
-
249
- for i, line in enumerate(lines):
250
- line = line.strip()
251
- if line.startswith("Negative prompt:"):
252
- done_with_prompt = True
253
- line = line[16:].strip()
254
-
255
- if done_with_prompt:
256
- negative_prompt += ("" if negative_prompt == "" else "\n") + line
257
- else:
258
- prompt += ("" if prompt == "" else "\n") + line
259
-
260
- res["Prompt"] = prompt
261
- res["Negative prompt"] = negative_prompt
262
-
263
- for k, v in re_param.findall(lastline):
264
- v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
265
- m = re_imagesize.match(v)
266
- if m is not None:
267
- res[k+"-1"] = m.group(1)
268
- res[k+"-2"] = m.group(2)
269
- else:
270
- res[k] = v
271
-
272
- # Missing CLIP skip means it was set to 1 (the default)
273
- if "Clip skip" not in res:
274
- res["Clip skip"] = "1"
275
-
276
- hypernet = res.get("Hypernet", None)
277
- if hypernet is not None:
278
- res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
279
-
280
- if "Hires resize-1" not in res:
281
- res["Hires resize-1"] = 0
282
- res["Hires resize-2"] = 0
283
-
284
- restore_old_hires_fix_params(res)
285
-
286
- return res
287
-
288
-
289
- settings_map = {}
290
-
291
- infotext_to_setting_name_mapping = [
292
- ('Clip skip', 'CLIP_stop_at_last_layers', ),
293
- ('Conditional mask weight', 'inpainting_mask_weight'),
294
- ('Model hash', 'sd_model_checkpoint'),
295
- ('ENSD', 'eta_noise_seed_delta'),
296
- ('Noise multiplier', 'initial_noise_multiplier'),
297
- ('Eta', 'eta_ancestral'),
298
- ('Eta DDIM', 'eta_ddim'),
299
- ('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
300
- ]
301
-
302
-
303
- def create_override_settings_dict(text_pairs):
304
- """creates processing's override_settings parameters from gradio's multiselect
305
-
306
- Example input:
307
- ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
308
-
309
- Example output:
310
- {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
311
- """
312
-
313
- res = {}
314
-
315
- params = {}
316
- for pair in text_pairs:
317
- k, v = pair.split(":", maxsplit=1)
318
-
319
- params[k] = v.strip()
320
-
321
- for param_name, setting_name in infotext_to_setting_name_mapping:
322
- value = params.get(param_name, None)
323
-
324
- if value is None:
325
- continue
326
-
327
- res[setting_name] = shared.opts.cast_value(setting_name, value)
328
-
329
- return res
330
-
331
-
332
- def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
333
- def paste_func(prompt):
334
- if not prompt and not shared.cmd_opts.hide_ui_dir_config:
335
- filename = os.path.join(data_path, "params.txt")
336
- if os.path.exists(filename):
337
- with open(filename, "r", encoding="utf8") as file:
338
- prompt = file.read()
339
-
340
- params = parse_generation_parameters(prompt)
341
- script_callbacks.infotext_pasted_callback(prompt, params)
342
- res = []
343
-
344
- for output, key in paste_fields:
345
- if callable(key):
346
- v = key(params)
347
- else:
348
- v = params.get(key, None)
349
-
350
- if v is None:
351
- res.append(gr.update())
352
- elif isinstance(v, type_of_gr_update):
353
- res.append(v)
354
- else:
355
- try:
356
- valtype = type(output.value)
357
-
358
- if valtype == bool and v == "False":
359
- val = False
360
- else:
361
- val = valtype(v)
362
-
363
- res.append(gr.update(value=val))
364
- except Exception:
365
- res.append(gr.update())
366
-
367
- return res
368
-
369
- if override_settings_component is not None:
370
- def paste_settings(params):
371
- vals = {}
372
-
373
- for param_name, setting_name in infotext_to_setting_name_mapping:
374
- v = params.get(param_name, None)
375
- if v is None:
376
- continue
377
-
378
- if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
379
- continue
380
-
381
- v = shared.opts.cast_value(setting_name, v)
382
- current_value = getattr(shared.opts, setting_name, None)
383
-
384
- if v == current_value:
385
- continue
386
-
387
- vals[param_name] = v
388
-
389
- vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
390
-
391
- return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
392
-
393
- paste_fields = paste_fields + [(override_settings_component, paste_settings)]
394
-
395
- button.click(
396
- fn=paste_func,
397
- _js=f"recalculate_prompts_{tabname}",
398
- inputs=[input_comp],
399
- outputs=[x[0] for x in paste_fields],
400
- )
401
-
402
-
 
 
 
 
 
 
 
1
+ import base64
2
+ import html
3
+ import io
4
+ import math
5
+ import os
6
+ import re
7
+ from pathlib import Path
8
+
9
+ import gradio as gr
10
+ from modules.paths import data_path
11
+ from modules import shared, ui_tempdir, script_callbacks
12
+ import tempfile
13
+ from PIL import Image
14
+
15
+ re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
16
+ re_param = re.compile(re_param_code)
17
+ re_imagesize = re.compile(r"^(\d+)x(\d+)$")
18
+ re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
19
+ type_of_gr_update = type(gr.update())
20
+
21
+ paste_fields = {}
22
+ registered_param_bindings = []
23
+
24
+
25
+ class ParamBinding:
26
+ def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
27
+ self.paste_button = paste_button
28
+ self.tabname = tabname
29
+ self.source_text_component = source_text_component
30
+ self.source_image_component = source_image_component
31
+ self.source_tabname = source_tabname
32
+ self.override_settings_component = override_settings_component
33
+
34
+
35
+ def reset():
36
+ paste_fields.clear()
37
+
38
+
39
+ def quote(text):
40
+ if ',' not in str(text):
41
+ return text
42
+
43
+ text = str(text)
44
+ text = text.replace('\\', '\\\\')
45
+ text = text.replace('"', '\\"')
46
+ return f'"{text}"'
47
+
48
+
49
+ def image_from_url_text(filedata):
50
+ if filedata is None:
51
+ return None
52
+
53
+ if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
54
+ filedata = filedata[0]
55
+
56
+ if type(filedata) == dict and filedata.get("is_file", False):
57
+ filename = filedata["name"]
58
+ is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
59
+ assert is_in_right_dir, 'trying to open image file outside of allowed directories'
60
+
61
+ return Image.open(filename)
62
+
63
+ if type(filedata) == list:
64
+ if len(filedata) == 0:
65
+ return None
66
+
67
+ filedata = filedata[0]
68
+
69
+ if filedata.startswith("data:image/png;base64,"):
70
+ filedata = filedata[len("data:image/png;base64,"):]
71
+
72
+ filedata = base64.decodebytes(filedata.encode('utf-8'))
73
+ image = Image.open(io.BytesIO(filedata))
74
+ return image
75
+
76
+
77
+ def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
78
+ paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
79
+
80
+ # backwards compatibility for existing extensions
81
+ import modules.ui
82
+ if tabname == 'txt2img':
83
+ modules.ui.txt2img_paste_fields = fields
84
+ elif tabname == 'img2img':
85
+ modules.ui.img2img_paste_fields = fields
86
+
87
+
88
+ def create_buttons(tabs_list):
89
+ buttons = {}
90
+ for tab in tabs_list:
91
+ buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
92
+ return buttons
93
+
94
+
95
+ def bind_buttons(buttons, send_image, send_generate_info):
96
+ """old function for backwards compatibility; do not use this, use register_paste_params_button"""
97
+ for tabname, button in buttons.items():
98
+ source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
99
+ source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
100
+
101
+ register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
102
+
103
+
104
+ def register_paste_params_button(binding: ParamBinding):
105
+ registered_param_bindings.append(binding)
106
+
107
+
108
+ def connect_paste_params_buttons():
109
+ binding: ParamBinding
110
+ for binding in registered_param_bindings:
111
+ destination_image_component = paste_fields[binding.tabname]["init_img"]
112
+ fields = paste_fields[binding.tabname]["fields"]
113
+ override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
114
+
115
+ destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
116
+ destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
117
+
118
+ if binding.source_image_component and destination_image_component:
119
+ if isinstance(binding.source_image_component, gr.Gallery):
120
+ func = send_image_and_dimensions if destination_width_component else image_from_url_text
121
+ jsfunc = "extract_image_from_gallery"
122
+ else:
123
+ func = send_image_and_dimensions if destination_width_component else lambda x: x
124
+ jsfunc = None
125
+
126
+ binding.paste_button.click(
127
+ fn=func,
128
+ _js=jsfunc,
129
+ inputs=[binding.source_image_component],
130
+ outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
131
+ )
132
+
133
+ if binding.source_text_component is not None and fields is not None:
134
+ connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
135
+
136
+ if binding.source_tabname is not None and fields is not None:
137
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
138
+ binding.paste_button.click(
139
+ fn=lambda *x: x,
140
+ inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
141
+ outputs=[field for field, name in fields if name in paste_field_names],
142
+ )
143
+
144
+ binding.paste_button.click(
145
+ fn=None,
146
+ _js=f"switch_to_{binding.tabname}",
147
+ inputs=None,
148
+ outputs=None,
149
+ )
150
+
151
+
152
+ def send_image_and_dimensions(x):
153
+ if isinstance(x, Image.Image):
154
+ img = x
155
+ else:
156
+ img = image_from_url_text(x)
157
+
158
+ if shared.opts.send_size and isinstance(img, Image.Image):
159
+ w = img.width
160
+ h = img.height
161
+ else:
162
+ w = gr.update()
163
+ h = gr.update()
164
+
165
+ return img, w, h
166
+
167
+
168
+
169
+ def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
170
+ """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
171
+
172
+ Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
173
+ parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
174
+
175
+ If the infotext has no hash, then a hypernet with the same name will be selected instead.
176
+ """
177
+ hypernet_name = hypernet_name.lower()
178
+ if hypernet_hash is not None:
179
+ # Try to match the hash in the name
180
+ for hypernet_key in shared.hypernetworks.keys():
181
+ result = re_hypernet_hash.search(hypernet_key)
182
+ if result is not None and result[1] == hypernet_hash:
183
+ return hypernet_key
184
+ else:
185
+ # Fall back to a hypernet with the same name
186
+ for hypernet_key in shared.hypernetworks.keys():
187
+ if hypernet_key.lower().startswith(hypernet_name):
188
+ return hypernet_key
189
+
190
+ return None
191
+
192
+
193
+ def restore_old_hires_fix_params(res):
194
+ """for infotexts that specify old First pass size parameter, convert it into
195
+ width, height, and hr scale"""
196
+
197
+ firstpass_width = res.get('First pass size-1', None)
198
+ firstpass_height = res.get('First pass size-2', None)
199
+
200
+ if shared.opts.use_old_hires_fix_width_height:
201
+ hires_width = int(res.get("Hires resize-1", 0))
202
+ hires_height = int(res.get("Hires resize-2", 0))
203
+
204
+ if hires_width and hires_height:
205
+ res['Size-1'] = hires_width
206
+ res['Size-2'] = hires_height
207
+ return
208
+
209
+ if firstpass_width is None or firstpass_height is None:
210
+ return
211
+
212
+ firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
213
+ width = int(res.get("Size-1", 512))
214
+ height = int(res.get("Size-2", 512))
215
+
216
+ if firstpass_width == 0 or firstpass_height == 0:
217
+ from modules import processing
218
+ firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
219
+
220
+ res['Size-1'] = firstpass_width
221
+ res['Size-2'] = firstpass_height
222
+ res['Hires resize-1'] = width
223
+ res['Hires resize-2'] = height
224
+
225
+
226
+ def parse_generation_parameters(x: str):
227
+ """parses generation parameters string, the one you see in text field under the picture in UI:
228
+ ```
229
+ girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
230
+ Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
231
+ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
232
+ ```
233
+
234
+ returns a dict with field values
235
+ """
236
+
237
+ res = {}
238
+
239
+ prompt = ""
240
+ negative_prompt = ""
241
+
242
+ done_with_prompt = False
243
+
244
+ *lines, lastline = x.strip().split("\n")
245
+ if len(re_param.findall(lastline)) < 3:
246
+ lines.append(lastline)
247
+ lastline = ''
248
+
249
+ for i, line in enumerate(lines):
250
+ line = line.strip()
251
+ if line.startswith("Negative prompt:"):
252
+ done_with_prompt = True
253
+ line = line[16:].strip()
254
+
255
+ if done_with_prompt:
256
+ negative_prompt += ("" if negative_prompt == "" else "\n") + line
257
+ else:
258
+ prompt += ("" if prompt == "" else "\n") + line
259
+
260
+ res["Prompt"] = prompt
261
+ res["Negative prompt"] = negative_prompt
262
+
263
+ for k, v in re_param.findall(lastline):
264
+ v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
265
+ m = re_imagesize.match(v)
266
+ if m is not None:
267
+ res[k+"-1"] = m.group(1)
268
+ res[k+"-2"] = m.group(2)
269
+ else:
270
+ res[k] = v
271
+
272
+ # Missing CLIP skip means it was set to 1 (the default)
273
+ if "Clip skip" not in res:
274
+ res["Clip skip"] = "1"
275
+
276
+ hypernet = res.get("Hypernet", None)
277
+ if hypernet is not None:
278
+ res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
279
+
280
+ if "Hires resize-1" not in res:
281
+ res["Hires resize-1"] = 0
282
+ res["Hires resize-2"] = 0
283
+
284
+ restore_old_hires_fix_params(res)
285
+
286
+ return res
287
+
288
+
289
+ settings_map = {}
290
+
291
+
292
+
293
+ infotext_to_setting_name_mapping = [
294
+ ('Clip skip', 'CLIP_stop_at_last_layers', ),
295
+ ('Conditional mask weight', 'inpainting_mask_weight'),
296
+ ('Model hash', 'sd_model_checkpoint'),
297
+ ('ENSD', 'eta_noise_seed_delta'),
298
+ ('Noise multiplier', 'initial_noise_multiplier'),
299
+ ('Eta', 'eta_ancestral'),
300
+ ('Eta DDIM', 'eta_ddim'),
301
+ ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
302
+ ('UniPC variant', 'uni_pc_variant'),
303
+ ('UniPC skip type', 'uni_pc_skip_type'),
304
+ ('UniPC order', 'uni_pc_order'),
305
+ ('UniPC lower order final', 'uni_pc_lower_order_final'),
306
+ ]
307
+
308
+
309
+ def create_override_settings_dict(text_pairs):
310
+ """creates processing's override_settings parameters from gradio's multiselect
311
+
312
+ Example input:
313
+ ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
314
+
315
+ Example output:
316
+ {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
317
+ """
318
+
319
+ res = {}
320
+
321
+ params = {}
322
+ for pair in text_pairs:
323
+ k, v = pair.split(":", maxsplit=1)
324
+
325
+ params[k] = v.strip()
326
+
327
+ for param_name, setting_name in infotext_to_setting_name_mapping:
328
+ value = params.get(param_name, None)
329
+
330
+ if value is None:
331
+ continue
332
+
333
+ res[setting_name] = shared.opts.cast_value(setting_name, value)
334
+
335
+ return res
336
+
337
+
338
+ def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
339
+ def paste_func(prompt):
340
+ if not prompt and not shared.cmd_opts.hide_ui_dir_config:
341
+ filename = os.path.join(data_path, "params.txt")
342
+ if os.path.exists(filename):
343
+ with open(filename, "r", encoding="utf8") as file:
344
+ prompt = file.read()
345
+
346
+ params = parse_generation_parameters(prompt)
347
+ script_callbacks.infotext_pasted_callback(prompt, params)
348
+ res = []
349
+
350
+ for output, key in paste_fields:
351
+ if callable(key):
352
+ v = key(params)
353
+ else:
354
+ v = params.get(key, None)
355
+
356
+ if v is None:
357
+ res.append(gr.update())
358
+ elif isinstance(v, type_of_gr_update):
359
+ res.append(v)
360
+ else:
361
+ try:
362
+ valtype = type(output.value)
363
+
364
+ if valtype == bool and v == "False":
365
+ val = False
366
+ else:
367
+ val = valtype(v)
368
+
369
+ res.append(gr.update(value=val))
370
+ except Exception:
371
+ res.append(gr.update())
372
+
373
+ return res
374
+
375
+ if override_settings_component is not None:
376
+ def paste_settings(params):
377
+ vals = {}
378
+
379
+ for param_name, setting_name in infotext_to_setting_name_mapping:
380
+ v = params.get(param_name, None)
381
+ if v is None:
382
+ continue
383
+
384
+ if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
385
+ continue
386
+
387
+ v = shared.opts.cast_value(setting_name, v)
388
+ current_value = getattr(shared.opts, setting_name, None)
389
+
390
+ if v == current_value:
391
+ continue
392
+
393
+ vals[param_name] = v
394
+
395
+ vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
396
+
397
+ return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
398
+
399
+ paste_fields = paste_fields + [(override_settings_component, paste_settings)]
400
+
401
+ button.click(
402
+ fn=paste_func,
403
+ _js=f"recalculate_prompts_{tabname}",
404
+ inputs=[input_comp],
405
+ outputs=[x[0] for x in paste_fields],
406
+ )
407
+
408
+
sd/stable-diffusion-webui/modules/gfpgan_model.py CHANGED
@@ -1,116 +1,116 @@
1
- import os
2
- import sys
3
- import traceback
4
-
5
- import facexlib
6
- import gfpgan
7
-
8
- import modules.face_restoration
9
- from modules import paths, shared, devices, modelloader
10
-
11
- model_dir = "GFPGAN"
12
- user_path = None
13
- model_path = os.path.join(paths.models_path, model_dir)
14
- model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
15
- have_gfpgan = False
16
- loaded_gfpgan_model = None
17
-
18
-
19
- def gfpgann():
20
- global loaded_gfpgan_model
21
- global model_path
22
- if loaded_gfpgan_model is not None:
23
- loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
24
- return loaded_gfpgan_model
25
-
26
- if gfpgan_constructor is None:
27
- return None
28
-
29
- models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
30
- if len(models) == 1 and "http" in models[0]:
31
- model_file = models[0]
32
- elif len(models) != 0:
33
- latest_file = max(models, key=os.path.getctime)
34
- model_file = latest_file
35
- else:
36
- print("Unable to load gfpgan model!")
37
- return None
38
- if hasattr(facexlib.detection.retinaface, 'device'):
39
- facexlib.detection.retinaface.device = devices.device_gfpgan
40
- model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
41
- loaded_gfpgan_model = model
42
-
43
- return model
44
-
45
-
46
- def send_model_to(model, device):
47
- model.gfpgan.to(device)
48
- model.face_helper.face_det.to(device)
49
- model.face_helper.face_parse.to(device)
50
-
51
-
52
- def gfpgan_fix_faces(np_image):
53
- model = gfpgann()
54
- if model is None:
55
- return np_image
56
-
57
- send_model_to(model, devices.device_gfpgan)
58
-
59
- np_image_bgr = np_image[:, :, ::-1]
60
- cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
61
- np_image = gfpgan_output_bgr[:, :, ::-1]
62
-
63
- model.face_helper.clean_all()
64
-
65
- if shared.opts.face_restoration_unload:
66
- send_model_to(model, devices.cpu)
67
-
68
- return np_image
69
-
70
-
71
- gfpgan_constructor = None
72
-
73
-
74
- def setup_model(dirname):
75
- global model_path
76
- if not os.path.exists(model_path):
77
- os.makedirs(model_path)
78
-
79
- try:
80
- from gfpgan import GFPGANer
81
- from facexlib import detection, parsing
82
- global user_path
83
- global have_gfpgan
84
- global gfpgan_constructor
85
-
86
- load_file_from_url_orig = gfpgan.utils.load_file_from_url
87
- facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
88
- facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
89
-
90
- def my_load_file_from_url(**kwargs):
91
- return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
92
-
93
- def facex_load_file_from_url(**kwargs):
94
- return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
95
-
96
- def facex_load_file_from_url2(**kwargs):
97
- return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
98
-
99
- gfpgan.utils.load_file_from_url = my_load_file_from_url
100
- facexlib.detection.load_file_from_url = facex_load_file_from_url
101
- facexlib.parsing.load_file_from_url = facex_load_file_from_url2
102
- user_path = dirname
103
- have_gfpgan = True
104
- gfpgan_constructor = GFPGANer
105
-
106
- class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
107
- def name(self):
108
- return "GFPGAN"
109
-
110
- def restore(self, np_image):
111
- return gfpgan_fix_faces(np_image)
112
-
113
- shared.face_restorers.append(FaceRestorerGFPGAN())
114
- except Exception:
115
- print("Error setting up GFPGAN:", file=sys.stderr)
116
- print(traceback.format_exc(), file=sys.stderr)
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ import facexlib
6
+ import gfpgan
7
+
8
+ import modules.face_restoration
9
+ from modules import paths, shared, devices, modelloader
10
+
11
+ model_dir = "GFPGAN"
12
+ user_path = None
13
+ model_path = os.path.join(paths.models_path, model_dir)
14
+ model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
15
+ have_gfpgan = False
16
+ loaded_gfpgan_model = None
17
+
18
+
19
+ def gfpgann():
20
+ global loaded_gfpgan_model
21
+ global model_path
22
+ if loaded_gfpgan_model is not None:
23
+ loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
24
+ return loaded_gfpgan_model
25
+
26
+ if gfpgan_constructor is None:
27
+ return None
28
+
29
+ models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
30
+ if len(models) == 1 and "http" in models[0]:
31
+ model_file = models[0]
32
+ elif len(models) != 0:
33
+ latest_file = max(models, key=os.path.getctime)
34
+ model_file = latest_file
35
+ else:
36
+ print("Unable to load gfpgan model!")
37
+ return None
38
+ if hasattr(facexlib.detection.retinaface, 'device'):
39
+ facexlib.detection.retinaface.device = devices.device_gfpgan
40
+ model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
41
+ loaded_gfpgan_model = model
42
+
43
+ return model
44
+
45
+
46
+ def send_model_to(model, device):
47
+ model.gfpgan.to(device)
48
+ model.face_helper.face_det.to(device)
49
+ model.face_helper.face_parse.to(device)
50
+
51
+
52
+ def gfpgan_fix_faces(np_image):
53
+ model = gfpgann()
54
+ if model is None:
55
+ return np_image
56
+
57
+ send_model_to(model, devices.device_gfpgan)
58
+
59
+ np_image_bgr = np_image[:, :, ::-1]
60
+ cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
61
+ np_image = gfpgan_output_bgr[:, :, ::-1]
62
+
63
+ model.face_helper.clean_all()
64
+
65
+ if shared.opts.face_restoration_unload:
66
+ send_model_to(model, devices.cpu)
67
+
68
+ return np_image
69
+
70
+
71
+ gfpgan_constructor = None
72
+
73
+
74
+ def setup_model(dirname):
75
+ global model_path
76
+ if not os.path.exists(model_path):
77
+ os.makedirs(model_path)
78
+
79
+ try:
80
+ from gfpgan import GFPGANer
81
+ from facexlib import detection, parsing
82
+ global user_path
83
+ global have_gfpgan
84
+ global gfpgan_constructor
85
+
86
+ load_file_from_url_orig = gfpgan.utils.load_file_from_url
87
+ facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
88
+ facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
89
+
90
+ def my_load_file_from_url(**kwargs):
91
+ return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
92
+
93
+ def facex_load_file_from_url(**kwargs):
94
+ return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
95
+
96
+ def facex_load_file_from_url2(**kwargs):
97
+ return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
98
+
99
+ gfpgan.utils.load_file_from_url = my_load_file_from_url
100
+ facexlib.detection.load_file_from_url = facex_load_file_from_url
101
+ facexlib.parsing.load_file_from_url = facex_load_file_from_url2
102
+ user_path = dirname
103
+ have_gfpgan = True
104
+ gfpgan_constructor = GFPGANer
105
+
106
+ class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
107
+ def name(self):
108
+ return "GFPGAN"
109
+
110
+ def restore(self, np_image):
111
+ return gfpgan_fix_faces(np_image)
112
+
113
+ shared.face_restorers.append(FaceRestorerGFPGAN())
114
+ except Exception:
115
+ print("Error setting up GFPGAN:", file=sys.stderr)
116
+ print(traceback.format_exc(), file=sys.stderr)
sd/stable-diffusion-webui/modules/hashes.py CHANGED
@@ -1,91 +1,91 @@
1
- import hashlib
2
- import json
3
- import os.path
4
-
5
- import filelock
6
-
7
- from modules import shared
8
- from modules.paths import data_path
9
-
10
-
11
- cache_filename = os.path.join(data_path, "cache.json")
12
- cache_data = None
13
-
14
-
15
- def dump_cache():
16
- with filelock.FileLock(cache_filename+".lock"):
17
- with open(cache_filename, "w", encoding="utf8") as file:
18
- json.dump(cache_data, file, indent=4)
19
-
20
-
21
- def cache(subsection):
22
- global cache_data
23
-
24
- if cache_data is None:
25
- with filelock.FileLock(cache_filename+".lock"):
26
- if not os.path.isfile(cache_filename):
27
- cache_data = {}
28
- else:
29
- with open(cache_filename, "r", encoding="utf8") as file:
30
- cache_data = json.load(file)
31
-
32
- s = cache_data.get(subsection, {})
33
- cache_data[subsection] = s
34
-
35
- return s
36
-
37
-
38
- def calculate_sha256(filename):
39
- hash_sha256 = hashlib.sha256()
40
- blksize = 1024 * 1024
41
-
42
- with open(filename, "rb") as f:
43
- for chunk in iter(lambda: f.read(blksize), b""):
44
- hash_sha256.update(chunk)
45
-
46
- return hash_sha256.hexdigest()
47
-
48
-
49
- def sha256_from_cache(filename, title):
50
- hashes = cache("hashes")
51
- ondisk_mtime = os.path.getmtime(filename)
52
-
53
- if title not in hashes:
54
- return None
55
-
56
- cached_sha256 = hashes[title].get("sha256", None)
57
- cached_mtime = hashes[title].get("mtime", 0)
58
-
59
- if ondisk_mtime > cached_mtime or cached_sha256 is None:
60
- return None
61
-
62
- return cached_sha256
63
-
64
-
65
- def sha256(filename, title):
66
- hashes = cache("hashes")
67
-
68
- sha256_value = sha256_from_cache(filename, title)
69
- if sha256_value is not None:
70
- return sha256_value
71
-
72
- if shared.cmd_opts.no_hashing:
73
- return None
74
-
75
- print(f"Calculating sha256 for {filename}: ", end='')
76
- sha256_value = calculate_sha256(filename)
77
- print(f"{sha256_value}")
78
-
79
- hashes[title] = {
80
- "mtime": os.path.getmtime(filename),
81
- "sha256": sha256_value,
82
- }
83
-
84
- dump_cache()
85
-
86
- return sha256_value
87
-
88
-
89
-
90
-
91
-
 
1
+ import hashlib
2
+ import json
3
+ import os.path
4
+
5
+ import filelock
6
+
7
+ from modules import shared
8
+ from modules.paths import data_path
9
+
10
+
11
+ cache_filename = os.path.join(data_path, "cache.json")
12
+ cache_data = None
13
+
14
+
15
+ def dump_cache():
16
+ with filelock.FileLock(cache_filename+".lock"):
17
+ with open(cache_filename, "w", encoding="utf8") as file:
18
+ json.dump(cache_data, file, indent=4)
19
+
20
+
21
+ def cache(subsection):
22
+ global cache_data
23
+
24
+ if cache_data is None:
25
+ with filelock.FileLock(cache_filename+".lock"):
26
+ if not os.path.isfile(cache_filename):
27
+ cache_data = {}
28
+ else:
29
+ with open(cache_filename, "r", encoding="utf8") as file:
30
+ cache_data = json.load(file)
31
+
32
+ s = cache_data.get(subsection, {})
33
+ cache_data[subsection] = s
34
+
35
+ return s
36
+
37
+
38
+ def calculate_sha256(filename):
39
+ hash_sha256 = hashlib.sha256()
40
+ blksize = 1024 * 1024
41
+
42
+ with open(filename, "rb") as f:
43
+ for chunk in iter(lambda: f.read(blksize), b""):
44
+ hash_sha256.update(chunk)
45
+
46
+ return hash_sha256.hexdigest()
47
+
48
+
49
+ def sha256_from_cache(filename, title):
50
+ hashes = cache("hashes")
51
+ ondisk_mtime = os.path.getmtime(filename)
52
+
53
+ if title not in hashes:
54
+ return None
55
+
56
+ cached_sha256 = hashes[title].get("sha256", None)
57
+ cached_mtime = hashes[title].get("mtime", 0)
58
+
59
+ if ondisk_mtime > cached_mtime or cached_sha256 is None:
60
+ return None
61
+
62
+ return cached_sha256
63
+
64
+
65
+ def sha256(filename, title):
66
+ hashes = cache("hashes")
67
+
68
+ sha256_value = sha256_from_cache(filename, title)
69
+ if sha256_value is not None:
70
+ return sha256_value
71
+
72
+ if shared.cmd_opts.no_hashing:
73
+ return None
74
+
75
+ print(f"Calculating sha256 for {filename}: ", end='')
76
+ sha256_value = calculate_sha256(filename)
77
+ print(f"{sha256_value}")
78
+
79
+ hashes[title] = {
80
+ "mtime": os.path.getmtime(filename),
81
+ "sha256": sha256_value,
82
+ }
83
+
84
+ dump_cache()
85
+
86
+ return sha256_value
87
+
88
+
89
+
90
+
91
+
sd/stable-diffusion-webui/modules/hypernetworks/hypernetwork.py CHANGED
@@ -1,811 +1,811 @@
1
- import csv
2
- import datetime
3
- import glob
4
- import html
5
- import os
6
- import sys
7
- import traceback
8
- import inspect
9
-
10
- import modules.textual_inversion.dataset
11
- import torch
12
- import tqdm
13
- from einops import rearrange, repeat
14
- from ldm.util import default
15
- from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
16
- from modules.textual_inversion import textual_inversion, logging
17
- from modules.textual_inversion.learn_schedule import LearnRateScheduler
18
- from torch import einsum
19
- from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
20
-
21
- from collections import defaultdict, deque
22
- from statistics import stdev, mean
23
-
24
-
25
- optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
26
-
27
- class HypernetworkModule(torch.nn.Module):
28
- activation_dict = {
29
- "linear": torch.nn.Identity,
30
- "relu": torch.nn.ReLU,
31
- "leakyrelu": torch.nn.LeakyReLU,
32
- "elu": torch.nn.ELU,
33
- "swish": torch.nn.Hardswish,
34
- "tanh": torch.nn.Tanh,
35
- "sigmoid": torch.nn.Sigmoid,
36
- }
37
- activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
38
-
39
- def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
40
- add_layer_norm=False, activate_output=False, dropout_structure=None):
41
- super().__init__()
42
-
43
- self.multiplier = 1.0
44
-
45
- assert layer_structure is not None, "layer_structure must not be None"
46
- assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
47
- assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
48
-
49
- linears = []
50
- for i in range(len(layer_structure) - 1):
51
-
52
- # Add a fully-connected layer
53
- linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
54
-
55
- # Add an activation func except last layer
56
- if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
57
- pass
58
- elif activation_func in self.activation_dict:
59
- linears.append(self.activation_dict[activation_func]())
60
- else:
61
- raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
62
-
63
- # Add layer normalization
64
- if add_layer_norm:
65
- linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
66
-
67
- # Everything should be now parsed into dropout structure, and applied here.
68
- # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
69
- if dropout_structure is not None and dropout_structure[i+1] > 0:
70
- assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
71
- linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
72
- # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
73
-
74
- self.linear = torch.nn.Sequential(*linears)
75
-
76
- if state_dict is not None:
77
- self.fix_old_state_dict(state_dict)
78
- self.load_state_dict(state_dict)
79
- else:
80
- for layer in self.linear:
81
- if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
82
- w, b = layer.weight.data, layer.bias.data
83
- if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
84
- normal_(w, mean=0.0, std=0.01)
85
- normal_(b, mean=0.0, std=0)
86
- elif weight_init == 'XavierUniform':
87
- xavier_uniform_(w)
88
- zeros_(b)
89
- elif weight_init == 'XavierNormal':
90
- xavier_normal_(w)
91
- zeros_(b)
92
- elif weight_init == 'KaimingUniform':
93
- kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
94
- zeros_(b)
95
- elif weight_init == 'KaimingNormal':
96
- kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
97
- zeros_(b)
98
- else:
99
- raise KeyError(f"Key {weight_init} is not defined as initialization!")
100
- self.to(devices.device)
101
-
102
- def fix_old_state_dict(self, state_dict):
103
- changes = {
104
- 'linear1.bias': 'linear.0.bias',
105
- 'linear1.weight': 'linear.0.weight',
106
- 'linear2.bias': 'linear.1.bias',
107
- 'linear2.weight': 'linear.1.weight',
108
- }
109
-
110
- for fr, to in changes.items():
111
- x = state_dict.get(fr, None)
112
- if x is None:
113
- continue
114
-
115
- del state_dict[fr]
116
- state_dict[to] = x
117
-
118
- def forward(self, x):
119
- return x + self.linear(x) * (self.multiplier if not self.training else 1)
120
-
121
- def trainables(self):
122
- layer_structure = []
123
- for layer in self.linear:
124
- if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
125
- layer_structure += [layer.weight, layer.bias]
126
- return layer_structure
127
-
128
-
129
- #param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
130
- def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
131
- if layer_structure is None:
132
- layer_structure = [1, 2, 1]
133
- if not use_dropout:
134
- return [0] * len(layer_structure)
135
- dropout_values = [0]
136
- dropout_values.extend([0.3] * (len(layer_structure) - 3))
137
- if last_layer_dropout:
138
- dropout_values.append(0.3)
139
- else:
140
- dropout_values.append(0)
141
- dropout_values.append(0)
142
- return dropout_values
143
-
144
-
145
- class Hypernetwork:
146
- filename = None
147
- name = None
148
-
149
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
150
- self.filename = None
151
- self.name = name
152
- self.layers = {}
153
- self.step = 0
154
- self.sd_checkpoint = None
155
- self.sd_checkpoint_name = None
156
- self.layer_structure = layer_structure
157
- self.activation_func = activation_func
158
- self.weight_init = weight_init
159
- self.add_layer_norm = add_layer_norm
160
- self.use_dropout = use_dropout
161
- self.activate_output = activate_output
162
- self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
163
- self.dropout_structure = kwargs.get('dropout_structure', None)
164
- if self.dropout_structure is None:
165
- self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
166
- self.optimizer_name = None
167
- self.optimizer_state_dict = None
168
- self.optional_info = None
169
-
170
- for size in enable_sizes or []:
171
- self.layers[size] = (
172
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
173
- self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
174
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
175
- self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
176
- )
177
- self.eval()
178
-
179
- def weights(self):
180
- res = []
181
- for k, layers in self.layers.items():
182
- for layer in layers:
183
- res += layer.parameters()
184
- return res
185
-
186
- def train(self, mode=True):
187
- for k, layers in self.layers.items():
188
- for layer in layers:
189
- layer.train(mode=mode)
190
- for param in layer.parameters():
191
- param.requires_grad = mode
192
-
193
- def to(self, device):
194
- for k, layers in self.layers.items():
195
- for layer in layers:
196
- layer.to(device)
197
-
198
- return self
199
-
200
- def set_multiplier(self, multiplier):
201
- for k, layers in self.layers.items():
202
- for layer in layers:
203
- layer.multiplier = multiplier
204
-
205
- return self
206
-
207
- def eval(self):
208
- for k, layers in self.layers.items():
209
- for layer in layers:
210
- layer.eval()
211
- for param in layer.parameters():
212
- param.requires_grad = False
213
-
214
- def save(self, filename):
215
- state_dict = {}
216
- optimizer_saved_dict = {}
217
-
218
- for k, v in self.layers.items():
219
- state_dict[k] = (v[0].state_dict(), v[1].state_dict())
220
-
221
- state_dict['step'] = self.step
222
- state_dict['name'] = self.name
223
- state_dict['layer_structure'] = self.layer_structure
224
- state_dict['activation_func'] = self.activation_func
225
- state_dict['is_layer_norm'] = self.add_layer_norm
226
- state_dict['weight_initialization'] = self.weight_init
227
- state_dict['sd_checkpoint'] = self.sd_checkpoint
228
- state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
229
- state_dict['activate_output'] = self.activate_output
230
- state_dict['use_dropout'] = self.use_dropout
231
- state_dict['dropout_structure'] = self.dropout_structure
232
- state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
233
- state_dict['optional_info'] = self.optional_info if self.optional_info else None
234
-
235
- if self.optimizer_name is not None:
236
- optimizer_saved_dict['optimizer_name'] = self.optimizer_name
237
-
238
- torch.save(state_dict, filename)
239
- if shared.opts.save_optimizer_state and self.optimizer_state_dict:
240
- optimizer_saved_dict['hash'] = self.shorthash()
241
- optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
242
- torch.save(optimizer_saved_dict, filename + '.optim')
243
-
244
- def load(self, filename):
245
- self.filename = filename
246
- if self.name is None:
247
- self.name = os.path.splitext(os.path.basename(filename))[0]
248
-
249
- state_dict = torch.load(filename, map_location='cpu')
250
-
251
- self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
252
- self.optional_info = state_dict.get('optional_info', None)
253
- self.activation_func = state_dict.get('activation_func', None)
254
- self.weight_init = state_dict.get('weight_initialization', 'Normal')
255
- self.add_layer_norm = state_dict.get('is_layer_norm', False)
256
- self.dropout_structure = state_dict.get('dropout_structure', None)
257
- self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
258
- self.activate_output = state_dict.get('activate_output', True)
259
- self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
260
- # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
261
- if self.dropout_structure is None:
262
- self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
263
-
264
- if shared.opts.print_hypernet_extra:
265
- if self.optional_info is not None:
266
- print(f" INFO:\n {self.optional_info}\n")
267
-
268
- print(f" Layer structure: {self.layer_structure}")
269
- print(f" Activation function: {self.activation_func}")
270
- print(f" Weight initialization: {self.weight_init}")
271
- print(f" Layer norm: {self.add_layer_norm}")
272
- print(f" Dropout usage: {self.use_dropout}" )
273
- print(f" Activate last layer: {self.activate_output}")
274
- print(f" Dropout structure: {self.dropout_structure}")
275
-
276
- optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
277
-
278
- if self.shorthash() == optimizer_saved_dict.get('hash', None):
279
- self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
280
- else:
281
- self.optimizer_state_dict = None
282
- if self.optimizer_state_dict:
283
- self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
284
- if shared.opts.print_hypernet_extra:
285
- print("Loaded existing optimizer from checkpoint")
286
- print(f"Optimizer name is {self.optimizer_name}")
287
- else:
288
- self.optimizer_name = "AdamW"
289
- if shared.opts.print_hypernet_extra:
290
- print("No saved optimizer exists in checkpoint")
291
-
292
- for size, sd in state_dict.items():
293
- if type(size) == int:
294
- self.layers[size] = (
295
- HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
296
- self.add_layer_norm, self.activate_output, self.dropout_structure),
297
- HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
298
- self.add_layer_norm, self.activate_output, self.dropout_structure),
299
- )
300
-
301
- self.name = state_dict.get('name', self.name)
302
- self.step = state_dict.get('step', 0)
303
- self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
304
- self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
305
- self.eval()
306
-
307
- def shorthash(self):
308
- sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
309
-
310
- return sha256[0:10] if sha256 else None
311
-
312
-
313
- def list_hypernetworks(path):
314
- res = {}
315
- for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
316
- name = os.path.splitext(os.path.basename(filename))[0]
317
- # Prevent a hypothetical "None.pt" from being listed.
318
- if name != "None":
319
- res[name] = filename
320
- return res
321
-
322
-
323
- def load_hypernetwork(name):
324
- path = shared.hypernetworks.get(name, None)
325
-
326
- if path is None:
327
- return None
328
-
329
- hypernetwork = Hypernetwork()
330
-
331
- try:
332
- hypernetwork.load(path)
333
- except Exception:
334
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
335
- print(traceback.format_exc(), file=sys.stderr)
336
- return None
337
-
338
- return hypernetwork
339
-
340
-
341
- def load_hypernetworks(names, multipliers=None):
342
- already_loaded = {}
343
-
344
- for hypernetwork in shared.loaded_hypernetworks:
345
- if hypernetwork.name in names:
346
- already_loaded[hypernetwork.name] = hypernetwork
347
-
348
- shared.loaded_hypernetworks.clear()
349
-
350
- for i, name in enumerate(names):
351
- hypernetwork = already_loaded.get(name, None)
352
- if hypernetwork is None:
353
- hypernetwork = load_hypernetwork(name)
354
-
355
- if hypernetwork is None:
356
- continue
357
-
358
- hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
359
- shared.loaded_hypernetworks.append(hypernetwork)
360
-
361
-
362
- def find_closest_hypernetwork_name(search: str):
363
- if not search:
364
- return None
365
- search = search.lower()
366
- applicable = [name for name in shared.hypernetworks if search in name.lower()]
367
- if not applicable:
368
- return None
369
- applicable = sorted(applicable, key=lambda name: len(name))
370
- return applicable[0]
371
-
372
-
373
- def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
374
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
375
-
376
- if hypernetwork_layers is None:
377
- return context_k, context_v
378
-
379
- if layer is not None:
380
- layer.hyper_k = hypernetwork_layers[0]
381
- layer.hyper_v = hypernetwork_layers[1]
382
-
383
- context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
384
- context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
385
- return context_k, context_v
386
-
387
-
388
- def apply_hypernetworks(hypernetworks, context, layer=None):
389
- context_k = context
390
- context_v = context
391
- for hypernetwork in hypernetworks:
392
- context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
393
-
394
- return context_k, context_v
395
-
396
-
397
- def attention_CrossAttention_forward(self, x, context=None, mask=None):
398
- h = self.heads
399
-
400
- q = self.to_q(x)
401
- context = default(context, x)
402
-
403
- context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
404
- k = self.to_k(context_k)
405
- v = self.to_v(context_v)
406
-
407
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
408
-
409
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
410
-
411
- if mask is not None:
412
- mask = rearrange(mask, 'b ... -> b (...)')
413
- max_neg_value = -torch.finfo(sim.dtype).max
414
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
415
- sim.masked_fill_(~mask, max_neg_value)
416
-
417
- # attention, what we cannot get enough of
418
- attn = sim.softmax(dim=-1)
419
-
420
- out = einsum('b i j, b j d -> b i d', attn, v)
421
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
422
- return self.to_out(out)
423
-
424
-
425
- def stack_conds(conds):
426
- if len(conds) == 1:
427
- return torch.stack(conds)
428
-
429
- # same as in reconstruct_multicond_batch
430
- token_count = max([x.shape[0] for x in conds])
431
- for i in range(len(conds)):
432
- if conds[i].shape[0] != token_count:
433
- last_vector = conds[i][-1:]
434
- last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
435
- conds[i] = torch.vstack([conds[i], last_vector_repeated])
436
-
437
- return torch.stack(conds)
438
-
439
-
440
- def statistics(data):
441
- if len(data) < 2:
442
- std = 0
443
- else:
444
- std = stdev(data)
445
- total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
446
- recent_data = data[-32:]
447
- if len(recent_data) < 2:
448
- std = 0
449
- else:
450
- std = stdev(recent_data)
451
- recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
452
- return total_information, recent_information
453
-
454
-
455
- def report_statistics(loss_info:dict):
456
- keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
457
- for key in keys:
458
- try:
459
- print("Loss statistics for file " + key)
460
- info, recent = statistics(list(loss_info[key]))
461
- print(info)
462
- print(recent)
463
- except Exception as e:
464
- print(e)
465
-
466
-
467
- def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
468
- # Remove illegal characters from name.
469
- name = "".join( x for x in name if (x.isalnum() or x in "._- "))
470
- assert name, "Name cannot be empty!"
471
-
472
- fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
473
- if not overwrite_old:
474
- assert not os.path.exists(fn), f"file {fn} already exists"
475
-
476
- if type(layer_structure) == str:
477
- layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
478
-
479
- if use_dropout and dropout_structure and type(dropout_structure) == str:
480
- dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
481
- else:
482
- dropout_structure = [0] * len(layer_structure)
483
-
484
- hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
485
- name=name,
486
- enable_sizes=[int(x) for x in enable_sizes],
487
- layer_structure=layer_structure,
488
- activation_func=activation_func,
489
- weight_init=weight_init,
490
- add_layer_norm=add_layer_norm,
491
- use_dropout=use_dropout,
492
- dropout_structure=dropout_structure
493
- )
494
- hypernet.save(fn)
495
-
496
- shared.reload_hypernetworks()
497
-
498
-
499
- def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
500
- # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
501
- from modules import images
502
-
503
- save_hypernetwork_every = save_hypernetwork_every or 0
504
- create_image_every = create_image_every or 0
505
- template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
506
- textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
507
- template_file = template_file.path
508
-
509
- path = shared.hypernetworks.get(hypernetwork_name, None)
510
- hypernetwork = Hypernetwork()
511
- hypernetwork.load(path)
512
- shared.loaded_hypernetworks = [hypernetwork]
513
-
514
- shared.state.job = "train-hypernetwork"
515
- shared.state.textinfo = "Initializing hypernetwork training..."
516
- shared.state.job_count = steps
517
-
518
- hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
519
- filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
520
-
521
- log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
522
- unload = shared.opts.unload_models_when_training
523
-
524
- if save_hypernetwork_every > 0:
525
- hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
526
- os.makedirs(hypernetwork_dir, exist_ok=True)
527
- else:
528
- hypernetwork_dir = None
529
-
530
- if create_image_every > 0:
531
- images_dir = os.path.join(log_directory, "images")
532
- os.makedirs(images_dir, exist_ok=True)
533
- else:
534
- images_dir = None
535
-
536
- checkpoint = sd_models.select_checkpoint()
537
-
538
- initial_step = hypernetwork.step or 0
539
- if initial_step >= steps:
540
- shared.state.textinfo = "Model has already been trained beyond specified max steps"
541
- return hypernetwork, filename
542
-
543
- scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
544
-
545
- clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
546
- if clip_grad:
547
- clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
548
-
549
- if shared.opts.training_enable_tensorboard:
550
- tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
551
-
552
- # dataset loading may take a while, so input validations and early returns should be done before this
553
- shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
554
-
555
- pin_memory = shared.opts.pin_memory
556
-
557
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
558
-
559
- if shared.opts.save_training_settings_to_txt:
560
- saved_params = dict(
561
- model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
562
- **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
563
- )
564
- logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
565
-
566
- latent_sampling_method = ds.latent_sampling_method
567
-
568
- dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
569
-
570
- old_parallel_processing_allowed = shared.parallel_processing_allowed
571
-
572
- if unload:
573
- shared.parallel_processing_allowed = False
574
- shared.sd_model.cond_stage_model.to(devices.cpu)
575
- shared.sd_model.first_stage_model.to(devices.cpu)
576
-
577
- weights = hypernetwork.weights()
578
- hypernetwork.train()
579
-
580
- # Here we use optimizer from saved HN, or we can specify as UI option.
581
- if hypernetwork.optimizer_name in optimizer_dict:
582
- optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
583
- optimizer_name = hypernetwork.optimizer_name
584
- else:
585
- print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
586
- optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
587
- optimizer_name = 'AdamW'
588
-
589
- if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
590
- try:
591
- optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
592
- except RuntimeError as e:
593
- print("Cannot resume from saved optimizer!")
594
- print(e)
595
-
596
- scaler = torch.cuda.amp.GradScaler()
597
-
598
- batch_size = ds.batch_size
599
- gradient_step = ds.gradient_step
600
- # n steps = batch_size * gradient_step * n image processed
601
- steps_per_epoch = len(ds) // batch_size // gradient_step
602
- max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
603
- loss_step = 0
604
- _loss_step = 0 #internal
605
- # size = len(ds.indexes)
606
- # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
607
- loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
608
- # losses = torch.zeros((size,))
609
- # previous_mean_losses = [0]
610
- # previous_mean_loss = 0
611
- # print("Mean loss of {} elements".format(size))
612
-
613
- steps_without_grad = 0
614
-
615
- last_saved_file = "<none>"
616
- last_saved_image = "<none>"
617
- forced_filename = "<none>"
618
-
619
- pbar = tqdm.tqdm(total=steps - initial_step)
620
- try:
621
- sd_hijack_checkpoint.add()
622
-
623
- for i in range((steps-initial_step) * gradient_step):
624
- if scheduler.finished:
625
- break
626
- if shared.state.interrupted:
627
- break
628
- for j, batch in enumerate(dl):
629
- # works as a drop_last=True for gradient accumulation
630
- if j == max_steps_per_epoch:
631
- break
632
- scheduler.apply(optimizer, hypernetwork.step)
633
- if scheduler.finished:
634
- break
635
- if shared.state.interrupted:
636
- break
637
-
638
- if clip_grad:
639
- clip_grad_sched.step(hypernetwork.step)
640
-
641
- with devices.autocast():
642
- x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
643
- if use_weight:
644
- w = batch.weight.to(devices.device, non_blocking=pin_memory)
645
- if tag_drop_out != 0 or shuffle_tags:
646
- shared.sd_model.cond_stage_model.to(devices.device)
647
- c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
648
- shared.sd_model.cond_stage_model.to(devices.cpu)
649
- else:
650
- c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
651
- if use_weight:
652
- loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
653
- del w
654
- else:
655
- loss = shared.sd_model.forward(x, c)[0] / gradient_step
656
- del x
657
- del c
658
-
659
- _loss_step += loss.item()
660
- scaler.scale(loss).backward()
661
-
662
- # go back until we reach gradient accumulation steps
663
- if (j + 1) % gradient_step != 0:
664
- continue
665
- loss_logging.append(_loss_step)
666
- if clip_grad:
667
- clip_grad(weights, clip_grad_sched.learn_rate)
668
-
669
- scaler.step(optimizer)
670
- scaler.update()
671
- hypernetwork.step += 1
672
- pbar.update()
673
- optimizer.zero_grad(set_to_none=True)
674
- loss_step = _loss_step
675
- _loss_step = 0
676
-
677
- steps_done = hypernetwork.step + 1
678
-
679
- epoch_num = hypernetwork.step // steps_per_epoch
680
- epoch_step = hypernetwork.step % steps_per_epoch
681
-
682
- description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
683
- pbar.set_description(description)
684
- if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
685
- # Before saving, change name to match current checkpoint.
686
- hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
687
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
688
- hypernetwork.optimizer_name = optimizer_name
689
- if shared.opts.save_optimizer_state:
690
- hypernetwork.optimizer_state_dict = optimizer.state_dict()
691
- save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
692
- hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
693
-
694
-
695
-
696
- if shared.opts.training_enable_tensorboard:
697
- epoch_num = hypernetwork.step // len(ds)
698
- epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
699
- mean_loss = sum(loss_logging) / len(loss_logging)
700
- textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
701
-
702
- textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
703
- "loss": f"{loss_step:.7f}",
704
- "learn_rate": scheduler.learn_rate
705
- })
706
-
707
- if images_dir is not None and steps_done % create_image_every == 0:
708
- forced_filename = f'{hypernetwork_name}-{steps_done}'
709
- last_saved_image = os.path.join(images_dir, forced_filename)
710
- hypernetwork.eval()
711
- rng_state = torch.get_rng_state()
712
- cuda_rng_state = None
713
- if torch.cuda.is_available():
714
- cuda_rng_state = torch.cuda.get_rng_state_all()
715
- shared.sd_model.cond_stage_model.to(devices.device)
716
- shared.sd_model.first_stage_model.to(devices.device)
717
-
718
- p = processing.StableDiffusionProcessingTxt2Img(
719
- sd_model=shared.sd_model,
720
- do_not_save_grid=True,
721
- do_not_save_samples=True,
722
- )
723
-
724
- p.disable_extra_networks = True
725
-
726
- if preview_from_txt2img:
727
- p.prompt = preview_prompt
728
- p.negative_prompt = preview_negative_prompt
729
- p.steps = preview_steps
730
- p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
731
- p.cfg_scale = preview_cfg_scale
732
- p.seed = preview_seed
733
- p.width = preview_width
734
- p.height = preview_height
735
- else:
736
- p.prompt = batch.cond_text[0]
737
- p.steps = 20
738
- p.width = training_width
739
- p.height = training_height
740
-
741
- preview_text = p.prompt
742
-
743
- processed = processing.process_images(p)
744
- image = processed.images[0] if len(processed.images) > 0 else None
745
-
746
- if unload:
747
- shared.sd_model.cond_stage_model.to(devices.cpu)
748
- shared.sd_model.first_stage_model.to(devices.cpu)
749
- torch.set_rng_state(rng_state)
750
- if torch.cuda.is_available():
751
- torch.cuda.set_rng_state_all(cuda_rng_state)
752
- hypernetwork.train()
753
- if image is not None:
754
- shared.state.assign_current_image(image)
755
- if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
756
- textual_inversion.tensorboard_add_image(tensorboard_writer,
757
- f"Validation at epoch {epoch_num}", image,
758
- hypernetwork.step)
759
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
760
- last_saved_image += f", prompt: {preview_text}"
761
-
762
- shared.state.job_no = hypernetwork.step
763
-
764
- shared.state.textinfo = f"""
765
- <p>
766
- Loss: {loss_step:.7f}<br/>
767
- Step: {steps_done}<br/>
768
- Last prompt: {html.escape(batch.cond_text[0])}<br/>
769
- Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
770
- Last saved image: {html.escape(last_saved_image)}<br/>
771
- </p>
772
- """
773
- except Exception:
774
- print(traceback.format_exc(), file=sys.stderr)
775
- finally:
776
- pbar.leave = False
777
- pbar.close()
778
- hypernetwork.eval()
779
- #report_statistics(loss_dict)
780
- sd_hijack_checkpoint.remove()
781
-
782
-
783
-
784
- filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
785
- hypernetwork.optimizer_name = optimizer_name
786
- if shared.opts.save_optimizer_state:
787
- hypernetwork.optimizer_state_dict = optimizer.state_dict()
788
- save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
789
-
790
- del optimizer
791
- hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
792
- shared.sd_model.cond_stage_model.to(devices.device)
793
- shared.sd_model.first_stage_model.to(devices.device)
794
- shared.parallel_processing_allowed = old_parallel_processing_allowed
795
-
796
- return hypernetwork, filename
797
-
798
- def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
799
- old_hypernetwork_name = hypernetwork.name
800
- old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
801
- old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
802
- try:
803
- hypernetwork.sd_checkpoint = checkpoint.shorthash
804
- hypernetwork.sd_checkpoint_name = checkpoint.model_name
805
- hypernetwork.name = hypernetwork_name
806
- hypernetwork.save(filename)
807
- except:
808
- hypernetwork.sd_checkpoint = old_sd_checkpoint
809
- hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
810
- hypernetwork.name = old_hypernetwork_name
811
- raise
 
1
+ import csv
2
+ import datetime
3
+ import glob
4
+ import html
5
+ import os
6
+ import sys
7
+ import traceback
8
+ import inspect
9
+
10
+ import modules.textual_inversion.dataset
11
+ import torch
12
+ import tqdm
13
+ from einops import rearrange, repeat
14
+ from ldm.util import default
15
+ from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
16
+ from modules.textual_inversion import textual_inversion, logging
17
+ from modules.textual_inversion.learn_schedule import LearnRateScheduler
18
+ from torch import einsum
19
+ from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
20
+
21
+ from collections import defaultdict, deque
22
+ from statistics import stdev, mean
23
+
24
+
25
+ optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
26
+
27
+ class HypernetworkModule(torch.nn.Module):
28
+ activation_dict = {
29
+ "linear": torch.nn.Identity,
30
+ "relu": torch.nn.ReLU,
31
+ "leakyrelu": torch.nn.LeakyReLU,
32
+ "elu": torch.nn.ELU,
33
+ "swish": torch.nn.Hardswish,
34
+ "tanh": torch.nn.Tanh,
35
+ "sigmoid": torch.nn.Sigmoid,
36
+ }
37
+ activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
38
+
39
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
40
+ add_layer_norm=False, activate_output=False, dropout_structure=None):
41
+ super().__init__()
42
+
43
+ self.multiplier = 1.0
44
+
45
+ assert layer_structure is not None, "layer_structure must not be None"
46
+ assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
47
+ assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
48
+
49
+ linears = []
50
+ for i in range(len(layer_structure) - 1):
51
+
52
+ # Add a fully-connected layer
53
+ linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
54
+
55
+ # Add an activation func except last layer
56
+ if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
57
+ pass
58
+ elif activation_func in self.activation_dict:
59
+ linears.append(self.activation_dict[activation_func]())
60
+ else:
61
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
62
+
63
+ # Add layer normalization
64
+ if add_layer_norm:
65
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
66
+
67
+ # Everything should be now parsed into dropout structure, and applied here.
68
+ # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
69
+ if dropout_structure is not None and dropout_structure[i+1] > 0:
70
+ assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
71
+ linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
72
+ # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
73
+
74
+ self.linear = torch.nn.Sequential(*linears)
75
+
76
+ if state_dict is not None:
77
+ self.fix_old_state_dict(state_dict)
78
+ self.load_state_dict(state_dict)
79
+ else:
80
+ for layer in self.linear:
81
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
82
+ w, b = layer.weight.data, layer.bias.data
83
+ if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
84
+ normal_(w, mean=0.0, std=0.01)
85
+ normal_(b, mean=0.0, std=0)
86
+ elif weight_init == 'XavierUniform':
87
+ xavier_uniform_(w)
88
+ zeros_(b)
89
+ elif weight_init == 'XavierNormal':
90
+ xavier_normal_(w)
91
+ zeros_(b)
92
+ elif weight_init == 'KaimingUniform':
93
+ kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
94
+ zeros_(b)
95
+ elif weight_init == 'KaimingNormal':
96
+ kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
97
+ zeros_(b)
98
+ else:
99
+ raise KeyError(f"Key {weight_init} is not defined as initialization!")
100
+ self.to(devices.device)
101
+
102
+ def fix_old_state_dict(self, state_dict):
103
+ changes = {
104
+ 'linear1.bias': 'linear.0.bias',
105
+ 'linear1.weight': 'linear.0.weight',
106
+ 'linear2.bias': 'linear.1.bias',
107
+ 'linear2.weight': 'linear.1.weight',
108
+ }
109
+
110
+ for fr, to in changes.items():
111
+ x = state_dict.get(fr, None)
112
+ if x is None:
113
+ continue
114
+
115
+ del state_dict[fr]
116
+ state_dict[to] = x
117
+
118
+ def forward(self, x):
119
+ return x + self.linear(x) * (self.multiplier if not self.training else 1)
120
+
121
+ def trainables(self):
122
+ layer_structure = []
123
+ for layer in self.linear:
124
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
125
+ layer_structure += [layer.weight, layer.bias]
126
+ return layer_structure
127
+
128
+
129
+ #param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
130
+ def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
131
+ if layer_structure is None:
132
+ layer_structure = [1, 2, 1]
133
+ if not use_dropout:
134
+ return [0] * len(layer_structure)
135
+ dropout_values = [0]
136
+ dropout_values.extend([0.3] * (len(layer_structure) - 3))
137
+ if last_layer_dropout:
138
+ dropout_values.append(0.3)
139
+ else:
140
+ dropout_values.append(0)
141
+ dropout_values.append(0)
142
+ return dropout_values
143
+
144
+
145
+ class Hypernetwork:
146
+ filename = None
147
+ name = None
148
+
149
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
150
+ self.filename = None
151
+ self.name = name
152
+ self.layers = {}
153
+ self.step = 0
154
+ self.sd_checkpoint = None
155
+ self.sd_checkpoint_name = None
156
+ self.layer_structure = layer_structure
157
+ self.activation_func = activation_func
158
+ self.weight_init = weight_init
159
+ self.add_layer_norm = add_layer_norm
160
+ self.use_dropout = use_dropout
161
+ self.activate_output = activate_output
162
+ self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
163
+ self.dropout_structure = kwargs.get('dropout_structure', None)
164
+ if self.dropout_structure is None:
165
+ self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
166
+ self.optimizer_name = None
167
+ self.optimizer_state_dict = None
168
+ self.optional_info = None
169
+
170
+ for size in enable_sizes or []:
171
+ self.layers[size] = (
172
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
173
+ self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
174
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
175
+ self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
176
+ )
177
+ self.eval()
178
+
179
+ def weights(self):
180
+ res = []
181
+ for k, layers in self.layers.items():
182
+ for layer in layers:
183
+ res += layer.parameters()
184
+ return res
185
+
186
+ def train(self, mode=True):
187
+ for k, layers in self.layers.items():
188
+ for layer in layers:
189
+ layer.train(mode=mode)
190
+ for param in layer.parameters():
191
+ param.requires_grad = mode
192
+
193
+ def to(self, device):
194
+ for k, layers in self.layers.items():
195
+ for layer in layers:
196
+ layer.to(device)
197
+
198
+ return self
199
+
200
+ def set_multiplier(self, multiplier):
201
+ for k, layers in self.layers.items():
202
+ for layer in layers:
203
+ layer.multiplier = multiplier
204
+
205
+ return self
206
+
207
+ def eval(self):
208
+ for k, layers in self.layers.items():
209
+ for layer in layers:
210
+ layer.eval()
211
+ for param in layer.parameters():
212
+ param.requires_grad = False
213
+
214
+ def save(self, filename):
215
+ state_dict = {}
216
+ optimizer_saved_dict = {}
217
+
218
+ for k, v in self.layers.items():
219
+ state_dict[k] = (v[0].state_dict(), v[1].state_dict())
220
+
221
+ state_dict['step'] = self.step
222
+ state_dict['name'] = self.name
223
+ state_dict['layer_structure'] = self.layer_structure
224
+ state_dict['activation_func'] = self.activation_func
225
+ state_dict['is_layer_norm'] = self.add_layer_norm
226
+ state_dict['weight_initialization'] = self.weight_init
227
+ state_dict['sd_checkpoint'] = self.sd_checkpoint
228
+ state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
229
+ state_dict['activate_output'] = self.activate_output
230
+ state_dict['use_dropout'] = self.use_dropout
231
+ state_dict['dropout_structure'] = self.dropout_structure
232
+ state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
233
+ state_dict['optional_info'] = self.optional_info if self.optional_info else None
234
+
235
+ if self.optimizer_name is not None:
236
+ optimizer_saved_dict['optimizer_name'] = self.optimizer_name
237
+
238
+ torch.save(state_dict, filename)
239
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict:
240
+ optimizer_saved_dict['hash'] = self.shorthash()
241
+ optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
242
+ torch.save(optimizer_saved_dict, filename + '.optim')
243
+
244
+ def load(self, filename):
245
+ self.filename = filename
246
+ if self.name is None:
247
+ self.name = os.path.splitext(os.path.basename(filename))[0]
248
+
249
+ state_dict = torch.load(filename, map_location='cpu')
250
+
251
+ self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
252
+ self.optional_info = state_dict.get('optional_info', None)
253
+ self.activation_func = state_dict.get('activation_func', None)
254
+ self.weight_init = state_dict.get('weight_initialization', 'Normal')
255
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
256
+ self.dropout_structure = state_dict.get('dropout_structure', None)
257
+ self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
258
+ self.activate_output = state_dict.get('activate_output', True)
259
+ self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
260
+ # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
261
+ if self.dropout_structure is None:
262
+ self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
263
+
264
+ if shared.opts.print_hypernet_extra:
265
+ if self.optional_info is not None:
266
+ print(f" INFO:\n {self.optional_info}\n")
267
+
268
+ print(f" Layer structure: {self.layer_structure}")
269
+ print(f" Activation function: {self.activation_func}")
270
+ print(f" Weight initialization: {self.weight_init}")
271
+ print(f" Layer norm: {self.add_layer_norm}")
272
+ print(f" Dropout usage: {self.use_dropout}" )
273
+ print(f" Activate last layer: {self.activate_output}")
274
+ print(f" Dropout structure: {self.dropout_structure}")
275
+
276
+ optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
277
+
278
+ if self.shorthash() == optimizer_saved_dict.get('hash', None):
279
+ self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
280
+ else:
281
+ self.optimizer_state_dict = None
282
+ if self.optimizer_state_dict:
283
+ self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
284
+ if shared.opts.print_hypernet_extra:
285
+ print("Loaded existing optimizer from checkpoint")
286
+ print(f"Optimizer name is {self.optimizer_name}")
287
+ else:
288
+ self.optimizer_name = "AdamW"
289
+ if shared.opts.print_hypernet_extra:
290
+ print("No saved optimizer exists in checkpoint")
291
+
292
+ for size, sd in state_dict.items():
293
+ if type(size) == int:
294
+ self.layers[size] = (
295
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
296
+ self.add_layer_norm, self.activate_output, self.dropout_structure),
297
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
298
+ self.add_layer_norm, self.activate_output, self.dropout_structure),
299
+ )
300
+
301
+ self.name = state_dict.get('name', self.name)
302
+ self.step = state_dict.get('step', 0)
303
+ self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
304
+ self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
305
+ self.eval()
306
+
307
+ def shorthash(self):
308
+ sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
309
+
310
+ return sha256[0:10] if sha256 else None
311
+
312
+
313
+ def list_hypernetworks(path):
314
+ res = {}
315
+ for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
316
+ name = os.path.splitext(os.path.basename(filename))[0]
317
+ # Prevent a hypothetical "None.pt" from being listed.
318
+ if name != "None":
319
+ res[name] = filename
320
+ return res
321
+
322
+
323
+ def load_hypernetwork(name):
324
+ path = shared.hypernetworks.get(name, None)
325
+
326
+ if path is None:
327
+ return None
328
+
329
+ hypernetwork = Hypernetwork()
330
+
331
+ try:
332
+ hypernetwork.load(path)
333
+ except Exception:
334
+ print(f"Error loading hypernetwork {path}", file=sys.stderr)
335
+ print(traceback.format_exc(), file=sys.stderr)
336
+ return None
337
+
338
+ return hypernetwork
339
+
340
+
341
+ def load_hypernetworks(names, multipliers=None):
342
+ already_loaded = {}
343
+
344
+ for hypernetwork in shared.loaded_hypernetworks:
345
+ if hypernetwork.name in names:
346
+ already_loaded[hypernetwork.name] = hypernetwork
347
+
348
+ shared.loaded_hypernetworks.clear()
349
+
350
+ for i, name in enumerate(names):
351
+ hypernetwork = already_loaded.get(name, None)
352
+ if hypernetwork is None:
353
+ hypernetwork = load_hypernetwork(name)
354
+
355
+ if hypernetwork is None:
356
+ continue
357
+
358
+ hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
359
+ shared.loaded_hypernetworks.append(hypernetwork)
360
+
361
+
362
+ def find_closest_hypernetwork_name(search: str):
363
+ if not search:
364
+ return None
365
+ search = search.lower()
366
+ applicable = [name for name in shared.hypernetworks if search in name.lower()]
367
+ if not applicable:
368
+ return None
369
+ applicable = sorted(applicable, key=lambda name: len(name))
370
+ return applicable[0]
371
+
372
+
373
+ def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
374
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
375
+
376
+ if hypernetwork_layers is None:
377
+ return context_k, context_v
378
+
379
+ if layer is not None:
380
+ layer.hyper_k = hypernetwork_layers[0]
381
+ layer.hyper_v = hypernetwork_layers[1]
382
+
383
+ context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
384
+ context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
385
+ return context_k, context_v
386
+
387
+
388
+ def apply_hypernetworks(hypernetworks, context, layer=None):
389
+ context_k = context
390
+ context_v = context
391
+ for hypernetwork in hypernetworks:
392
+ context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
393
+
394
+ return context_k, context_v
395
+
396
+
397
+ def attention_CrossAttention_forward(self, x, context=None, mask=None):
398
+ h = self.heads
399
+
400
+ q = self.to_q(x)
401
+ context = default(context, x)
402
+
403
+ context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
404
+ k = self.to_k(context_k)
405
+ v = self.to_v(context_v)
406
+
407
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
408
+
409
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
410
+
411
+ if mask is not None:
412
+ mask = rearrange(mask, 'b ... -> b (...)')
413
+ max_neg_value = -torch.finfo(sim.dtype).max
414
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
415
+ sim.masked_fill_(~mask, max_neg_value)
416
+
417
+ # attention, what we cannot get enough of
418
+ attn = sim.softmax(dim=-1)
419
+
420
+ out = einsum('b i j, b j d -> b i d', attn, v)
421
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
422
+ return self.to_out(out)
423
+
424
+
425
+ def stack_conds(conds):
426
+ if len(conds) == 1:
427
+ return torch.stack(conds)
428
+
429
+ # same as in reconstruct_multicond_batch
430
+ token_count = max([x.shape[0] for x in conds])
431
+ for i in range(len(conds)):
432
+ if conds[i].shape[0] != token_count:
433
+ last_vector = conds[i][-1:]
434
+ last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
435
+ conds[i] = torch.vstack([conds[i], last_vector_repeated])
436
+
437
+ return torch.stack(conds)
438
+
439
+
440
+ def statistics(data):
441
+ if len(data) < 2:
442
+ std = 0
443
+ else:
444
+ std = stdev(data)
445
+ total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
446
+ recent_data = data[-32:]
447
+ if len(recent_data) < 2:
448
+ std = 0
449
+ else:
450
+ std = stdev(recent_data)
451
+ recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
452
+ return total_information, recent_information
453
+
454
+
455
+ def report_statistics(loss_info:dict):
456
+ keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
457
+ for key in keys:
458
+ try:
459
+ print("Loss statistics for file " + key)
460
+ info, recent = statistics(list(loss_info[key]))
461
+ print(info)
462
+ print(recent)
463
+ except Exception as e:
464
+ print(e)
465
+
466
+
467
+ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
468
+ # Remove illegal characters from name.
469
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
470
+ assert name, "Name cannot be empty!"
471
+
472
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
473
+ if not overwrite_old:
474
+ assert not os.path.exists(fn), f"file {fn} already exists"
475
+
476
+ if type(layer_structure) == str:
477
+ layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
478
+
479
+ if use_dropout and dropout_structure and type(dropout_structure) == str:
480
+ dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
481
+ else:
482
+ dropout_structure = [0] * len(layer_structure)
483
+
484
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
485
+ name=name,
486
+ enable_sizes=[int(x) for x in enable_sizes],
487
+ layer_structure=layer_structure,
488
+ activation_func=activation_func,
489
+ weight_init=weight_init,
490
+ add_layer_norm=add_layer_norm,
491
+ use_dropout=use_dropout,
492
+ dropout_structure=dropout_structure
493
+ )
494
+ hypernet.save(fn)
495
+
496
+ shared.reload_hypernetworks()
497
+
498
+
499
+ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
500
+ # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
501
+ from modules import images
502
+
503
+ save_hypernetwork_every = save_hypernetwork_every or 0
504
+ create_image_every = create_image_every or 0
505
+ template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
506
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
507
+ template_file = template_file.path
508
+
509
+ path = shared.hypernetworks.get(hypernetwork_name, None)
510
+ hypernetwork = Hypernetwork()
511
+ hypernetwork.load(path)
512
+ shared.loaded_hypernetworks = [hypernetwork]
513
+
514
+ shared.state.job = "train-hypernetwork"
515
+ shared.state.textinfo = "Initializing hypernetwork training..."
516
+ shared.state.job_count = steps
517
+
518
+ hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
519
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
520
+
521
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
522
+ unload = shared.opts.unload_models_when_training
523
+
524
+ if save_hypernetwork_every > 0:
525
+ hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
526
+ os.makedirs(hypernetwork_dir, exist_ok=True)
527
+ else:
528
+ hypernetwork_dir = None
529
+
530
+ if create_image_every > 0:
531
+ images_dir = os.path.join(log_directory, "images")
532
+ os.makedirs(images_dir, exist_ok=True)
533
+ else:
534
+ images_dir = None
535
+
536
+ checkpoint = sd_models.select_checkpoint()
537
+
538
+ initial_step = hypernetwork.step or 0
539
+ if initial_step >= steps:
540
+ shared.state.textinfo = "Model has already been trained beyond specified max steps"
541
+ return hypernetwork, filename
542
+
543
+ scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
544
+
545
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
546
+ if clip_grad:
547
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
548
+
549
+ if shared.opts.training_enable_tensorboard:
550
+ tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
551
+
552
+ # dataset loading may take a while, so input validations and early returns should be done before this
553
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
554
+
555
+ pin_memory = shared.opts.pin_memory
556
+
557
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
558
+
559
+ if shared.opts.save_training_settings_to_txt:
560
+ saved_params = dict(
561
+ model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
562
+ **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
563
+ )
564
+ logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
565
+
566
+ latent_sampling_method = ds.latent_sampling_method
567
+
568
+ dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
569
+
570
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
571
+
572
+ if unload:
573
+ shared.parallel_processing_allowed = False
574
+ shared.sd_model.cond_stage_model.to(devices.cpu)
575
+ shared.sd_model.first_stage_model.to(devices.cpu)
576
+
577
+ weights = hypernetwork.weights()
578
+ hypernetwork.train()
579
+
580
+ # Here we use optimizer from saved HN, or we can specify as UI option.
581
+ if hypernetwork.optimizer_name in optimizer_dict:
582
+ optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
583
+ optimizer_name = hypernetwork.optimizer_name
584
+ else:
585
+ print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
586
+ optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
587
+ optimizer_name = 'AdamW'
588
+
589
+ if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
590
+ try:
591
+ optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
592
+ except RuntimeError as e:
593
+ print("Cannot resume from saved optimizer!")
594
+ print(e)
595
+
596
+ scaler = torch.cuda.amp.GradScaler()
597
+
598
+ batch_size = ds.batch_size
599
+ gradient_step = ds.gradient_step
600
+ # n steps = batch_size * gradient_step * n image processed
601
+ steps_per_epoch = len(ds) // batch_size // gradient_step
602
+ max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
603
+ loss_step = 0
604
+ _loss_step = 0 #internal
605
+ # size = len(ds.indexes)
606
+ # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
607
+ loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
608
+ # losses = torch.zeros((size,))
609
+ # previous_mean_losses = [0]
610
+ # previous_mean_loss = 0
611
+ # print("Mean loss of {} elements".format(size))
612
+
613
+ steps_without_grad = 0
614
+
615
+ last_saved_file = "<none>"
616
+ last_saved_image = "<none>"
617
+ forced_filename = "<none>"
618
+
619
+ pbar = tqdm.tqdm(total=steps - initial_step)
620
+ try:
621
+ sd_hijack_checkpoint.add()
622
+
623
+ for i in range((steps-initial_step) * gradient_step):
624
+ if scheduler.finished:
625
+ break
626
+ if shared.state.interrupted:
627
+ break
628
+ for j, batch in enumerate(dl):
629
+ # works as a drop_last=True for gradient accumulation
630
+ if j == max_steps_per_epoch:
631
+ break
632
+ scheduler.apply(optimizer, hypernetwork.step)
633
+ if scheduler.finished:
634
+ break
635
+ if shared.state.interrupted:
636
+ break
637
+
638
+ if clip_grad:
639
+ clip_grad_sched.step(hypernetwork.step)
640
+
641
+ with devices.autocast():
642
+ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
643
+ if use_weight:
644
+ w = batch.weight.to(devices.device, non_blocking=pin_memory)
645
+ if tag_drop_out != 0 or shuffle_tags:
646
+ shared.sd_model.cond_stage_model.to(devices.device)
647
+ c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
648
+ shared.sd_model.cond_stage_model.to(devices.cpu)
649
+ else:
650
+ c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
651
+ if use_weight:
652
+ loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
653
+ del w
654
+ else:
655
+ loss = shared.sd_model.forward(x, c)[0] / gradient_step
656
+ del x
657
+ del c
658
+
659
+ _loss_step += loss.item()
660
+ scaler.scale(loss).backward()
661
+
662
+ # go back until we reach gradient accumulation steps
663
+ if (j + 1) % gradient_step != 0:
664
+ continue
665
+ loss_logging.append(_loss_step)
666
+ if clip_grad:
667
+ clip_grad(weights, clip_grad_sched.learn_rate)
668
+
669
+ scaler.step(optimizer)
670
+ scaler.update()
671
+ hypernetwork.step += 1
672
+ pbar.update()
673
+ optimizer.zero_grad(set_to_none=True)
674
+ loss_step = _loss_step
675
+ _loss_step = 0
676
+
677
+ steps_done = hypernetwork.step + 1
678
+
679
+ epoch_num = hypernetwork.step // steps_per_epoch
680
+ epoch_step = hypernetwork.step % steps_per_epoch
681
+
682
+ description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
683
+ pbar.set_description(description)
684
+ if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
685
+ # Before saving, change name to match current checkpoint.
686
+ hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
687
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
688
+ hypernetwork.optimizer_name = optimizer_name
689
+ if shared.opts.save_optimizer_state:
690
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
691
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
692
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
693
+
694
+
695
+
696
+ if shared.opts.training_enable_tensorboard:
697
+ epoch_num = hypernetwork.step // len(ds)
698
+ epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
699
+ mean_loss = sum(loss_logging) / len(loss_logging)
700
+ textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
701
+
702
+ textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
703
+ "loss": f"{loss_step:.7f}",
704
+ "learn_rate": scheduler.learn_rate
705
+ })
706
+
707
+ if images_dir is not None and steps_done % create_image_every == 0:
708
+ forced_filename = f'{hypernetwork_name}-{steps_done}'
709
+ last_saved_image = os.path.join(images_dir, forced_filename)
710
+ hypernetwork.eval()
711
+ rng_state = torch.get_rng_state()
712
+ cuda_rng_state = None
713
+ if torch.cuda.is_available():
714
+ cuda_rng_state = torch.cuda.get_rng_state_all()
715
+ shared.sd_model.cond_stage_model.to(devices.device)
716
+ shared.sd_model.first_stage_model.to(devices.device)
717
+
718
+ p = processing.StableDiffusionProcessingTxt2Img(
719
+ sd_model=shared.sd_model,
720
+ do_not_save_grid=True,
721
+ do_not_save_samples=True,
722
+ )
723
+
724
+ p.disable_extra_networks = True
725
+
726
+ if preview_from_txt2img:
727
+ p.prompt = preview_prompt
728
+ p.negative_prompt = preview_negative_prompt
729
+ p.steps = preview_steps
730
+ p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
731
+ p.cfg_scale = preview_cfg_scale
732
+ p.seed = preview_seed
733
+ p.width = preview_width
734
+ p.height = preview_height
735
+ else:
736
+ p.prompt = batch.cond_text[0]
737
+ p.steps = 20
738
+ p.width = training_width
739
+ p.height = training_height
740
+
741
+ preview_text = p.prompt
742
+
743
+ processed = processing.process_images(p)
744
+ image = processed.images[0] if len(processed.images) > 0 else None
745
+
746
+ if unload:
747
+ shared.sd_model.cond_stage_model.to(devices.cpu)
748
+ shared.sd_model.first_stage_model.to(devices.cpu)
749
+ torch.set_rng_state(rng_state)
750
+ if torch.cuda.is_available():
751
+ torch.cuda.set_rng_state_all(cuda_rng_state)
752
+ hypernetwork.train()
753
+ if image is not None:
754
+ shared.state.assign_current_image(image)
755
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
756
+ textual_inversion.tensorboard_add_image(tensorboard_writer,
757
+ f"Validation at epoch {epoch_num}", image,
758
+ hypernetwork.step)
759
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
760
+ last_saved_image += f", prompt: {preview_text}"
761
+
762
+ shared.state.job_no = hypernetwork.step
763
+
764
+ shared.state.textinfo = f"""
765
+ <p>
766
+ Loss: {loss_step:.7f}<br/>
767
+ Step: {steps_done}<br/>
768
+ Last prompt: {html.escape(batch.cond_text[0])}<br/>
769
+ Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
770
+ Last saved image: {html.escape(last_saved_image)}<br/>
771
+ </p>
772
+ """
773
+ except Exception:
774
+ print(traceback.format_exc(), file=sys.stderr)
775
+ finally:
776
+ pbar.leave = False
777
+ pbar.close()
778
+ hypernetwork.eval()
779
+ #report_statistics(loss_dict)
780
+ sd_hijack_checkpoint.remove()
781
+
782
+
783
+
784
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
785
+ hypernetwork.optimizer_name = optimizer_name
786
+ if shared.opts.save_optimizer_state:
787
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
788
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
789
+
790
+ del optimizer
791
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
792
+ shared.sd_model.cond_stage_model.to(devices.device)
793
+ shared.sd_model.first_stage_model.to(devices.device)
794
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
795
+
796
+ return hypernetwork, filename
797
+
798
+ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
799
+ old_hypernetwork_name = hypernetwork.name
800
+ old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
801
+ old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
802
+ try:
803
+ hypernetwork.sd_checkpoint = checkpoint.shorthash
804
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
805
+ hypernetwork.name = hypernetwork_name
806
+ hypernetwork.save(filename)
807
+ except:
808
+ hypernetwork.sd_checkpoint = old_sd_checkpoint
809
+ hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
810
+ hypernetwork.name = old_hypernetwork_name
811
+ raise
sd/stable-diffusion-webui/modules/hypernetworks/ui.py CHANGED
@@ -1,40 +1,40 @@
1
- import html
2
- import os
3
- import re
4
-
5
- import gradio as gr
6
- import modules.hypernetworks.hypernetwork
7
- from modules import devices, sd_hijack, shared
8
-
9
- not_available = ["hardswish", "multiheadattention"]
10
- keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
11
-
12
-
13
- def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
14
- filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
15
-
16
- return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
17
-
18
-
19
- def train_hypernetwork(*args):
20
- shared.loaded_hypernetworks = []
21
-
22
- assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
23
-
24
- try:
25
- sd_hijack.undo_optimizations()
26
-
27
- hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
28
-
29
- res = f"""
30
- Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
31
- Hypernetwork saved to {html.escape(filename)}
32
- """
33
- return res, ""
34
- except Exception:
35
- raise
36
- finally:
37
- shared.sd_model.cond_stage_model.to(devices.device)
38
- shared.sd_model.first_stage_model.to(devices.device)
39
- sd_hijack.apply_optimizations()
40
-
 
1
+ import html
2
+ import os
3
+ import re
4
+
5
+ import gradio as gr
6
+ import modules.hypernetworks.hypernetwork
7
+ from modules import devices, sd_hijack, shared
8
+
9
+ not_available = ["hardswish", "multiheadattention"]
10
+ keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
11
+
12
+
13
+ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
14
+ filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
15
+
16
+ return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
17
+
18
+
19
+ def train_hypernetwork(*args):
20
+ shared.loaded_hypernetworks = []
21
+
22
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
23
+
24
+ try:
25
+ sd_hijack.undo_optimizations()
26
+
27
+ hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
28
+
29
+ res = f"""
30
+ Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
31
+ Hypernetwork saved to {html.escape(filename)}
32
+ """
33
+ return res, ""
34
+ except Exception:
35
+ raise
36
+ finally:
37
+ shared.sd_model.cond_stage_model.to(devices.device)
38
+ shared.sd_model.first_stage_model.to(devices.device)
39
+ sd_hijack.apply_optimizations()
40
+
sd/stable-diffusion-webui/modules/images.py CHANGED
@@ -1,669 +1,669 @@
1
- import datetime
2
- import sys
3
- import traceback
4
-
5
- import pytz
6
- import io
7
- import math
8
- import os
9
- from collections import namedtuple
10
- import re
11
-
12
- import numpy as np
13
- import piexif
14
- import piexif.helper
15
- from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
16
- from fonts.ttf import Roboto
17
- import string
18
- import json
19
- import hashlib
20
-
21
- from modules import sd_samplers, shared, script_callbacks, errors
22
- from modules.shared import opts, cmd_opts
23
-
24
- LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
25
-
26
-
27
- def image_grid(imgs, batch_size=1, rows=None):
28
- if rows is None:
29
- if opts.n_rows > 0:
30
- rows = opts.n_rows
31
- elif opts.n_rows == 0:
32
- rows = batch_size
33
- elif opts.grid_prevent_empty_spots:
34
- rows = math.floor(math.sqrt(len(imgs)))
35
- while len(imgs) % rows != 0:
36
- rows -= 1
37
- else:
38
- rows = math.sqrt(len(imgs))
39
- rows = round(rows)
40
- if rows > len(imgs):
41
- rows = len(imgs)
42
-
43
- cols = math.ceil(len(imgs) / rows)
44
-
45
- params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
46
- script_callbacks.image_grid_callback(params)
47
-
48
- w, h = imgs[0].size
49
- grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
50
-
51
- for i, img in enumerate(params.imgs):
52
- grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
53
-
54
- return grid
55
-
56
-
57
- Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
58
-
59
-
60
- def split_grid(image, tile_w=512, tile_h=512, overlap=64):
61
- w = image.width
62
- h = image.height
63
-
64
- non_overlap_width = tile_w - overlap
65
- non_overlap_height = tile_h - overlap
66
-
67
- cols = math.ceil((w - overlap) / non_overlap_width)
68
- rows = math.ceil((h - overlap) / non_overlap_height)
69
-
70
- dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
71
- dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
72
-
73
- grid = Grid([], tile_w, tile_h, w, h, overlap)
74
- for row in range(rows):
75
- row_images = []
76
-
77
- y = int(row * dy)
78
-
79
- if y + tile_h >= h:
80
- y = h - tile_h
81
-
82
- for col in range(cols):
83
- x = int(col * dx)
84
-
85
- if x + tile_w >= w:
86
- x = w - tile_w
87
-
88
- tile = image.crop((x, y, x + tile_w, y + tile_h))
89
-
90
- row_images.append([x, tile_w, tile])
91
-
92
- grid.tiles.append([y, tile_h, row_images])
93
-
94
- return grid
95
-
96
-
97
- def combine_grid(grid):
98
- def make_mask_image(r):
99
- r = r * 255 / grid.overlap
100
- r = r.astype(np.uint8)
101
- return Image.fromarray(r, 'L')
102
-
103
- mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
104
- mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
105
-
106
- combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
107
- for y, h, row in grid.tiles:
108
- combined_row = Image.new("RGB", (grid.image_w, h))
109
- for x, w, tile in row:
110
- if x == 0:
111
- combined_row.paste(tile, (0, 0))
112
- continue
113
-
114
- combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
115
- combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
116
-
117
- if y == 0:
118
- combined_image.paste(combined_row, (0, 0))
119
- continue
120
-
121
- combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
122
- combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
123
-
124
- return combined_image
125
-
126
-
127
- class GridAnnotation:
128
- def __init__(self, text='', is_active=True):
129
- self.text = text
130
- self.is_active = is_active
131
- self.size = None
132
-
133
-
134
- def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
135
- def wrap(drawing, text, font, line_length):
136
- lines = ['']
137
- for word in text.split():
138
- line = f'{lines[-1]} {word}'.strip()
139
- if drawing.textlength(line, font=font) <= line_length:
140
- lines[-1] = line
141
- else:
142
- lines.append(word)
143
- return lines
144
-
145
- def get_font(fontsize):
146
- try:
147
- return ImageFont.truetype(opts.font or Roboto, fontsize)
148
- except Exception:
149
- return ImageFont.truetype(Roboto, fontsize)
150
-
151
- def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
152
- for i, line in enumerate(lines):
153
- fnt = initial_fnt
154
- fontsize = initial_fontsize
155
- while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
156
- fontsize -= 1
157
- fnt = get_font(fontsize)
158
- drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
159
-
160
- if not line.is_active:
161
- drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
162
-
163
- draw_y += line.size[1] + line_spacing
164
-
165
- fontsize = (width + height) // 25
166
- line_spacing = fontsize // 2
167
-
168
- fnt = get_font(fontsize)
169
-
170
- color_active = (0, 0, 0)
171
- color_inactive = (153, 153, 153)
172
-
173
- pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
174
-
175
- cols = im.width // width
176
- rows = im.height // height
177
-
178
- assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
179
- assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
180
-
181
- calc_img = Image.new("RGB", (1, 1), "white")
182
- calc_d = ImageDraw.Draw(calc_img)
183
-
184
- for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
185
- items = [] + texts
186
- texts.clear()
187
-
188
- for line in items:
189
- wrapped = wrap(calc_d, line.text, fnt, allowed_width)
190
- texts += [GridAnnotation(x, line.is_active) for x in wrapped]
191
-
192
- for line in texts:
193
- bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
194
- line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
195
- line.allowed_width = allowed_width
196
-
197
- hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
198
- ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
199
-
200
- pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
201
-
202
- result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
203
-
204
- for row in range(rows):
205
- for col in range(cols):
206
- cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
207
- result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
208
-
209
- d = ImageDraw.Draw(result)
210
-
211
- for col in range(cols):
212
- x = pad_left + (width + margin) * col + width / 2
213
- y = pad_top / 2 - hor_text_heights[col] / 2
214
-
215
- draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
216
-
217
- for row in range(rows):
218
- x = pad_left / 2
219
- y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
220
-
221
- draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
222
-
223
- return result
224
-
225
-
226
- def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
227
- prompts = all_prompts[1:]
228
- boundary = math.ceil(len(prompts) / 2)
229
-
230
- prompts_horiz = prompts[:boundary]
231
- prompts_vert = prompts[boundary:]
232
-
233
- hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
234
- ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
235
-
236
- return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
237
-
238
-
239
- def resize_image(resize_mode, im, width, height, upscaler_name=None):
240
- """
241
- Resizes an image with the specified resize_mode, width, and height.
242
-
243
- Args:
244
- resize_mode: The mode to use when resizing the image.
245
- 0: Resize the image to the specified width and height.
246
- 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
247
- 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
248
- im: The image to resize.
249
- width: The width to resize the image to.
250
- height: The height to resize the image to.
251
- upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
252
- """
253
-
254
- upscaler_name = upscaler_name or opts.upscaler_for_img2img
255
-
256
- def resize(im, w, h):
257
- if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
258
- return im.resize((w, h), resample=LANCZOS)
259
-
260
- scale = max(w / im.width, h / im.height)
261
-
262
- if scale > 1.0:
263
- upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
264
- assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
265
-
266
- upscaler = upscalers[0]
267
- im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
268
-
269
- if im.width != w or im.height != h:
270
- im = im.resize((w, h), resample=LANCZOS)
271
-
272
- return im
273
-
274
- if resize_mode == 0:
275
- res = resize(im, width, height)
276
-
277
- elif resize_mode == 1:
278
- ratio = width / height
279
- src_ratio = im.width / im.height
280
-
281
- src_w = width if ratio > src_ratio else im.width * height // im.height
282
- src_h = height if ratio <= src_ratio else im.height * width // im.width
283
-
284
- resized = resize(im, src_w, src_h)
285
- res = Image.new("RGB", (width, height))
286
- res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
287
-
288
- else:
289
- ratio = width / height
290
- src_ratio = im.width / im.height
291
-
292
- src_w = width if ratio < src_ratio else im.width * height // im.height
293
- src_h = height if ratio >= src_ratio else im.height * width // im.width
294
-
295
- resized = resize(im, src_w, src_h)
296
- res = Image.new("RGB", (width, height))
297
- res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
298
-
299
- if ratio < src_ratio:
300
- fill_height = height // 2 - src_h // 2
301
- res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
302
- res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
303
- elif ratio > src_ratio:
304
- fill_width = width // 2 - src_w // 2
305
- res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
306
- res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
307
-
308
- return res
309
-
310
-
311
- invalid_filename_chars = '<>:"/\\|?*\n'
312
- invalid_filename_prefix = ' '
313
- invalid_filename_postfix = ' .'
314
- re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
315
- re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
316
- re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
317
- max_filename_part_length = 128
318
-
319
-
320
- def sanitize_filename_part(text, replace_spaces=True):
321
- if text is None:
322
- return None
323
-
324
- if replace_spaces:
325
- text = text.replace(' ', '_')
326
-
327
- text = text.translate({ord(x): '_' for x in invalid_filename_chars})
328
- text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
329
- text = text.rstrip(invalid_filename_postfix)
330
- return text
331
-
332
-
333
- class FilenameGenerator:
334
- replacements = {
335
- 'seed': lambda self: self.seed if self.seed is not None else '',
336
- 'steps': lambda self: self.p and self.p.steps,
337
- 'cfg': lambda self: self.p and self.p.cfg_scale,
338
- 'width': lambda self: self.image.width,
339
- 'height': lambda self: self.image.height,
340
- 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
341
- 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
342
- 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
343
- 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
344
- 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
345
- 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
346
- 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
347
- 'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
348
- 'prompt': lambda self: sanitize_filename_part(self.prompt),
349
- 'prompt_no_styles': lambda self: self.prompt_no_style(),
350
- 'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
351
- 'prompt_words': lambda self: self.prompt_words(),
352
- }
353
- default_time_format = '%Y%m%d%H%M%S'
354
-
355
- def __init__(self, p, seed, prompt, image):
356
- self.p = p
357
- self.seed = seed
358
- self.prompt = prompt
359
- self.image = image
360
-
361
- def prompt_no_style(self):
362
- if self.p is None or self.prompt is None:
363
- return None
364
-
365
- prompt_no_style = self.prompt
366
- for style in shared.prompt_styles.get_style_prompts(self.p.styles):
367
- if len(style) > 0:
368
- for part in style.split("{prompt}"):
369
- prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
370
-
371
- prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
372
-
373
- return sanitize_filename_part(prompt_no_style, replace_spaces=False)
374
-
375
- def prompt_words(self):
376
- words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
377
- if len(words) == 0:
378
- words = ["empty"]
379
- return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
380
-
381
- def datetime(self, *args):
382
- time_datetime = datetime.datetime.now()
383
-
384
- time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
385
- try:
386
- time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
387
- except pytz.exceptions.UnknownTimeZoneError as _:
388
- time_zone = None
389
-
390
- time_zone_time = time_datetime.astimezone(time_zone)
391
- try:
392
- formatted_time = time_zone_time.strftime(time_format)
393
- except (ValueError, TypeError) as _:
394
- formatted_time = time_zone_time.strftime(self.default_time_format)
395
-
396
- return sanitize_filename_part(formatted_time, replace_spaces=False)
397
-
398
- def apply(self, x):
399
- res = ''
400
-
401
- for m in re_pattern.finditer(x):
402
- text, pattern = m.groups()
403
- res += text
404
-
405
- if pattern is None:
406
- continue
407
-
408
- pattern_args = []
409
- while True:
410
- m = re_pattern_arg.match(pattern)
411
- if m is None:
412
- break
413
-
414
- pattern, arg = m.groups()
415
- pattern_args.insert(0, arg)
416
-
417
- fun = self.replacements.get(pattern.lower())
418
- if fun is not None:
419
- try:
420
- replacement = fun(self, *pattern_args)
421
- except Exception:
422
- replacement = None
423
- print(f"Error adding [{pattern}] to filename", file=sys.stderr)
424
- print(traceback.format_exc(), file=sys.stderr)
425
-
426
- if replacement is not None:
427
- res += str(replacement)
428
- continue
429
-
430
- res += f'[{pattern}]'
431
-
432
- return res
433
-
434
-
435
- def get_next_sequence_number(path, basename):
436
- """
437
- Determines and returns the next sequence number to use when saving an image in the specified directory.
438
-
439
- The sequence starts at 0.
440
- """
441
- result = -1
442
- if basename != '':
443
- basename = basename + "-"
444
-
445
- prefix_length = len(basename)
446
- for p in os.listdir(path):
447
- if p.startswith(basename):
448
- l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
449
- try:
450
- result = max(int(l[0]), result)
451
- except ValueError:
452
- pass
453
-
454
- return result + 1
455
-
456
-
457
- def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
458
- """Save an image.
459
-
460
- Args:
461
- image (`PIL.Image`):
462
- The image to be saved.
463
- path (`str`):
464
- The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
465
- basename (`str`):
466
- The base filename which will be applied to `filename pattern`.
467
- seed, prompt, short_filename,
468
- extension (`str`):
469
- Image file extension, default is `png`.
470
- pngsectionname (`str`):
471
- Specify the name of the section which `info` will be saved in.
472
- info (`str` or `PngImagePlugin.iTXt`):
473
- PNG info chunks.
474
- existing_info (`dict`):
475
- Additional PNG info. `existing_info == {pngsectionname: info, ...}`
476
- no_prompt:
477
- TODO I don't know its meaning.
478
- p (`StableDiffusionProcessing`)
479
- forced_filename (`str`):
480
- If specified, `basename` and filename pattern will be ignored.
481
- save_to_dirs (bool):
482
- If true, the image will be saved into a subdirectory of `path`.
483
-
484
- Returns: (fullfn, txt_fullfn)
485
- fullfn (`str`):
486
- The full path of the saved imaged.
487
- txt_fullfn (`str` or None):
488
- If a text file is saved for this image, this will be its full path. Otherwise None.
489
- """
490
- namegen = FilenameGenerator(p, seed, prompt, image)
491
-
492
- if save_to_dirs is None:
493
- save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
494
-
495
- if save_to_dirs:
496
- dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
497
- path = os.path.join(path, dirname)
498
-
499
- os.makedirs(path, exist_ok=True)
500
-
501
- if forced_filename is None:
502
- if short_filename or seed is None:
503
- file_decoration = ""
504
- elif opts.save_to_dirs:
505
- file_decoration = opts.samples_filename_pattern or "[seed]"
506
- else:
507
- file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
508
-
509
- add_number = opts.save_images_add_number or file_decoration == ''
510
-
511
- if file_decoration != "" and add_number:
512
- file_decoration = "-" + file_decoration
513
-
514
- file_decoration = namegen.apply(file_decoration) + suffix
515
-
516
- if add_number:
517
- basecount = get_next_sequence_number(path, basename)
518
- fullfn = None
519
- for i in range(500):
520
- fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
521
- fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
522
- if not os.path.exists(fullfn):
523
- break
524
- else:
525
- fullfn = os.path.join(path, f"{file_decoration}.{extension}")
526
- else:
527
- fullfn = os.path.join(path, f"{forced_filename}.{extension}")
528
-
529
- pnginfo = existing_info or {}
530
- if info is not None:
531
- pnginfo[pnginfo_section_name] = info
532
-
533
- params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
534
- script_callbacks.before_image_saved_callback(params)
535
-
536
- image = params.image
537
- fullfn = params.filename
538
- info = params.pnginfo.get(pnginfo_section_name, None)
539
-
540
- def _atomically_save_image(image_to_save, filename_without_extension, extension):
541
- # save image with .tmp extension to avoid race condition when another process detects new image in the directory
542
- temp_file_path = filename_without_extension + ".tmp"
543
- image_format = Image.registered_extensions()[extension]
544
-
545
- if extension.lower() == '.png':
546
- pnginfo_data = PngImagePlugin.PngInfo()
547
- if opts.enable_pnginfo:
548
- for k, v in params.pnginfo.items():
549
- pnginfo_data.add_text(k, str(v))
550
-
551
- image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
552
-
553
- elif extension.lower() in (".jpg", ".jpeg", ".webp"):
554
- if image_to_save.mode == 'RGBA':
555
- image_to_save = image_to_save.convert("RGB")
556
- elif image_to_save.mode == 'I;16':
557
- image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
558
-
559
- image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
560
-
561
- if opts.enable_pnginfo and info is not None:
562
- exif_bytes = piexif.dump({
563
- "Exif": {
564
- piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
565
- },
566
- })
567
-
568
- piexif.insert(exif_bytes, temp_file_path)
569
- else:
570
- image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
571
-
572
- # atomically rename the file with correct extension
573
- os.replace(temp_file_path, filename_without_extension + extension)
574
-
575
- fullfn_without_extension, extension = os.path.splitext(params.filename)
576
- _atomically_save_image(image, fullfn_without_extension, extension)
577
-
578
- image.already_saved_as = fullfn
579
-
580
- oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
581
- if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
582
- ratio = image.width / image.height
583
-
584
- if oversize and ratio > 1:
585
- image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
586
- elif oversize:
587
- image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
588
-
589
- try:
590
- _atomically_save_image(image, fullfn_without_extension, ".jpg")
591
- except Exception as e:
592
- errors.display(e, "saving image as downscaled JPG")
593
-
594
- if opts.save_txt and info is not None:
595
- txt_fullfn = f"{fullfn_without_extension}.txt"
596
- with open(txt_fullfn, "w", encoding="utf8") as file:
597
- file.write(info + "\n")
598
- else:
599
- txt_fullfn = None
600
-
601
- script_callbacks.image_saved_callback(params)
602
-
603
- return fullfn, txt_fullfn
604
-
605
-
606
- def read_info_from_image(image):
607
- items = image.info or {}
608
-
609
- geninfo = items.pop('parameters', None)
610
-
611
- if "exif" in items:
612
- exif = piexif.load(items["exif"])
613
- exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
614
- try:
615
- exif_comment = piexif.helper.UserComment.load(exif_comment)
616
- except ValueError:
617
- exif_comment = exif_comment.decode('utf8', errors="ignore")
618
-
619
- if exif_comment:
620
- items['exif comment'] = exif_comment
621
- geninfo = exif_comment
622
-
623
- for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
624
- 'loop', 'background', 'timestamp', 'duration']:
625
- items.pop(field, None)
626
-
627
- if items.get("Software", None) == "NovelAI":
628
- try:
629
- json_info = json.loads(items["Comment"])
630
- sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
631
-
632
- geninfo = f"""{items["Description"]}
633
- Negative prompt: {json_info["uc"]}
634
- Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
635
- except Exception:
636
- print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
637
- print(traceback.format_exc(), file=sys.stderr)
638
-
639
- return geninfo, items
640
-
641
-
642
- def image_data(data):
643
- try:
644
- image = Image.open(io.BytesIO(data))
645
- textinfo, _ = read_info_from_image(image)
646
- return textinfo, None
647
- except Exception:
648
- pass
649
-
650
- try:
651
- text = data.decode('utf8')
652
- assert len(text) < 10000
653
- return text, None
654
-
655
- except Exception:
656
- pass
657
-
658
- return '', None
659
-
660
-
661
- def flatten(img, bgcolor):
662
- """replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
663
-
664
- if img.mode == "RGBA":
665
- background = Image.new('RGBA', img.size, bgcolor)
666
- background.paste(img, mask=img)
667
- img = background
668
-
669
- return img.convert('RGB')
 
1
+ import datetime
2
+ import sys
3
+ import traceback
4
+
5
+ import pytz
6
+ import io
7
+ import math
8
+ import os
9
+ from collections import namedtuple
10
+ import re
11
+
12
+ import numpy as np
13
+ import piexif
14
+ import piexif.helper
15
+ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
16
+ from fonts.ttf import Roboto
17
+ import string
18
+ import json
19
+ import hashlib
20
+
21
+ from modules import sd_samplers, shared, script_callbacks, errors
22
+ from modules.shared import opts, cmd_opts
23
+
24
+ LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
25
+
26
+
27
+ def image_grid(imgs, batch_size=1, rows=None):
28
+ if rows is None:
29
+ if opts.n_rows > 0:
30
+ rows = opts.n_rows
31
+ elif opts.n_rows == 0:
32
+ rows = batch_size
33
+ elif opts.grid_prevent_empty_spots:
34
+ rows = math.floor(math.sqrt(len(imgs)))
35
+ while len(imgs) % rows != 0:
36
+ rows -= 1
37
+ else:
38
+ rows = math.sqrt(len(imgs))
39
+ rows = round(rows)
40
+ if rows > len(imgs):
41
+ rows = len(imgs)
42
+
43
+ cols = math.ceil(len(imgs) / rows)
44
+
45
+ params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
46
+ script_callbacks.image_grid_callback(params)
47
+
48
+ w, h = imgs[0].size
49
+ grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
50
+
51
+ for i, img in enumerate(params.imgs):
52
+ grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
53
+
54
+ return grid
55
+
56
+
57
+ Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
58
+
59
+
60
+ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
61
+ w = image.width
62
+ h = image.height
63
+
64
+ non_overlap_width = tile_w - overlap
65
+ non_overlap_height = tile_h - overlap
66
+
67
+ cols = math.ceil((w - overlap) / non_overlap_width)
68
+ rows = math.ceil((h - overlap) / non_overlap_height)
69
+
70
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
71
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
72
+
73
+ grid = Grid([], tile_w, tile_h, w, h, overlap)
74
+ for row in range(rows):
75
+ row_images = []
76
+
77
+ y = int(row * dy)
78
+
79
+ if y + tile_h >= h:
80
+ y = h - tile_h
81
+
82
+ for col in range(cols):
83
+ x = int(col * dx)
84
+
85
+ if x + tile_w >= w:
86
+ x = w - tile_w
87
+
88
+ tile = image.crop((x, y, x + tile_w, y + tile_h))
89
+
90
+ row_images.append([x, tile_w, tile])
91
+
92
+ grid.tiles.append([y, tile_h, row_images])
93
+
94
+ return grid
95
+
96
+
97
+ def combine_grid(grid):
98
+ def make_mask_image(r):
99
+ r = r * 255 / grid.overlap
100
+ r = r.astype(np.uint8)
101
+ return Image.fromarray(r, 'L')
102
+
103
+ mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
104
+ mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
105
+
106
+ combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
107
+ for y, h, row in grid.tiles:
108
+ combined_row = Image.new("RGB", (grid.image_w, h))
109
+ for x, w, tile in row:
110
+ if x == 0:
111
+ combined_row.paste(tile, (0, 0))
112
+ continue
113
+
114
+ combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
115
+ combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
116
+
117
+ if y == 0:
118
+ combined_image.paste(combined_row, (0, 0))
119
+ continue
120
+
121
+ combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
122
+ combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
123
+
124
+ return combined_image
125
+
126
+
127
+ class GridAnnotation:
128
+ def __init__(self, text='', is_active=True):
129
+ self.text = text
130
+ self.is_active = is_active
131
+ self.size = None
132
+
133
+
134
+ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
135
+ def wrap(drawing, text, font, line_length):
136
+ lines = ['']
137
+ for word in text.split():
138
+ line = f'{lines[-1]} {word}'.strip()
139
+ if drawing.textlength(line, font=font) <= line_length:
140
+ lines[-1] = line
141
+ else:
142
+ lines.append(word)
143
+ return lines
144
+
145
+ def get_font(fontsize):
146
+ try:
147
+ return ImageFont.truetype(opts.font or Roboto, fontsize)
148
+ except Exception:
149
+ return ImageFont.truetype(Roboto, fontsize)
150
+
151
+ def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
152
+ for i, line in enumerate(lines):
153
+ fnt = initial_fnt
154
+ fontsize = initial_fontsize
155
+ while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
156
+ fontsize -= 1
157
+ fnt = get_font(fontsize)
158
+ drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
159
+
160
+ if not line.is_active:
161
+ drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
162
+
163
+ draw_y += line.size[1] + line_spacing
164
+
165
+ fontsize = (width + height) // 25
166
+ line_spacing = fontsize // 2
167
+
168
+ fnt = get_font(fontsize)
169
+
170
+ color_active = (0, 0, 0)
171
+ color_inactive = (153, 153, 153)
172
+
173
+ pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
174
+
175
+ cols = im.width // width
176
+ rows = im.height // height
177
+
178
+ assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
179
+ assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
180
+
181
+ calc_img = Image.new("RGB", (1, 1), "white")
182
+ calc_d = ImageDraw.Draw(calc_img)
183
+
184
+ for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
185
+ items = [] + texts
186
+ texts.clear()
187
+
188
+ for line in items:
189
+ wrapped = wrap(calc_d, line.text, fnt, allowed_width)
190
+ texts += [GridAnnotation(x, line.is_active) for x in wrapped]
191
+
192
+ for line in texts:
193
+ bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
194
+ line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
195
+ line.allowed_width = allowed_width
196
+
197
+ hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
198
+ ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
199
+
200
+ pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
201
+
202
+ result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
203
+
204
+ for row in range(rows):
205
+ for col in range(cols):
206
+ cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
207
+ result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
208
+
209
+ d = ImageDraw.Draw(result)
210
+
211
+ for col in range(cols):
212
+ x = pad_left + (width + margin) * col + width / 2
213
+ y = pad_top / 2 - hor_text_heights[col] / 2
214
+
215
+ draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
216
+
217
+ for row in range(rows):
218
+ x = pad_left / 2
219
+ y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
220
+
221
+ draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
222
+
223
+ return result
224
+
225
+
226
+ def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
227
+ prompts = all_prompts[1:]
228
+ boundary = math.ceil(len(prompts) / 2)
229
+
230
+ prompts_horiz = prompts[:boundary]
231
+ prompts_vert = prompts[boundary:]
232
+
233
+ hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
234
+ ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
235
+
236
+ return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
237
+
238
+
239
+ def resize_image(resize_mode, im, width, height, upscaler_name=None):
240
+ """
241
+ Resizes an image with the specified resize_mode, width, and height.
242
+
243
+ Args:
244
+ resize_mode: The mode to use when resizing the image.
245
+ 0: Resize the image to the specified width and height.
246
+ 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
247
+ 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
248
+ im: The image to resize.
249
+ width: The width to resize the image to.
250
+ height: The height to resize the image to.
251
+ upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
252
+ """
253
+
254
+ upscaler_name = upscaler_name or opts.upscaler_for_img2img
255
+
256
+ def resize(im, w, h):
257
+ if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
258
+ return im.resize((w, h), resample=LANCZOS)
259
+
260
+ scale = max(w / im.width, h / im.height)
261
+
262
+ if scale > 1.0:
263
+ upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
264
+ assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
265
+
266
+ upscaler = upscalers[0]
267
+ im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
268
+
269
+ if im.width != w or im.height != h:
270
+ im = im.resize((w, h), resample=LANCZOS)
271
+
272
+ return im
273
+
274
+ if resize_mode == 0:
275
+ res = resize(im, width, height)
276
+
277
+ elif resize_mode == 1:
278
+ ratio = width / height
279
+ src_ratio = im.width / im.height
280
+
281
+ src_w = width if ratio > src_ratio else im.width * height // im.height
282
+ src_h = height if ratio <= src_ratio else im.height * width // im.width
283
+
284
+ resized = resize(im, src_w, src_h)
285
+ res = Image.new("RGB", (width, height))
286
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
287
+
288
+ else:
289
+ ratio = width / height
290
+ src_ratio = im.width / im.height
291
+
292
+ src_w = width if ratio < src_ratio else im.width * height // im.height
293
+ src_h = height if ratio >= src_ratio else im.height * width // im.width
294
+
295
+ resized = resize(im, src_w, src_h)
296
+ res = Image.new("RGB", (width, height))
297
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
298
+
299
+ if ratio < src_ratio:
300
+ fill_height = height // 2 - src_h // 2
301
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
302
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
303
+ elif ratio > src_ratio:
304
+ fill_width = width // 2 - src_w // 2
305
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
306
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
307
+
308
+ return res
309
+
310
+
311
+ invalid_filename_chars = '<>:"/\\|?*\n'
312
+ invalid_filename_prefix = ' '
313
+ invalid_filename_postfix = ' .'
314
+ re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
315
+ re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
316
+ re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
317
+ max_filename_part_length = 128
318
+
319
+
320
+ def sanitize_filename_part(text, replace_spaces=True):
321
+ if text is None:
322
+ return None
323
+
324
+ if replace_spaces:
325
+ text = text.replace(' ', '_')
326
+
327
+ text = text.translate({ord(x): '_' for x in invalid_filename_chars})
328
+ text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
329
+ text = text.rstrip(invalid_filename_postfix)
330
+ return text
331
+
332
+
333
+ class FilenameGenerator:
334
+ replacements = {
335
+ 'seed': lambda self: self.seed if self.seed is not None else '',
336
+ 'steps': lambda self: self.p and self.p.steps,
337
+ 'cfg': lambda self: self.p and self.p.cfg_scale,
338
+ 'width': lambda self: self.image.width,
339
+ 'height': lambda self: self.image.height,
340
+ 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
341
+ 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
342
+ 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
343
+ 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
344
+ 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
345
+ 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
346
+ 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
347
+ 'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
348
+ 'prompt': lambda self: sanitize_filename_part(self.prompt),
349
+ 'prompt_no_styles': lambda self: self.prompt_no_style(),
350
+ 'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
351
+ 'prompt_words': lambda self: self.prompt_words(),
352
+ }
353
+ default_time_format = '%Y%m%d%H%M%S'
354
+
355
+ def __init__(self, p, seed, prompt, image):
356
+ self.p = p
357
+ self.seed = seed
358
+ self.prompt = prompt
359
+ self.image = image
360
+
361
+ def prompt_no_style(self):
362
+ if self.p is None or self.prompt is None:
363
+ return None
364
+
365
+ prompt_no_style = self.prompt
366
+ for style in shared.prompt_styles.get_style_prompts(self.p.styles):
367
+ if len(style) > 0:
368
+ for part in style.split("{prompt}"):
369
+ prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
370
+
371
+ prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
372
+
373
+ return sanitize_filename_part(prompt_no_style, replace_spaces=False)
374
+
375
+ def prompt_words(self):
376
+ words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
377
+ if len(words) == 0:
378
+ words = ["empty"]
379
+ return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
380
+
381
+ def datetime(self, *args):
382
+ time_datetime = datetime.datetime.now()
383
+
384
+ time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
385
+ try:
386
+ time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
387
+ except pytz.exceptions.UnknownTimeZoneError as _:
388
+ time_zone = None
389
+
390
+ time_zone_time = time_datetime.astimezone(time_zone)
391
+ try:
392
+ formatted_time = time_zone_time.strftime(time_format)
393
+ except (ValueError, TypeError) as _:
394
+ formatted_time = time_zone_time.strftime(self.default_time_format)
395
+
396
+ return sanitize_filename_part(formatted_time, replace_spaces=False)
397
+
398
+ def apply(self, x):
399
+ res = ''
400
+
401
+ for m in re_pattern.finditer(x):
402
+ text, pattern = m.groups()
403
+ res += text
404
+
405
+ if pattern is None:
406
+ continue
407
+
408
+ pattern_args = []
409
+ while True:
410
+ m = re_pattern_arg.match(pattern)
411
+ if m is None:
412
+ break
413
+
414
+ pattern, arg = m.groups()
415
+ pattern_args.insert(0, arg)
416
+
417
+ fun = self.replacements.get(pattern.lower())
418
+ if fun is not None:
419
+ try:
420
+ replacement = fun(self, *pattern_args)
421
+ except Exception:
422
+ replacement = None
423
+ print(f"Error adding [{pattern}] to filename", file=sys.stderr)
424
+ print(traceback.format_exc(), file=sys.stderr)
425
+
426
+ if replacement is not None:
427
+ res += str(replacement)
428
+ continue
429
+
430
+ res += f'[{pattern}]'
431
+
432
+ return res
433
+
434
+
435
+ def get_next_sequence_number(path, basename):
436
+ """
437
+ Determines and returns the next sequence number to use when saving an image in the specified directory.
438
+
439
+ The sequence starts at 0.
440
+ """
441
+ result = -1
442
+ if basename != '':
443
+ basename = basename + "-"
444
+
445
+ prefix_length = len(basename)
446
+ for p in os.listdir(path):
447
+ if p.startswith(basename):
448
+ l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
449
+ try:
450
+ result = max(int(l[0]), result)
451
+ except ValueError:
452
+ pass
453
+
454
+ return result + 1
455
+
456
+
457
+ def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
458
+ """Save an image.
459
+
460
+ Args:
461
+ image (`PIL.Image`):
462
+ The image to be saved.
463
+ path (`str`):
464
+ The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
465
+ basename (`str`):
466
+ The base filename which will be applied to `filename pattern`.
467
+ seed, prompt, short_filename,
468
+ extension (`str`):
469
+ Image file extension, default is `png`.
470
+ pngsectionname (`str`):
471
+ Specify the name of the section which `info` will be saved in.
472
+ info (`str` or `PngImagePlugin.iTXt`):
473
+ PNG info chunks.
474
+ existing_info (`dict`):
475
+ Additional PNG info. `existing_info == {pngsectionname: info, ...}`
476
+ no_prompt:
477
+ TODO I don't know its meaning.
478
+ p (`StableDiffusionProcessing`)
479
+ forced_filename (`str`):
480
+ If specified, `basename` and filename pattern will be ignored.
481
+ save_to_dirs (bool):
482
+ If true, the image will be saved into a subdirectory of `path`.
483
+
484
+ Returns: (fullfn, txt_fullfn)
485
+ fullfn (`str`):
486
+ The full path of the saved imaged.
487
+ txt_fullfn (`str` or None):
488
+ If a text file is saved for this image, this will be its full path. Otherwise None.
489
+ """
490
+ namegen = FilenameGenerator(p, seed, prompt, image)
491
+
492
+ if save_to_dirs is None:
493
+ save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
494
+
495
+ if save_to_dirs:
496
+ dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
497
+ path = os.path.join(path, dirname)
498
+
499
+ os.makedirs(path, exist_ok=True)
500
+
501
+ if forced_filename is None:
502
+ if short_filename or seed is None:
503
+ file_decoration = ""
504
+ elif opts.save_to_dirs:
505
+ file_decoration = opts.samples_filename_pattern or "[seed]"
506
+ else:
507
+ file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
508
+
509
+ add_number = opts.save_images_add_number or file_decoration == ''
510
+
511
+ if file_decoration != "" and add_number:
512
+ file_decoration = "-" + file_decoration
513
+
514
+ file_decoration = namegen.apply(file_decoration) + suffix
515
+
516
+ if add_number:
517
+ basecount = get_next_sequence_number(path, basename)
518
+ fullfn = None
519
+ for i in range(500):
520
+ fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
521
+ fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
522
+ if not os.path.exists(fullfn):
523
+ break
524
+ else:
525
+ fullfn = os.path.join(path, f"{file_decoration}.{extension}")
526
+ else:
527
+ fullfn = os.path.join(path, f"{forced_filename}.{extension}")
528
+
529
+ pnginfo = existing_info or {}
530
+ if info is not None:
531
+ pnginfo[pnginfo_section_name] = info
532
+
533
+ params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
534
+ script_callbacks.before_image_saved_callback(params)
535
+
536
+ image = params.image
537
+ fullfn = params.filename
538
+ info = params.pnginfo.get(pnginfo_section_name, None)
539
+
540
+ def _atomically_save_image(image_to_save, filename_without_extension, extension):
541
+ # save image with .tmp extension to avoid race condition when another process detects new image in the directory
542
+ temp_file_path = filename_without_extension + ".tmp"
543
+ image_format = Image.registered_extensions()[extension]
544
+
545
+ if extension.lower() == '.png':
546
+ pnginfo_data = PngImagePlugin.PngInfo()
547
+ if opts.enable_pnginfo:
548
+ for k, v in params.pnginfo.items():
549
+ pnginfo_data.add_text(k, str(v))
550
+
551
+ image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
552
+
553
+ elif extension.lower() in (".jpg", ".jpeg", ".webp"):
554
+ if image_to_save.mode == 'RGBA':
555
+ image_to_save = image_to_save.convert("RGB")
556
+ elif image_to_save.mode == 'I;16':
557
+ image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
558
+
559
+ image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
560
+
561
+ if opts.enable_pnginfo and info is not None:
562
+ exif_bytes = piexif.dump({
563
+ "Exif": {
564
+ piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
565
+ },
566
+ })
567
+
568
+ piexif.insert(exif_bytes, temp_file_path)
569
+ else:
570
+ image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
571
+
572
+ # atomically rename the file with correct extension
573
+ os.replace(temp_file_path, filename_without_extension + extension)
574
+
575
+ fullfn_without_extension, extension = os.path.splitext(params.filename)
576
+ _atomically_save_image(image, fullfn_without_extension, extension)
577
+
578
+ image.already_saved_as = fullfn
579
+
580
+ oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
581
+ if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
582
+ ratio = image.width / image.height
583
+
584
+ if oversize and ratio > 1:
585
+ image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
586
+ elif oversize:
587
+ image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
588
+
589
+ try:
590
+ _atomically_save_image(image, fullfn_without_extension, ".jpg")
591
+ except Exception as e:
592
+ errors.display(e, "saving image as downscaled JPG")
593
+
594
+ if opts.save_txt and info is not None:
595
+ txt_fullfn = f"{fullfn_without_extension}.txt"
596
+ with open(txt_fullfn, "w", encoding="utf8") as file:
597
+ file.write(info + "\n")
598
+ else:
599
+ txt_fullfn = None
600
+
601
+ script_callbacks.image_saved_callback(params)
602
+
603
+ return fullfn, txt_fullfn
604
+
605
+
606
+ def read_info_from_image(image):
607
+ items = image.info or {}
608
+
609
+ geninfo = items.pop('parameters', None)
610
+
611
+ if "exif" in items:
612
+ exif = piexif.load(items["exif"])
613
+ exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
614
+ try:
615
+ exif_comment = piexif.helper.UserComment.load(exif_comment)
616
+ except ValueError:
617
+ exif_comment = exif_comment.decode('utf8', errors="ignore")
618
+
619
+ if exif_comment:
620
+ items['exif comment'] = exif_comment
621
+ geninfo = exif_comment
622
+
623
+ for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
624
+ 'loop', 'background', 'timestamp', 'duration']:
625
+ items.pop(field, None)
626
+
627
+ if items.get("Software", None) == "NovelAI":
628
+ try:
629
+ json_info = json.loads(items["Comment"])
630
+ sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
631
+
632
+ geninfo = f"""{items["Description"]}
633
+ Negative prompt: {json_info["uc"]}
634
+ Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
635
+ except Exception:
636
+ print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
637
+ print(traceback.format_exc(), file=sys.stderr)
638
+
639
+ return geninfo, items
640
+
641
+
642
+ def image_data(data):
643
+ try:
644
+ image = Image.open(io.BytesIO(data))
645
+ textinfo, _ = read_info_from_image(image)
646
+ return textinfo, None
647
+ except Exception:
648
+ pass
649
+
650
+ try:
651
+ text = data.decode('utf8')
652
+ assert len(text) < 10000
653
+ return text, None
654
+
655
+ except Exception:
656
+ pass
657
+
658
+ return '', None
659
+
660
+
661
+ def flatten(img, bgcolor):
662
+ """replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
663
+
664
+ if img.mode == "RGBA":
665
+ background = Image.new('RGBA', img.size, bgcolor)
666
+ background.paste(img, mask=img)
667
+ img = background
668
+
669
+ return img.convert('RGB')
sd/stable-diffusion-webui/modules/img2img.py CHANGED
@@ -1,184 +1,184 @@
1
- import math
2
- import os
3
- import sys
4
- import traceback
5
-
6
- import numpy as np
7
- from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
8
-
9
- from modules import devices, sd_samplers
10
- from modules.generation_parameters_copypaste import create_override_settings_dict
11
- from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
12
- from modules.shared import opts, state
13
- import modules.shared as shared
14
- import modules.processing as processing
15
- from modules.ui import plaintext_to_html
16
- import modules.images as images
17
- import modules.scripts
18
-
19
-
20
- def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
21
- processing.fix_seed(p)
22
-
23
- images = shared.listfiles(input_dir)
24
-
25
- is_inpaint_batch = False
26
- if inpaint_mask_dir:
27
- inpaint_masks = shared.listfiles(inpaint_mask_dir)
28
- is_inpaint_batch = len(inpaint_masks) > 0
29
- if is_inpaint_batch:
30
- print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
31
-
32
- print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
33
-
34
- save_normally = output_dir == ''
35
-
36
- p.do_not_save_grid = True
37
- p.do_not_save_samples = not save_normally
38
-
39
- state.job_count = len(images) * p.n_iter
40
-
41
- for i, image in enumerate(images):
42
- state.job = f"{i+1} out of {len(images)}"
43
- if state.skipped:
44
- state.skipped = False
45
-
46
- if state.interrupted:
47
- break
48
-
49
- img = Image.open(image)
50
- # Use the EXIF orientation of photos taken by smartphones.
51
- img = ImageOps.exif_transpose(img)
52
- p.init_images = [img] * p.batch_size
53
-
54
- if is_inpaint_batch:
55
- # try to find corresponding mask for an image using simple filename matching
56
- mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
57
- # if not found use first one ("same mask for all images" use-case)
58
- if not mask_image_path in inpaint_masks:
59
- mask_image_path = inpaint_masks[0]
60
- mask_image = Image.open(mask_image_path)
61
- p.image_mask = mask_image
62
-
63
- proc = modules.scripts.scripts_img2img.run(p, *args)
64
- if proc is None:
65
- proc = process_images(p)
66
-
67
- for n, processed_image in enumerate(proc.images):
68
- filename = os.path.basename(image)
69
-
70
- if n > 0:
71
- left, right = os.path.splitext(filename)
72
- filename = f"{left}-{n}{right}"
73
-
74
- if not save_normally:
75
- os.makedirs(output_dir, exist_ok=True)
76
- if processed_image.mode == 'RGBA':
77
- processed_image = processed_image.convert("RGB")
78
- processed_image.save(os.path.join(output_dir, filename))
79
-
80
-
81
- def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
82
- override_settings = create_override_settings_dict(override_settings_texts)
83
-
84
- is_batch = mode == 5
85
-
86
- if mode == 0: # img2img
87
- image = init_img.convert("RGB")
88
- mask = None
89
- elif mode == 1: # img2img sketch
90
- image = sketch.convert("RGB")
91
- mask = None
92
- elif mode == 2: # inpaint
93
- image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
94
- alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
95
- mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
96
- image = image.convert("RGB")
97
- elif mode == 3: # inpaint sketch
98
- image = inpaint_color_sketch
99
- orig = inpaint_color_sketch_orig or inpaint_color_sketch
100
- pred = np.any(np.array(image) != np.array(orig), axis=-1)
101
- mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
102
- mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
103
- blur = ImageFilter.GaussianBlur(mask_blur)
104
- image = Image.composite(image.filter(blur), orig, mask.filter(blur))
105
- image = image.convert("RGB")
106
- elif mode == 4: # inpaint upload mask
107
- image = init_img_inpaint
108
- mask = init_mask_inpaint
109
- else:
110
- image = None
111
- mask = None
112
-
113
- # Use the EXIF orientation of photos taken by smartphones.
114
- if image is not None:
115
- image = ImageOps.exif_transpose(image)
116
-
117
- assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
118
-
119
- p = StableDiffusionProcessingImg2Img(
120
- sd_model=shared.sd_model,
121
- outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
122
- outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
123
- prompt=prompt,
124
- negative_prompt=negative_prompt,
125
- styles=prompt_styles,
126
- seed=seed,
127
- subseed=subseed,
128
- subseed_strength=subseed_strength,
129
- seed_resize_from_h=seed_resize_from_h,
130
- seed_resize_from_w=seed_resize_from_w,
131
- seed_enable_extras=seed_enable_extras,
132
- sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
133
- batch_size=batch_size,
134
- n_iter=n_iter,
135
- steps=steps,
136
- cfg_scale=cfg_scale,
137
- width=width,
138
- height=height,
139
- restore_faces=restore_faces,
140
- tiling=tiling,
141
- init_images=[image],
142
- mask=mask,
143
- mask_blur=mask_blur,
144
- inpainting_fill=inpainting_fill,
145
- resize_mode=resize_mode,
146
- denoising_strength=denoising_strength,
147
- image_cfg_scale=image_cfg_scale,
148
- inpaint_full_res=inpaint_full_res,
149
- inpaint_full_res_padding=inpaint_full_res_padding,
150
- inpainting_mask_invert=inpainting_mask_invert,
151
- override_settings=override_settings,
152
- )
153
-
154
- p.scripts = modules.scripts.scripts_txt2img
155
- p.script_args = args
156
-
157
- if shared.cmd_opts.enable_console_prompts:
158
- print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
159
-
160
- p.extra_generation_params["Mask blur"] = mask_blur
161
-
162
- if is_batch:
163
- assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
164
-
165
- process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
166
-
167
- processed = Processed(p, [], p.seed, "")
168
- else:
169
- processed = modules.scripts.scripts_img2img.run(p, *args)
170
- if processed is None:
171
- processed = process_images(p)
172
-
173
- p.close()
174
-
175
- shared.total_tqdm.clear()
176
-
177
- generation_info_js = processed.js()
178
- if opts.samples_log_stdout:
179
- print(generation_info_js)
180
-
181
- if opts.do_not_show_images:
182
- processed.images = []
183
-
184
- return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
 
1
+ import math
2
+ import os
3
+ import sys
4
+ import traceback
5
+
6
+ import numpy as np
7
+ from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
8
+
9
+ from modules import devices, sd_samplers
10
+ from modules.generation_parameters_copypaste import create_override_settings_dict
11
+ from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
12
+ from modules.shared import opts, state
13
+ import modules.shared as shared
14
+ import modules.processing as processing
15
+ from modules.ui import plaintext_to_html
16
+ import modules.images as images
17
+ import modules.scripts
18
+
19
+
20
+ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
21
+ processing.fix_seed(p)
22
+
23
+ images = shared.listfiles(input_dir)
24
+
25
+ is_inpaint_batch = False
26
+ if inpaint_mask_dir:
27
+ inpaint_masks = shared.listfiles(inpaint_mask_dir)
28
+ is_inpaint_batch = len(inpaint_masks) > 0
29
+ if is_inpaint_batch:
30
+ print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
31
+
32
+ print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
33
+
34
+ save_normally = output_dir == ''
35
+
36
+ p.do_not_save_grid = True
37
+ p.do_not_save_samples = not save_normally
38
+
39
+ state.job_count = len(images) * p.n_iter
40
+
41
+ for i, image in enumerate(images):
42
+ state.job = f"{i+1} out of {len(images)}"
43
+ if state.skipped:
44
+ state.skipped = False
45
+
46
+ if state.interrupted:
47
+ break
48
+
49
+ img = Image.open(image)
50
+ # Use the EXIF orientation of photos taken by smartphones.
51
+ img = ImageOps.exif_transpose(img)
52
+ p.init_images = [img] * p.batch_size
53
+
54
+ if is_inpaint_batch:
55
+ # try to find corresponding mask for an image using simple filename matching
56
+ mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
57
+ # if not found use first one ("same mask for all images" use-case)
58
+ if not mask_image_path in inpaint_masks:
59
+ mask_image_path = inpaint_masks[0]
60
+ mask_image = Image.open(mask_image_path)
61
+ p.image_mask = mask_image
62
+
63
+ proc = modules.scripts.scripts_img2img.run(p, *args)
64
+ if proc is None:
65
+ proc = process_images(p)
66
+
67
+ for n, processed_image in enumerate(proc.images):
68
+ filename = os.path.basename(image)
69
+
70
+ if n > 0:
71
+ left, right = os.path.splitext(filename)
72
+ filename = f"{left}-{n}{right}"
73
+
74
+ if not save_normally:
75
+ os.makedirs(output_dir, exist_ok=True)
76
+ if processed_image.mode == 'RGBA':
77
+ processed_image = processed_image.convert("RGB")
78
+ processed_image.save(os.path.join(output_dir, filename))
79
+
80
+
81
+ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
82
+ override_settings = create_override_settings_dict(override_settings_texts)
83
+
84
+ is_batch = mode == 5
85
+
86
+ if mode == 0: # img2img
87
+ image = init_img.convert("RGB")
88
+ mask = None
89
+ elif mode == 1: # img2img sketch
90
+ image = sketch.convert("RGB")
91
+ mask = None
92
+ elif mode == 2: # inpaint
93
+ image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
94
+ alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
95
+ mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
96
+ image = image.convert("RGB")
97
+ elif mode == 3: # inpaint sketch
98
+ image = inpaint_color_sketch
99
+ orig = inpaint_color_sketch_orig or inpaint_color_sketch
100
+ pred = np.any(np.array(image) != np.array(orig), axis=-1)
101
+ mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
102
+ mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
103
+ blur = ImageFilter.GaussianBlur(mask_blur)
104
+ image = Image.composite(image.filter(blur), orig, mask.filter(blur))
105
+ image = image.convert("RGB")
106
+ elif mode == 4: # inpaint upload mask
107
+ image = init_img_inpaint
108
+ mask = init_mask_inpaint
109
+ else:
110
+ image = None
111
+ mask = None
112
+
113
+ # Use the EXIF orientation of photos taken by smartphones.
114
+ if image is not None:
115
+ image = ImageOps.exif_transpose(image)
116
+
117
+ assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
118
+
119
+ p = StableDiffusionProcessingImg2Img(
120
+ sd_model=shared.sd_model,
121
+ outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
122
+ outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
123
+ prompt=prompt,
124
+ negative_prompt=negative_prompt,
125
+ styles=prompt_styles,
126
+ seed=seed,
127
+ subseed=subseed,
128
+ subseed_strength=subseed_strength,
129
+ seed_resize_from_h=seed_resize_from_h,
130
+ seed_resize_from_w=seed_resize_from_w,
131
+ seed_enable_extras=seed_enable_extras,
132
+ sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
133
+ batch_size=batch_size,
134
+ n_iter=n_iter,
135
+ steps=steps,
136
+ cfg_scale=cfg_scale,
137
+ width=width,
138
+ height=height,
139
+ restore_faces=restore_faces,
140
+ tiling=tiling,
141
+ init_images=[image],
142
+ mask=mask,
143
+ mask_blur=mask_blur,
144
+ inpainting_fill=inpainting_fill,
145
+ resize_mode=resize_mode,
146
+ denoising_strength=denoising_strength,
147
+ image_cfg_scale=image_cfg_scale,
148
+ inpaint_full_res=inpaint_full_res,
149
+ inpaint_full_res_padding=inpaint_full_res_padding,
150
+ inpainting_mask_invert=inpainting_mask_invert,
151
+ override_settings=override_settings,
152
+ )
153
+
154
+ p.scripts = modules.scripts.scripts_txt2img
155
+ p.script_args = args
156
+
157
+ if shared.cmd_opts.enable_console_prompts:
158
+ print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
159
+
160
+ p.extra_generation_params["Mask blur"] = mask_blur
161
+
162
+ if is_batch:
163
+ assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
164
+
165
+ process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
166
+
167
+ processed = Processed(p, [], p.seed, "")
168
+ else:
169
+ processed = modules.scripts.scripts_img2img.run(p, *args)
170
+ if processed is None:
171
+ processed = process_images(p)
172
+
173
+ p.close()
174
+
175
+ shared.total_tqdm.clear()
176
+
177
+ generation_info_js = processed.js()
178
+ if opts.samples_log_stdout:
179
+ print(generation_info_js)
180
+
181
+ if opts.do_not_show_images:
182
+ processed.images = []
183
+
184
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
sd/stable-diffusion-webui/modules/interrogate.py CHANGED
@@ -1,227 +1,227 @@
1
- import os
2
- import sys
3
- import traceback
4
- from collections import namedtuple
5
- from pathlib import Path
6
- import re
7
-
8
- import torch
9
- import torch.hub
10
-
11
- from torchvision import transforms
12
- from torchvision.transforms.functional import InterpolationMode
13
-
14
- import modules.shared as shared
15
- from modules import devices, paths, shared, lowvram, modelloader, errors
16
-
17
- blip_image_eval_size = 384
18
- clip_model_name = 'ViT-L/14'
19
-
20
- Category = namedtuple("Category", ["name", "topn", "items"])
21
-
22
- re_topn = re.compile(r"\.top(\d+)\.")
23
-
24
- def category_types():
25
- return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
26
-
27
-
28
- def download_default_clip_interrogate_categories(content_dir):
29
- print("Downloading CLIP categories...")
30
-
31
- tmpdir = content_dir + "_tmp"
32
- category_types = ["artists", "flavors", "mediums", "movements"]
33
-
34
- try:
35
- os.makedirs(tmpdir)
36
- for category_type in category_types:
37
- torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
38
- os.rename(tmpdir, content_dir)
39
-
40
- except Exception as e:
41
- errors.display(e, "downloading default CLIP interrogate categories")
42
- finally:
43
- if os.path.exists(tmpdir):
44
- os.remove(tmpdir)
45
-
46
-
47
- class InterrogateModels:
48
- blip_model = None
49
- clip_model = None
50
- clip_preprocess = None
51
- dtype = None
52
- running_on_cpu = None
53
-
54
- def __init__(self, content_dir):
55
- self.loaded_categories = None
56
- self.skip_categories = []
57
- self.content_dir = content_dir
58
- self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
59
-
60
- def categories(self):
61
- if not os.path.exists(self.content_dir):
62
- download_default_clip_interrogate_categories(self.content_dir)
63
-
64
- if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
65
- return self.loaded_categories
66
-
67
- self.loaded_categories = []
68
-
69
- if os.path.exists(self.content_dir):
70
- self.skip_categories = shared.opts.interrogate_clip_skip_categories
71
- category_types = []
72
- for filename in Path(self.content_dir).glob('*.txt'):
73
- category_types.append(filename.stem)
74
- if filename.stem in self.skip_categories:
75
- continue
76
- m = re_topn.search(filename.stem)
77
- topn = 1 if m is None else int(m.group(1))
78
- with open(filename, "r", encoding="utf8") as file:
79
- lines = [x.strip() for x in file.readlines()]
80
-
81
- self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
82
-
83
- return self.loaded_categories
84
-
85
- def create_fake_fairscale(self):
86
- class FakeFairscale:
87
- def checkpoint_wrapper(self):
88
- pass
89
-
90
- sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
91
-
92
- def load_blip_model(self):
93
- self.create_fake_fairscale()
94
- import models.blip
95
-
96
- files = modelloader.load_models(
97
- model_path=os.path.join(paths.models_path, "BLIP"),
98
- model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
99
- ext_filter=[".pth"],
100
- download_name='model_base_caption_capfilt_large.pth',
101
- )
102
-
103
- blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
104
- blip_model.eval()
105
-
106
- return blip_model
107
-
108
- def load_clip_model(self):
109
- import clip
110
-
111
- if self.running_on_cpu:
112
- model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
113
- else:
114
- model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
115
-
116
- model.eval()
117
- model = model.to(devices.device_interrogate)
118
-
119
- return model, preprocess
120
-
121
- def load(self):
122
- if self.blip_model is None:
123
- self.blip_model = self.load_blip_model()
124
- if not shared.cmd_opts.no_half and not self.running_on_cpu:
125
- self.blip_model = self.blip_model.half()
126
-
127
- self.blip_model = self.blip_model.to(devices.device_interrogate)
128
-
129
- if self.clip_model is None:
130
- self.clip_model, self.clip_preprocess = self.load_clip_model()
131
- if not shared.cmd_opts.no_half and not self.running_on_cpu:
132
- self.clip_model = self.clip_model.half()
133
-
134
- self.clip_model = self.clip_model.to(devices.device_interrogate)
135
-
136
- self.dtype = next(self.clip_model.parameters()).dtype
137
-
138
- def send_clip_to_ram(self):
139
- if not shared.opts.interrogate_keep_models_in_memory:
140
- if self.clip_model is not None:
141
- self.clip_model = self.clip_model.to(devices.cpu)
142
-
143
- def send_blip_to_ram(self):
144
- if not shared.opts.interrogate_keep_models_in_memory:
145
- if self.blip_model is not None:
146
- self.blip_model = self.blip_model.to(devices.cpu)
147
-
148
- def unload(self):
149
- self.send_clip_to_ram()
150
- self.send_blip_to_ram()
151
-
152
- devices.torch_gc()
153
-
154
- def rank(self, image_features, text_array, top_count=1):
155
- import clip
156
-
157
- devices.torch_gc()
158
-
159
- if shared.opts.interrogate_clip_dict_limit != 0:
160
- text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
161
-
162
- top_count = min(top_count, len(text_array))
163
- text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
164
- text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
165
- text_features /= text_features.norm(dim=-1, keepdim=True)
166
-
167
- similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
168
- for i in range(image_features.shape[0]):
169
- similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
170
- similarity /= image_features.shape[0]
171
-
172
- top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
173
- return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
174
-
175
- def generate_caption(self, pil_image):
176
- gpu_image = transforms.Compose([
177
- transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
178
- transforms.ToTensor(),
179
- transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
180
- ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
181
-
182
- with torch.no_grad():
183
- caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
184
-
185
- return caption[0]
186
-
187
- def interrogate(self, pil_image):
188
- res = ""
189
- shared.state.begin()
190
- shared.state.job = 'interrogate'
191
- try:
192
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
193
- lowvram.send_everything_to_cpu()
194
- devices.torch_gc()
195
-
196
- self.load()
197
-
198
- caption = self.generate_caption(pil_image)
199
- self.send_blip_to_ram()
200
- devices.torch_gc()
201
-
202
- res = caption
203
-
204
- clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
205
-
206
- with torch.no_grad(), devices.autocast():
207
- image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
208
-
209
- image_features /= image_features.norm(dim=-1, keepdim=True)
210
-
211
- for name, topn, items in self.categories():
212
- matches = self.rank(image_features, items, top_count=topn)
213
- for match, score in matches:
214
- if shared.opts.interrogate_return_ranks:
215
- res += f", ({match}:{score/100:.3f})"
216
- else:
217
- res += ", " + match
218
-
219
- except Exception:
220
- print("Error interrogating", file=sys.stderr)
221
- print(traceback.format_exc(), file=sys.stderr)
222
- res += "<error>"
223
-
224
- self.unload()
225
- shared.state.end()
226
-
227
- return res
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+ from collections import namedtuple
5
+ from pathlib import Path
6
+ import re
7
+
8
+ import torch
9
+ import torch.hub
10
+
11
+ from torchvision import transforms
12
+ from torchvision.transforms.functional import InterpolationMode
13
+
14
+ import modules.shared as shared
15
+ from modules import devices, paths, shared, lowvram, modelloader, errors
16
+
17
+ blip_image_eval_size = 384
18
+ clip_model_name = 'ViT-L/14'
19
+
20
+ Category = namedtuple("Category", ["name", "topn", "items"])
21
+
22
+ re_topn = re.compile(r"\.top(\d+)\.")
23
+
24
+ def category_types():
25
+ return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
26
+
27
+
28
+ def download_default_clip_interrogate_categories(content_dir):
29
+ print("Downloading CLIP categories...")
30
+
31
+ tmpdir = content_dir + "_tmp"
32
+ category_types = ["artists", "flavors", "mediums", "movements"]
33
+
34
+ try:
35
+ os.makedirs(tmpdir)
36
+ for category_type in category_types:
37
+ torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
38
+ os.rename(tmpdir, content_dir)
39
+
40
+ except Exception as e:
41
+ errors.display(e, "downloading default CLIP interrogate categories")
42
+ finally:
43
+ if os.path.exists(tmpdir):
44
+ os.remove(tmpdir)
45
+
46
+
47
+ class InterrogateModels:
48
+ blip_model = None
49
+ clip_model = None
50
+ clip_preprocess = None
51
+ dtype = None
52
+ running_on_cpu = None
53
+
54
+ def __init__(self, content_dir):
55
+ self.loaded_categories = None
56
+ self.skip_categories = []
57
+ self.content_dir = content_dir
58
+ self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
59
+
60
+ def categories(self):
61
+ if not os.path.exists(self.content_dir):
62
+ download_default_clip_interrogate_categories(self.content_dir)
63
+
64
+ if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
65
+ return self.loaded_categories
66
+
67
+ self.loaded_categories = []
68
+
69
+ if os.path.exists(self.content_dir):
70
+ self.skip_categories = shared.opts.interrogate_clip_skip_categories
71
+ category_types = []
72
+ for filename in Path(self.content_dir).glob('*.txt'):
73
+ category_types.append(filename.stem)
74
+ if filename.stem in self.skip_categories:
75
+ continue
76
+ m = re_topn.search(filename.stem)
77
+ topn = 1 if m is None else int(m.group(1))
78
+ with open(filename, "r", encoding="utf8") as file:
79
+ lines = [x.strip() for x in file.readlines()]
80
+
81
+ self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
82
+
83
+ return self.loaded_categories
84
+
85
+ def create_fake_fairscale(self):
86
+ class FakeFairscale:
87
+ def checkpoint_wrapper(self):
88
+ pass
89
+
90
+ sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
91
+
92
+ def load_blip_model(self):
93
+ self.create_fake_fairscale()
94
+ import models.blip
95
+
96
+ files = modelloader.load_models(
97
+ model_path=os.path.join(paths.models_path, "BLIP"),
98
+ model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
99
+ ext_filter=[".pth"],
100
+ download_name='model_base_caption_capfilt_large.pth',
101
+ )
102
+
103
+ blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
104
+ blip_model.eval()
105
+
106
+ return blip_model
107
+
108
+ def load_clip_model(self):
109
+ import clip
110
+
111
+ if self.running_on_cpu:
112
+ model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
113
+ else:
114
+ model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
115
+
116
+ model.eval()
117
+ model = model.to(devices.device_interrogate)
118
+
119
+ return model, preprocess
120
+
121
+ def load(self):
122
+ if self.blip_model is None:
123
+ self.blip_model = self.load_blip_model()
124
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
125
+ self.blip_model = self.blip_model.half()
126
+
127
+ self.blip_model = self.blip_model.to(devices.device_interrogate)
128
+
129
+ if self.clip_model is None:
130
+ self.clip_model, self.clip_preprocess = self.load_clip_model()
131
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
132
+ self.clip_model = self.clip_model.half()
133
+
134
+ self.clip_model = self.clip_model.to(devices.device_interrogate)
135
+
136
+ self.dtype = next(self.clip_model.parameters()).dtype
137
+
138
+ def send_clip_to_ram(self):
139
+ if not shared.opts.interrogate_keep_models_in_memory:
140
+ if self.clip_model is not None:
141
+ self.clip_model = self.clip_model.to(devices.cpu)
142
+
143
+ def send_blip_to_ram(self):
144
+ if not shared.opts.interrogate_keep_models_in_memory:
145
+ if self.blip_model is not None:
146
+ self.blip_model = self.blip_model.to(devices.cpu)
147
+
148
+ def unload(self):
149
+ self.send_clip_to_ram()
150
+ self.send_blip_to_ram()
151
+
152
+ devices.torch_gc()
153
+
154
+ def rank(self, image_features, text_array, top_count=1):
155
+ import clip
156
+
157
+ devices.torch_gc()
158
+
159
+ if shared.opts.interrogate_clip_dict_limit != 0:
160
+ text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
161
+
162
+ top_count = min(top_count, len(text_array))
163
+ text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
164
+ text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
165
+ text_features /= text_features.norm(dim=-1, keepdim=True)
166
+
167
+ similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
168
+ for i in range(image_features.shape[0]):
169
+ similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
170
+ similarity /= image_features.shape[0]
171
+
172
+ top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
173
+ return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
174
+
175
+ def generate_caption(self, pil_image):
176
+ gpu_image = transforms.Compose([
177
+ transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
178
+ transforms.ToTensor(),
179
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
180
+ ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
181
+
182
+ with torch.no_grad():
183
+ caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
184
+
185
+ return caption[0]
186
+
187
+ def interrogate(self, pil_image):
188
+ res = ""
189
+ shared.state.begin()
190
+ shared.state.job = 'interrogate'
191
+ try:
192
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
193
+ lowvram.send_everything_to_cpu()
194
+ devices.torch_gc()
195
+
196
+ self.load()
197
+
198
+ caption = self.generate_caption(pil_image)
199
+ self.send_blip_to_ram()
200
+ devices.torch_gc()
201
+
202
+ res = caption
203
+
204
+ clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
205
+
206
+ with torch.no_grad(), devices.autocast():
207
+ image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
208
+
209
+ image_features /= image_features.norm(dim=-1, keepdim=True)
210
+
211
+ for name, topn, items in self.categories():
212
+ matches = self.rank(image_features, items, top_count=topn)
213
+ for match, score in matches:
214
+ if shared.opts.interrogate_return_ranks:
215
+ res += f", ({match}:{score/100:.3f})"
216
+ else:
217
+ res += ", " + match
218
+
219
+ except Exception:
220
+ print("Error interrogating", file=sys.stderr)
221
+ print(traceback.format_exc(), file=sys.stderr)
222
+ res += "<error>"
223
+
224
+ self.unload()
225
+ shared.state.end()
226
+
227
+ return res
sd/stable-diffusion-webui/modules/localization.py CHANGED
@@ -1,37 +1,37 @@
1
- import json
2
- import os
3
- import sys
4
- import traceback
5
-
6
-
7
- localizations = {}
8
-
9
-
10
- def list_localizations(dirname):
11
- localizations.clear()
12
-
13
- for file in os.listdir(dirname):
14
- fn, ext = os.path.splitext(file)
15
- if ext.lower() != ".json":
16
- continue
17
-
18
- localizations[fn] = os.path.join(dirname, file)
19
-
20
- from modules import scripts
21
- for file in scripts.list_scripts("localizations", ".json"):
22
- fn, ext = os.path.splitext(file.filename)
23
- localizations[fn] = file.path
24
-
25
-
26
- def localization_js(current_localization_name):
27
- fn = localizations.get(current_localization_name, None)
28
- data = {}
29
- if fn is not None:
30
- try:
31
- with open(fn, "r", encoding="utf8") as file:
32
- data = json.load(file)
33
- except Exception:
34
- print(f"Error loading localization from {fn}:", file=sys.stderr)
35
- print(traceback.format_exc(), file=sys.stderr)
36
-
37
- return f"var localization = {json.dumps(data)}\n"
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import traceback
5
+
6
+
7
+ localizations = {}
8
+
9
+
10
+ def list_localizations(dirname):
11
+ localizations.clear()
12
+
13
+ for file in os.listdir(dirname):
14
+ fn, ext = os.path.splitext(file)
15
+ if ext.lower() != ".json":
16
+ continue
17
+
18
+ localizations[fn] = os.path.join(dirname, file)
19
+
20
+ from modules import scripts
21
+ for file in scripts.list_scripts("localizations", ".json"):
22
+ fn, ext = os.path.splitext(file.filename)
23
+ localizations[fn] = file.path
24
+
25
+
26
+ def localization_js(current_localization_name):
27
+ fn = localizations.get(current_localization_name, None)
28
+ data = {}
29
+ if fn is not None:
30
+ try:
31
+ with open(fn, "r", encoding="utf8") as file:
32
+ data = json.load(file)
33
+ except Exception:
34
+ print(f"Error loading localization from {fn}:", file=sys.stderr)
35
+ print(traceback.format_exc(), file=sys.stderr)
36
+
37
+ return f"var localization = {json.dumps(data)}\n"