Spaces:
Running
on
Zero
Running
on
Zero
artificialguybr
commited on
Commit
•
eadd7b4
1
Parent(s):
f8cfb21
Hi
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +29 -0
- LICENSE +661 -0
- app/app_pixart_dmd.py +343 -0
- app/app_pixart_sigma.py +420 -0
- asset/PixArt.svg +96 -0
- asset/docs/pixart.md +112 -0
- asset/examples.py +36 -0
- asset/logo-sigma.png +0 -0
- asset/logo.png +0 -0
- asset/samples.txt +120 -0
- configs/PixArt_xl2_internal.py +79 -0
- configs/pixart_alpha_config/PixArt_xl2_img1024_dreambooth.py +30 -0
- configs/pixart_alpha_config/PixArt_xl2_img1024_internal.py +29 -0
- configs/pixart_alpha_config/PixArt_xl2_img1024_internalms.py +32 -0
- configs/pixart_alpha_config/PixArt_xl2_img256_internal.py +27 -0
- configs/pixart_alpha_config/PixArt_xl2_img512_internal.py +29 -0
- configs/pixart_alpha_config/PixArt_xl2_img512_internalms.py +31 -0
- configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms.py +46 -0
- configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms_kvcompress.py +51 -0
- configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_lcm.py +52 -0
- configs/pixart_sigma_config/PixArt_sigma_xl2_img256_internal.py +41 -0
- configs/pixart_sigma_config/PixArt_sigma_xl2_img2K_internalms_kvcompress.py +49 -0
- configs/pixart_sigma_config/PixArt_sigma_xl2_img512_internalms.py +43 -0
- diffusion/__init__.py +8 -0
- diffusion/data/__init__.py +2 -0
- diffusion/data/builder.py +50 -0
- diffusion/data/datasets/InternalData.py +312 -0
- diffusion/data/datasets/InternalData_ms.py +336 -0
- diffusion/data/datasets/__init__.py +3 -0
- diffusion/data/datasets/utils.py +134 -0
- diffusion/data/transforms.py +30 -0
- diffusion/dpm_solver.py +36 -0
- diffusion/iddpm.py +53 -0
- diffusion/lcm_scheduler.py +459 -0
- diffusion/model/__init__.py +1 -0
- diffusion/model/builder.py +14 -0
- diffusion/model/diffusion_utils.py +88 -0
- diffusion/model/dpm_solver.py +1337 -0
- diffusion/model/edm_sample.py +171 -0
- diffusion/model/gaussian_diffusion.py +1041 -0
- diffusion/model/llava/__init__.py +1 -0
- diffusion/model/llava/llava_mpt.py +280 -0
- diffusion/model/llava/mpt/attention.py +276 -0
- diffusion/model/llava/mpt/blocks.py +41 -0
- diffusion/model/llava/mpt/configuration_mpt.py +118 -0
- diffusion/model/llava/mpt/modeling_mpt.py +308 -0
- diffusion/model/llava/mpt/norm.py +56 -0
- diffusion/model/llava/mpt/param_init_fns.py +181 -0
- diffusion/model/nets/PixArt.py +315 -0
- diffusion/model/nets/PixArtMS.py +293 -0
Dockerfile
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is a sample Dockefile that builds a runtime container and runs the sample Gradio app.
|
2 |
+
# Note, you must pass in the pretrained models when you run the container.
|
3 |
+
|
4 |
+
FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04
|
5 |
+
|
6 |
+
WORKDIR /workspace
|
7 |
+
|
8 |
+
RUN apt-get update && \
|
9 |
+
apt-get install -y \
|
10 |
+
git \
|
11 |
+
python3 \
|
12 |
+
python-is-python3 \
|
13 |
+
python3-pip \
|
14 |
+
python3.10-venv \
|
15 |
+
libgl1 \
|
16 |
+
libgl1-mesa-glx \
|
17 |
+
libglib2.0-0 \
|
18 |
+
&& rm -rf /var/lib/apt/lists/*
|
19 |
+
|
20 |
+
ADD requirements.txt .
|
21 |
+
|
22 |
+
RUN pip install -r requirements.txt
|
23 |
+
|
24 |
+
ADD . .
|
25 |
+
|
26 |
+
RUN chmod a+x docker-entrypoint.sh
|
27 |
+
|
28 |
+
ENV DEMO_PORT=12345
|
29 |
+
ENTRYPOINT [ "/workspace/docker-entrypoint.sh" ]
|
LICENSE
ADDED
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GNU AFFERO GENERAL PUBLIC LICENSE
|
2 |
+
Version 3, 19 November 2007
|
3 |
+
|
4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
6 |
+
of this license document, but changing it is not allowed.
|
7 |
+
|
8 |
+
Preamble
|
9 |
+
|
10 |
+
The GNU Affero General Public License is a free, copyleft license for
|
11 |
+
software and other kinds of works, specifically designed to ensure
|
12 |
+
cooperation with the community in the case of network server software.
|
13 |
+
|
14 |
+
The licenses for most software and other practical works are designed
|
15 |
+
to take away your freedom to share and change the works. By contrast,
|
16 |
+
our General Public Licenses are intended to guarantee your freedom to
|
17 |
+
share and change all versions of a program--to make sure it remains free
|
18 |
+
software for all its users.
|
19 |
+
|
20 |
+
When we speak of free software, we are referring to freedom, not
|
21 |
+
price. Our General Public Licenses are designed to make sure that you
|
22 |
+
have the freedom to distribute copies of free software (and charge for
|
23 |
+
them if you wish), that you receive source code or can get it if you
|
24 |
+
want it, that you can change the software or use pieces of it in new
|
25 |
+
free programs, and that you know you can do these things.
|
26 |
+
|
27 |
+
Developers that use our General Public Licenses protect your rights
|
28 |
+
with two steps: (1) assert copyright on the software, and (2) offer
|
29 |
+
you this License which gives you legal permission to copy, distribute
|
30 |
+
and/or modify the software.
|
31 |
+
|
32 |
+
A secondary benefit of defending all users' freedom is that
|
33 |
+
improvements made in alternate versions of the program, if they
|
34 |
+
receive widespread use, become available for other developers to
|
35 |
+
incorporate. Many developers of free software are heartened and
|
36 |
+
encouraged by the resulting cooperation. However, in the case of
|
37 |
+
software used on network servers, this result may fail to come about.
|
38 |
+
The GNU General Public License permits making a modified version and
|
39 |
+
letting the public access it on a server without ever releasing its
|
40 |
+
source code to the public.
|
41 |
+
|
42 |
+
The GNU Affero General Public License is designed specifically to
|
43 |
+
ensure that, in such cases, the modified source code becomes available
|
44 |
+
to the community. It requires the operator of a network server to
|
45 |
+
provide the source code of the modified version running there to the
|
46 |
+
users of that server. Therefore, public use of a modified version, on
|
47 |
+
a publicly accessible server, gives the public access to the source
|
48 |
+
code of the modified version.
|
49 |
+
|
50 |
+
An older license, called the Affero General Public License and
|
51 |
+
published by Affero, was designed to accomplish similar goals. This is
|
52 |
+
a different license, not a version of the Affero GPL, but Affero has
|
53 |
+
released a new version of the Affero GPL which permits relicensing under
|
54 |
+
this license.
|
55 |
+
|
56 |
+
The precise terms and conditions for copying, distribution and
|
57 |
+
modification follow.
|
58 |
+
|
59 |
+
TERMS AND CONDITIONS
|
60 |
+
|
61 |
+
0. Definitions.
|
62 |
+
|
63 |
+
"This License" refers to version 3 of the GNU Affero General Public License.
|
64 |
+
|
65 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
66 |
+
works, such as semiconductor masks.
|
67 |
+
|
68 |
+
"The Program" refers to any copyrightable work licensed under this
|
69 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
70 |
+
"recipients" may be individuals or organizations.
|
71 |
+
|
72 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
73 |
+
in a fashion requiring copyright permission, other than the making of an
|
74 |
+
exact copy. The resulting work is called a "modified version" of the
|
75 |
+
earlier work or a work "based on" the earlier work.
|
76 |
+
|
77 |
+
A "covered work" means either the unmodified Program or a work based
|
78 |
+
on the Program.
|
79 |
+
|
80 |
+
To "propagate" a work means to do anything with it that, without
|
81 |
+
permission, would make you directly or secondarily liable for
|
82 |
+
infringement under applicable copyright law, except executing it on a
|
83 |
+
computer or modifying a private copy. Propagation includes copying,
|
84 |
+
distribution (with or without modification), making available to the
|
85 |
+
public, and in some countries other activities as well.
|
86 |
+
|
87 |
+
To "convey" a work means any kind of propagation that enables other
|
88 |
+
parties to make or receive copies. Mere interaction with a user through
|
89 |
+
a computer network, with no transfer of a copy, is not conveying.
|
90 |
+
|
91 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
92 |
+
to the extent that it includes a convenient and prominently visible
|
93 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
94 |
+
tells the user that there is no warranty for the work (except to the
|
95 |
+
extent that warranties are provided), that licensees may convey the
|
96 |
+
work under this License, and how to view a copy of this License. If
|
97 |
+
the interface presents a list of user commands or options, such as a
|
98 |
+
menu, a prominent item in the list meets this criterion.
|
99 |
+
|
100 |
+
1. Source Code.
|
101 |
+
|
102 |
+
The "source code" for a work means the preferred form of the work
|
103 |
+
for making modifications to it. "Object code" means any non-source
|
104 |
+
form of a work.
|
105 |
+
|
106 |
+
A "Standard Interface" means an interface that either is an official
|
107 |
+
standard defined by a recognized standards body, or, in the case of
|
108 |
+
interfaces specified for a particular programming language, one that
|
109 |
+
is widely used among developers working in that language.
|
110 |
+
|
111 |
+
The "System Libraries" of an executable work include anything, other
|
112 |
+
than the work as a whole, that (a) is included in the normal form of
|
113 |
+
packaging a Major Component, but which is not part of that Major
|
114 |
+
Component, and (b) serves only to enable use of the work with that
|
115 |
+
Major Component, or to implement a Standard Interface for which an
|
116 |
+
implementation is available to the public in source code form. A
|
117 |
+
"Major Component", in this context, means a major essential component
|
118 |
+
(kernel, window system, and so on) of the specific operating system
|
119 |
+
(if any) on which the executable work runs, or a compiler used to
|
120 |
+
produce the work, or an object code interpreter used to run it.
|
121 |
+
|
122 |
+
The "Corresponding Source" for a work in object code form means all
|
123 |
+
the source code needed to generate, install, and (for an executable
|
124 |
+
work) run the object code and to modify the work, including scripts to
|
125 |
+
control those activities. However, it does not include the work's
|
126 |
+
System Libraries, or general-purpose tools or generally available free
|
127 |
+
programs which are used unmodified in performing those activities but
|
128 |
+
which are not part of the work. For example, Corresponding Source
|
129 |
+
includes interface definition files associated with source files for
|
130 |
+
the work, and the source code for shared libraries and dynamically
|
131 |
+
linked subprograms that the work is specifically designed to require,
|
132 |
+
such as by intimate data communication or control flow between those
|
133 |
+
subprograms and other parts of the work.
|
134 |
+
|
135 |
+
The Corresponding Source need not include anything that users
|
136 |
+
can regenerate automatically from other parts of the Corresponding
|
137 |
+
Source.
|
138 |
+
|
139 |
+
The Corresponding Source for a work in source code form is that
|
140 |
+
same work.
|
141 |
+
|
142 |
+
2. Basic Permissions.
|
143 |
+
|
144 |
+
All rights granted under this License are granted for the term of
|
145 |
+
copyright on the Program, and are irrevocable provided the stated
|
146 |
+
conditions are met. This License explicitly affirms your unlimited
|
147 |
+
permission to run the unmodified Program. The output from running a
|
148 |
+
covered work is covered by this License only if the output, given its
|
149 |
+
content, constitutes a covered work. This License acknowledges your
|
150 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
151 |
+
|
152 |
+
You may make, run and propagate covered works that you do not
|
153 |
+
convey, without conditions so long as your license otherwise remains
|
154 |
+
in force. You may convey covered works to others for the sole purpose
|
155 |
+
of having them make modifications exclusively for you, or provide you
|
156 |
+
with facilities for running those works, provided that you comply with
|
157 |
+
the terms of this License in conveying all material for which you do
|
158 |
+
not control copyright. Those thus making or running the covered works
|
159 |
+
for you must do so exclusively on your behalf, under your direction
|
160 |
+
and control, on terms that prohibit them from making any copies of
|
161 |
+
your copyrighted material outside their relationship with you.
|
162 |
+
|
163 |
+
Conveying under any other circumstances is permitted solely under
|
164 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
165 |
+
makes it unnecessary.
|
166 |
+
|
167 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
168 |
+
|
169 |
+
No covered work shall be deemed part of an effective technological
|
170 |
+
measure under any applicable law fulfilling obligations under article
|
171 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
172 |
+
similar laws prohibiting or restricting circumvention of such
|
173 |
+
measures.
|
174 |
+
|
175 |
+
When you convey a covered work, you waive any legal power to forbid
|
176 |
+
circumvention of technological measures to the extent such circumvention
|
177 |
+
is effected by exercising rights under this License with respect to
|
178 |
+
the covered work, and you disclaim any intention to limit operation or
|
179 |
+
modification of the work as a means of enforcing, against the work's
|
180 |
+
users, your or third parties' legal rights to forbid circumvention of
|
181 |
+
technological measures.
|
182 |
+
|
183 |
+
4. Conveying Verbatim Copies.
|
184 |
+
|
185 |
+
You may convey verbatim copies of the Program's source code as you
|
186 |
+
receive it, in any medium, provided that you conspicuously and
|
187 |
+
appropriately publish on each copy an appropriate copyright notice;
|
188 |
+
keep intact all notices stating that this License and any
|
189 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
190 |
+
keep intact all notices of the absence of any warranty; and give all
|
191 |
+
recipients a copy of this License along with the Program.
|
192 |
+
|
193 |
+
You may charge any price or no price for each copy that you convey,
|
194 |
+
and you may offer support or warranty protection for a fee.
|
195 |
+
|
196 |
+
5. Conveying Modified Source Versions.
|
197 |
+
|
198 |
+
You may convey a work based on the Program, or the modifications to
|
199 |
+
produce it from the Program, in the form of source code under the
|
200 |
+
terms of section 4, provided that you also meet all of these conditions:
|
201 |
+
|
202 |
+
a) The work must carry prominent notices stating that you modified
|
203 |
+
it, and giving a relevant date.
|
204 |
+
|
205 |
+
b) The work must carry prominent notices stating that it is
|
206 |
+
released under this License and any conditions added under section
|
207 |
+
7. This requirement modifies the requirement in section 4 to
|
208 |
+
"keep intact all notices".
|
209 |
+
|
210 |
+
c) You must license the entire work, as a whole, under this
|
211 |
+
License to anyone who comes into possession of a copy. This
|
212 |
+
License will therefore apply, along with any applicable section 7
|
213 |
+
additional terms, to the whole of the work, and all its parts,
|
214 |
+
regardless of how they are packaged. This License gives no
|
215 |
+
permission to license the work in any other way, but it does not
|
216 |
+
invalidate such permission if you have separately received it.
|
217 |
+
|
218 |
+
d) If the work has interactive user interfaces, each must display
|
219 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
220 |
+
interfaces that do not display Appropriate Legal Notices, your
|
221 |
+
work need not make them do so.
|
222 |
+
|
223 |
+
A compilation of a covered work with other separate and independent
|
224 |
+
works, which are not by their nature extensions of the covered work,
|
225 |
+
and which are not combined with it such as to form a larger program,
|
226 |
+
in or on a volume of a storage or distribution medium, is called an
|
227 |
+
"aggregate" if the compilation and its resulting copyright are not
|
228 |
+
used to limit the access or legal rights of the compilation's users
|
229 |
+
beyond what the individual works permit. Inclusion of a covered work
|
230 |
+
in an aggregate does not cause this License to apply to the other
|
231 |
+
parts of the aggregate.
|
232 |
+
|
233 |
+
6. Conveying Non-Source Forms.
|
234 |
+
|
235 |
+
You may convey a covered work in object code form under the terms
|
236 |
+
of sections 4 and 5, provided that you also convey the
|
237 |
+
machine-readable Corresponding Source under the terms of this License,
|
238 |
+
in one of these ways:
|
239 |
+
|
240 |
+
a) Convey the object code in, or embodied in, a physical product
|
241 |
+
(including a physical distribution medium), accompanied by the
|
242 |
+
Corresponding Source fixed on a durable physical medium
|
243 |
+
customarily used for software interchange.
|
244 |
+
|
245 |
+
b) Convey the object code in, or embodied in, a physical product
|
246 |
+
(including a physical distribution medium), accompanied by a
|
247 |
+
written offer, valid for at least three years and valid for as
|
248 |
+
long as you offer spare parts or customer support for that product
|
249 |
+
model, to give anyone who possesses the object code either (1) a
|
250 |
+
copy of the Corresponding Source for all the software in the
|
251 |
+
product that is covered by this License, on a durable physical
|
252 |
+
medium customarily used for software interchange, for a price no
|
253 |
+
more than your reasonable cost of physically performing this
|
254 |
+
conveying of source, or (2) access to copy the
|
255 |
+
Corresponding Source from a network server at no charge.
|
256 |
+
|
257 |
+
c) Convey individual copies of the object code with a copy of the
|
258 |
+
written offer to provide the Corresponding Source. This
|
259 |
+
alternative is allowed only occasionally and noncommercially, and
|
260 |
+
only if you received the object code with such an offer, in accord
|
261 |
+
with subsection 6b.
|
262 |
+
|
263 |
+
d) Convey the object code by offering access from a designated
|
264 |
+
place (gratis or for a charge), and offer equivalent access to the
|
265 |
+
Corresponding Source in the same way through the same place at no
|
266 |
+
further charge. You need not require recipients to copy the
|
267 |
+
Corresponding Source along with the object code. If the place to
|
268 |
+
copy the object code is a network server, the Corresponding Source
|
269 |
+
may be on a different server (operated by you or a third party)
|
270 |
+
that supports equivalent copying facilities, provided you maintain
|
271 |
+
clear directions next to the object code saying where to find the
|
272 |
+
Corresponding Source. Regardless of what server hosts the
|
273 |
+
Corresponding Source, you remain obligated to ensure that it is
|
274 |
+
available for as long as needed to satisfy these requirements.
|
275 |
+
|
276 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
277 |
+
you inform other peers where the object code and Corresponding
|
278 |
+
Source of the work are being offered to the general public at no
|
279 |
+
charge under subsection 6d.
|
280 |
+
|
281 |
+
A separable portion of the object code, whose source code is excluded
|
282 |
+
from the Corresponding Source as a System Library, need not be
|
283 |
+
included in conveying the object code work.
|
284 |
+
|
285 |
+
A "User Product" is either (1) a "consumer product", which means any
|
286 |
+
tangible personal property which is normally used for personal, family,
|
287 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
288 |
+
into a dwelling. In determining whether a product is a consumer product,
|
289 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
290 |
+
product received by a particular user, "normally used" refers to a
|
291 |
+
typical or common use of that class of product, regardless of the status
|
292 |
+
of the particular user or of the way in which the particular user
|
293 |
+
actually uses, or expects or is expected to use, the product. A product
|
294 |
+
is a consumer product regardless of whether the product has substantial
|
295 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
296 |
+
the only significant mode of use of the product.
|
297 |
+
|
298 |
+
"Installation Information" for a User Product means any methods,
|
299 |
+
procedures, authorization keys, or other information required to install
|
300 |
+
and execute modified versions of a covered work in that User Product from
|
301 |
+
a modified version of its Corresponding Source. The information must
|
302 |
+
suffice to ensure that the continued functioning of the modified object
|
303 |
+
code is in no case prevented or interfered with solely because
|
304 |
+
modification has been made.
|
305 |
+
|
306 |
+
If you convey an object code work under this section in, or with, or
|
307 |
+
specifically for use in, a User Product, and the conveying occurs as
|
308 |
+
part of a transaction in which the right of possession and use of the
|
309 |
+
User Product is transferred to the recipient in perpetuity or for a
|
310 |
+
fixed term (regardless of how the transaction is characterized), the
|
311 |
+
Corresponding Source conveyed under this section must be accompanied
|
312 |
+
by the Installation Information. But this requirement does not apply
|
313 |
+
if neither you nor any third party retains the ability to install
|
314 |
+
modified object code on the User Product (for example, the work has
|
315 |
+
been installed in ROM).
|
316 |
+
|
317 |
+
The requirement to provide Installation Information does not include a
|
318 |
+
requirement to continue to provide support service, warranty, or updates
|
319 |
+
for a work that has been modified or installed by the recipient, or for
|
320 |
+
the User Product in which it has been modified or installed. Access to a
|
321 |
+
network may be denied when the modification itself materially and
|
322 |
+
adversely affects the operation of the network or violates the rules and
|
323 |
+
protocols for communication across the network.
|
324 |
+
|
325 |
+
Corresponding Source conveyed, and Installation Information provided,
|
326 |
+
in accord with this section must be in a format that is publicly
|
327 |
+
documented (and with an implementation available to the public in
|
328 |
+
source code form), and must require no special password or key for
|
329 |
+
unpacking, reading or copying.
|
330 |
+
|
331 |
+
7. Additional Terms.
|
332 |
+
|
333 |
+
"Additional permissions" are terms that supplement the terms of this
|
334 |
+
License by making exceptions from one or more of its conditions.
|
335 |
+
Additional permissions that are applicable to the entire Program shall
|
336 |
+
be treated as though they were included in this License, to the extent
|
337 |
+
that they are valid under applicable law. If additional permissions
|
338 |
+
apply only to part of the Program, that part may be used separately
|
339 |
+
under those permissions, but the entire Program remains governed by
|
340 |
+
this License without regard to the additional permissions.
|
341 |
+
|
342 |
+
When you convey a copy of a covered work, you may at your option
|
343 |
+
remove any additional permissions from that copy, or from any part of
|
344 |
+
it. (Additional permissions may be written to require their own
|
345 |
+
removal in certain cases when you modify the work.) You may place
|
346 |
+
additional permissions on material, added by you to a covered work,
|
347 |
+
for which you have or can give appropriate copyright permission.
|
348 |
+
|
349 |
+
Notwithstanding any other provision of this License, for material you
|
350 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
351 |
+
that material) supplement the terms of this License with terms:
|
352 |
+
|
353 |
+
a) Disclaiming warranty or limiting liability differently from the
|
354 |
+
terms of sections 15 and 16 of this License; or
|
355 |
+
|
356 |
+
b) Requiring preservation of specified reasonable legal notices or
|
357 |
+
author attributions in that material or in the Appropriate Legal
|
358 |
+
Notices displayed by works containing it; or
|
359 |
+
|
360 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
361 |
+
requiring that modified versions of such material be marked in
|
362 |
+
reasonable ways as different from the original version; or
|
363 |
+
|
364 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
365 |
+
authors of the material; or
|
366 |
+
|
367 |
+
e) Declining to grant rights under trademark law for use of some
|
368 |
+
trade names, trademarks, or service marks; or
|
369 |
+
|
370 |
+
f) Requiring indemnification of licensors and authors of that
|
371 |
+
material by anyone who conveys the material (or modified versions of
|
372 |
+
it) with contractual assumptions of liability to the recipient, for
|
373 |
+
any liability that these contractual assumptions directly impose on
|
374 |
+
those licensors and authors.
|
375 |
+
|
376 |
+
All other non-permissive additional terms are considered "further
|
377 |
+
restrictions" within the meaning of section 10. If the Program as you
|
378 |
+
received it, or any part of it, contains a notice stating that it is
|
379 |
+
governed by this License along with a term that is a further
|
380 |
+
restriction, you may remove that term. If a license document contains
|
381 |
+
a further restriction but permits relicensing or conveying under this
|
382 |
+
License, you may add to a covered work material governed by the terms
|
383 |
+
of that license document, provided that the further restriction does
|
384 |
+
not survive such relicensing or conveying.
|
385 |
+
|
386 |
+
If you add terms to a covered work in accord with this section, you
|
387 |
+
must place, in the relevant source files, a statement of the
|
388 |
+
additional terms that apply to those files, or a notice indicating
|
389 |
+
where to find the applicable terms.
|
390 |
+
|
391 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
392 |
+
form of a separately written license, or stated as exceptions;
|
393 |
+
the above requirements apply either way.
|
394 |
+
|
395 |
+
8. Termination.
|
396 |
+
|
397 |
+
You may not propagate or modify a covered work except as expressly
|
398 |
+
provided under this License. Any attempt otherwise to propagate or
|
399 |
+
modify it is void, and will automatically terminate your rights under
|
400 |
+
this License (including any patent licenses granted under the third
|
401 |
+
paragraph of section 11).
|
402 |
+
|
403 |
+
However, if you cease all violation of this License, then your
|
404 |
+
license from a particular copyright holder is reinstated (a)
|
405 |
+
provisionally, unless and until the copyright holder explicitly and
|
406 |
+
finally terminates your license, and (b) permanently, if the copyright
|
407 |
+
holder fails to notify you of the violation by some reasonable means
|
408 |
+
prior to 60 days after the cessation.
|
409 |
+
|
410 |
+
Moreover, your license from a particular copyright holder is
|
411 |
+
reinstated permanently if the copyright holder notifies you of the
|
412 |
+
violation by some reasonable means, this is the first time you have
|
413 |
+
received notice of violation of this License (for any work) from that
|
414 |
+
copyright holder, and you cure the violation prior to 30 days after
|
415 |
+
your receipt of the notice.
|
416 |
+
|
417 |
+
Termination of your rights under this section does not terminate the
|
418 |
+
licenses of parties who have received copies or rights from you under
|
419 |
+
this License. If your rights have been terminated and not permanently
|
420 |
+
reinstated, you do not qualify to receive new licenses for the same
|
421 |
+
material under section 10.
|
422 |
+
|
423 |
+
9. Acceptance Not Required for Having Copies.
|
424 |
+
|
425 |
+
You are not required to accept this License in order to receive or
|
426 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
427 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
428 |
+
to receive a copy likewise does not require acceptance. However,
|
429 |
+
nothing other than this License grants you permission to propagate or
|
430 |
+
modify any covered work. These actions infringe copyright if you do
|
431 |
+
not accept this License. Therefore, by modifying or propagating a
|
432 |
+
covered work, you indicate your acceptance of this License to do so.
|
433 |
+
|
434 |
+
10. Automatic Licensing of Downstream Recipients.
|
435 |
+
|
436 |
+
Each time you convey a covered work, the recipient automatically
|
437 |
+
receives a license from the original licensors, to run, modify and
|
438 |
+
propagate that work, subject to this License. You are not responsible
|
439 |
+
for enforcing compliance by third parties with this License.
|
440 |
+
|
441 |
+
An "entity transaction" is a transaction transferring control of an
|
442 |
+
organization, or substantially all assets of one, or subdividing an
|
443 |
+
organization, or merging organizations. If propagation of a covered
|
444 |
+
work results from an entity transaction, each party to that
|
445 |
+
transaction who receives a copy of the work also receives whatever
|
446 |
+
licenses to the work the party's predecessor in interest had or could
|
447 |
+
give under the previous paragraph, plus a right to possession of the
|
448 |
+
Corresponding Source of the work from the predecessor in interest, if
|
449 |
+
the predecessor has it or can get it with reasonable efforts.
|
450 |
+
|
451 |
+
You may not impose any further restrictions on the exercise of the
|
452 |
+
rights granted or affirmed under this License. For example, you may
|
453 |
+
not impose a license fee, royalty, or other charge for exercise of
|
454 |
+
rights granted under this License, and you may not initiate litigation
|
455 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
456 |
+
any patent claim is infringed by making, using, selling, offering for
|
457 |
+
sale, or importing the Program or any portion of it.
|
458 |
+
|
459 |
+
11. Patents.
|
460 |
+
|
461 |
+
A "contributor" is a copyright holder who authorizes use under this
|
462 |
+
License of the Program or a work on which the Program is based. The
|
463 |
+
work thus licensed is called the contributor's "contributor version".
|
464 |
+
|
465 |
+
A contributor's "essential patent claims" are all patent claims
|
466 |
+
owned or controlled by the contributor, whether already acquired or
|
467 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
468 |
+
by this License, of making, using, or selling its contributor version,
|
469 |
+
but do not include claims that would be infringed only as a
|
470 |
+
consequence of further modification of the contributor version. For
|
471 |
+
purposes of this definition, "control" includes the right to grant
|
472 |
+
patent sublicenses in a manner consistent with the requirements of
|
473 |
+
this License.
|
474 |
+
|
475 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
476 |
+
patent license under the contributor's essential patent claims, to
|
477 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
478 |
+
propagate the contents of its contributor version.
|
479 |
+
|
480 |
+
In the following three paragraphs, a "patent license" is any express
|
481 |
+
agreement or commitment, however denominated, not to enforce a patent
|
482 |
+
(such as an express permission to practice a patent or covenant not to
|
483 |
+
sue for patent infringement). To "grant" such a patent license to a
|
484 |
+
party means to make such an agreement or commitment not to enforce a
|
485 |
+
patent against the party.
|
486 |
+
|
487 |
+
If you convey a covered work, knowingly relying on a patent license,
|
488 |
+
and the Corresponding Source of the work is not available for anyone
|
489 |
+
to copy, free of charge and under the terms of this License, through a
|
490 |
+
publicly available network server or other readily accessible means,
|
491 |
+
then you must either (1) cause the Corresponding Source to be so
|
492 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
493 |
+
patent license for this particular work, or (3) arrange, in a manner
|
494 |
+
consistent with the requirements of this License, to extend the patent
|
495 |
+
license to downstream recipients. "Knowingly relying" means you have
|
496 |
+
actual knowledge that, but for the patent license, your conveying the
|
497 |
+
covered work in a country, or your recipient's use of the covered work
|
498 |
+
in a country, would infringe one or more identifiable patents in that
|
499 |
+
country that you have reason to believe are valid.
|
500 |
+
|
501 |
+
If, pursuant to or in connection with a single transaction or
|
502 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
503 |
+
covered work, and grant a patent license to some of the parties
|
504 |
+
receiving the covered work authorizing them to use, propagate, modify
|
505 |
+
or convey a specific copy of the covered work, then the patent license
|
506 |
+
you grant is automatically extended to all recipients of the covered
|
507 |
+
work and works based on it.
|
508 |
+
|
509 |
+
A patent license is "discriminatory" if it does not include within
|
510 |
+
the scope of its coverage, prohibits the exercise of, or is
|
511 |
+
conditioned on the non-exercise of one or more of the rights that are
|
512 |
+
specifically granted under this License. You may not convey a covered
|
513 |
+
work if you are a party to an arrangement with a third party that is
|
514 |
+
in the business of distributing software, under which you make payment
|
515 |
+
to the third party based on the extent of your activity of conveying
|
516 |
+
the work, and under which the third party grants, to any of the
|
517 |
+
parties who would receive the covered work from you, a discriminatory
|
518 |
+
patent license (a) in connection with copies of the covered work
|
519 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
520 |
+
for and in connection with specific products or compilations that
|
521 |
+
contain the covered work, unless you entered into that arrangement,
|
522 |
+
or that patent license was granted, prior to 28 March 2007.
|
523 |
+
|
524 |
+
Nothing in this License shall be construed as excluding or limiting
|
525 |
+
any implied license or other defenses to infringement that may
|
526 |
+
otherwise be available to you under applicable patent law.
|
527 |
+
|
528 |
+
12. No Surrender of Others' Freedom.
|
529 |
+
|
530 |
+
If conditions are imposed on you (whether by court order, agreement or
|
531 |
+
otherwise) that contradict the conditions of this License, they do not
|
532 |
+
excuse you from the conditions of this License. If you cannot convey a
|
533 |
+
covered work so as to satisfy simultaneously your obligations under this
|
534 |
+
License and any other pertinent obligations, then as a consequence you may
|
535 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
536 |
+
to collect a royalty for further conveying from those to whom you convey
|
537 |
+
the Program, the only way you could satisfy both those terms and this
|
538 |
+
License would be to refrain entirely from conveying the Program.
|
539 |
+
|
540 |
+
13. Remote Network Interaction; Use with the GNU General Public License.
|
541 |
+
|
542 |
+
Notwithstanding any other provision of this License, if you modify the
|
543 |
+
Program, your modified version must prominently offer all users
|
544 |
+
interacting with it remotely through a computer network (if your version
|
545 |
+
supports such interaction) an opportunity to receive the Corresponding
|
546 |
+
Source of your version by providing access to the Corresponding Source
|
547 |
+
from a network server at no charge, through some standard or customary
|
548 |
+
means of facilitating copying of software. This Corresponding Source
|
549 |
+
shall include the Corresponding Source for any work covered by version 3
|
550 |
+
of the GNU General Public License that is incorporated pursuant to the
|
551 |
+
following paragraph.
|
552 |
+
|
553 |
+
Notwithstanding any other provision of this License, you have
|
554 |
+
permission to link or combine any covered work with a work licensed
|
555 |
+
under version 3 of the GNU General Public License into a single
|
556 |
+
combined work, and to convey the resulting work. The terms of this
|
557 |
+
License will continue to apply to the part which is the covered work,
|
558 |
+
but the work with which it is combined will remain governed by version
|
559 |
+
3 of the GNU General Public License.
|
560 |
+
|
561 |
+
14. Revised Versions of this License.
|
562 |
+
|
563 |
+
The Free Software Foundation may publish revised and/or new versions of
|
564 |
+
the GNU Affero General Public License from time to time. Such new versions
|
565 |
+
will be similar in spirit to the present version, but may differ in detail to
|
566 |
+
address new problems or concerns.
|
567 |
+
|
568 |
+
Each version is given a distinguishing version number. If the
|
569 |
+
Program specifies that a certain numbered version of the GNU Affero General
|
570 |
+
Public License "or any later version" applies to it, you have the
|
571 |
+
option of following the terms and conditions either of that numbered
|
572 |
+
version or of any later version published by the Free Software
|
573 |
+
Foundation. If the Program does not specify a version number of the
|
574 |
+
GNU Affero General Public License, you may choose any version ever published
|
575 |
+
by the Free Software Foundation.
|
576 |
+
|
577 |
+
If the Program specifies that a proxy can decide which future
|
578 |
+
versions of the GNU Affero General Public License can be used, that proxy's
|
579 |
+
public statement of acceptance of a version permanently authorizes you
|
580 |
+
to choose that version for the Program.
|
581 |
+
|
582 |
+
Later license versions may give you additional or different
|
583 |
+
permissions. However, no additional obligations are imposed on any
|
584 |
+
author or copyright holder as a result of your choosing to follow a
|
585 |
+
later version.
|
586 |
+
|
587 |
+
15. Disclaimer of Warranty.
|
588 |
+
|
589 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
590 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
591 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
592 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
593 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
594 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
595 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
596 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
597 |
+
|
598 |
+
16. Limitation of Liability.
|
599 |
+
|
600 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
601 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
602 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
603 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
604 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
605 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
606 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
607 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
608 |
+
SUCH DAMAGES.
|
609 |
+
|
610 |
+
17. Interpretation of Sections 15 and 16.
|
611 |
+
|
612 |
+
If the disclaimer of warranty and limitation of liability provided
|
613 |
+
above cannot be given local legal effect according to their terms,
|
614 |
+
reviewing courts shall apply local law that most closely approximates
|
615 |
+
an absolute waiver of all civil liability in connection with the
|
616 |
+
Program, unless a warranty or assumption of liability accompanies a
|
617 |
+
copy of the Program in return for a fee.
|
618 |
+
|
619 |
+
END OF TERMS AND CONDITIONS
|
620 |
+
|
621 |
+
How to Apply These Terms to Your New Programs
|
622 |
+
|
623 |
+
If you develop a new program, and you want it to be of the greatest
|
624 |
+
possible use to the public, the best way to achieve this is to make it
|
625 |
+
free software which everyone can redistribute and change under these terms.
|
626 |
+
|
627 |
+
To do so, attach the following notices to the program. It is safest
|
628 |
+
to attach them to the start of each source file to most effectively
|
629 |
+
state the exclusion of warranty; and each file should have at least
|
630 |
+
the "copyright" line and a pointer to where the full notice is found.
|
631 |
+
|
632 |
+
<one line to give the program's name and a brief idea of what it does.>
|
633 |
+
Copyright (C) <year> <name of author>
|
634 |
+
|
635 |
+
This program is free software: you can redistribute it and/or modify
|
636 |
+
it under the terms of the GNU Affero General Public License as published
|
637 |
+
by the Free Software Foundation, either version 3 of the License, or
|
638 |
+
(at your option) any later version.
|
639 |
+
|
640 |
+
This program is distributed in the hope that it will be useful,
|
641 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
642 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
643 |
+
GNU Affero General Public License for more details.
|
644 |
+
|
645 |
+
You should have received a copy of the GNU Affero General Public License
|
646 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
647 |
+
|
648 |
+
Also add information on how to contact you by electronic and paper mail.
|
649 |
+
|
650 |
+
If your software can interact with users remotely through a computer
|
651 |
+
network, you should also make sure that it provides a way for users to
|
652 |
+
get its source. For example, if your program is a web application, its
|
653 |
+
interface could display a "Source" link that leads users to an archive
|
654 |
+
of the code. There are many ways you could offer source, and different
|
655 |
+
solutions will be better for different programs; see section 13 for the
|
656 |
+
specific requirements.
|
657 |
+
|
658 |
+
You should also get your employer (if you work as a programmer) or school,
|
659 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
660 |
+
For more information on this, and how to apply and follow the GNU AGPL, see
|
661 |
+
<https://www.gnu.org/licenses/>.
|
app/app_pixart_dmd.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
from __future__ import annotations
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
current_file_path = Path(__file__).resolve()
|
9 |
+
sys.path.insert(0, str(current_file_path.parent.parent))
|
10 |
+
import random
|
11 |
+
import gradio as gr
|
12 |
+
import numpy as np
|
13 |
+
import uuid
|
14 |
+
from diffusers import ConsistencyDecoderVAE, PixArtAlphaPipeline, Transformer2DModel, DDPMScheduler
|
15 |
+
import torch
|
16 |
+
from typing import Tuple
|
17 |
+
from datetime import datetime
|
18 |
+
from scripts.diffusers_patches import pipeline_pixart_alpha_call
|
19 |
+
|
20 |
+
DESCRIPTION = """![Logo](https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma-project/master/static/images/logo-sigma.png)
|
21 |
+
# PixArt-Alpha One Step 512px
|
22 |
+
#### [PixArt-Alpha-DMD 512px](https://github.com/PixArt-alpha/PixArt-sigma) is a transformer-based text-to-image diffusion system trained on text embeddings from T5. This demo uses the [PixArt-Alpha-DMD-XL-2-512x512](https://huggingface.co/PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512) checkpoint.
|
23 |
+
#### English prompts ONLY; 提示词仅限英文
|
24 |
+
### <span style='color: red;'>We only use 8 V100 GPUs for PixArt-DMD training. There's still plenty of room for improvement.
|
25 |
+
"""
|
26 |
+
if not torch.cuda.is_available():
|
27 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
28 |
+
|
29 |
+
MAX_SEED = np.iinfo(np.int32).max
|
30 |
+
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
31 |
+
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "6000"))
|
32 |
+
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
33 |
+
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
34 |
+
PORT = int(os.getenv("DEMO_PORT", "15432"))
|
35 |
+
|
36 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
37 |
+
|
38 |
+
style_list = [
|
39 |
+
{
|
40 |
+
"name": "(No style)",
|
41 |
+
"prompt": "{prompt}",
|
42 |
+
"negative_prompt": "",
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"name": "Cinematic",
|
46 |
+
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
|
47 |
+
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"name": "Photographic",
|
51 |
+
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
52 |
+
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"name": "Anime",
|
56 |
+
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
|
57 |
+
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"name": "Manga",
|
61 |
+
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
|
62 |
+
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"name": "Digital Art",
|
66 |
+
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
|
67 |
+
"negative_prompt": "photo, photorealistic, realism, ugly",
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"name": "Pixel art",
|
71 |
+
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
|
72 |
+
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"name": "Fantasy art",
|
76 |
+
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
|
77 |
+
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"name": "Neonpunk",
|
81 |
+
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
|
82 |
+
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"name": "3D Model",
|
86 |
+
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
|
87 |
+
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
|
88 |
+
},
|
89 |
+
]
|
90 |
+
|
91 |
+
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
92 |
+
STYLE_NAMES = list(styles.keys())
|
93 |
+
DEFAULT_STYLE_NAME = "(No style)"
|
94 |
+
SCHEDULE_NAME = ["PixArt-DMD"]
|
95 |
+
DEFAULT_SCHEDULE_NAME = "PixArt-DMD"
|
96 |
+
NUM_IMAGES_PER_PROMPT = 2
|
97 |
+
|
98 |
+
|
99 |
+
def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
|
100 |
+
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
101 |
+
if not negative:
|
102 |
+
negative = ""
|
103 |
+
return p.replace("{prompt}", positive), n + negative
|
104 |
+
|
105 |
+
|
106 |
+
def get_args():
|
107 |
+
parser = argparse.ArgumentParser()
|
108 |
+
parser.add_argument('--model_path', default="PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", type=str)
|
109 |
+
parser.add_argument(
|
110 |
+
'--pipeline_load_from', default="PixArt-alpha/PixArt-XL-2-1024-MS", type=str,
|
111 |
+
help="Download for loading text_encoder, "
|
112 |
+
"tokenizer and vae from https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS")
|
113 |
+
parser.add_argument('--T5_token_max_length', default=120, type=int, help='max length of tokens for T5')
|
114 |
+
return parser.parse_args()
|
115 |
+
|
116 |
+
|
117 |
+
args = get_args()
|
118 |
+
|
119 |
+
if torch.cuda.is_available():
|
120 |
+
weight_dtype = torch.float16
|
121 |
+
T5_token_max_length = args.T5_token_max_length
|
122 |
+
model_path = args.model_path
|
123 |
+
if 'Sigma' in args.model_path:
|
124 |
+
T5_token_max_length = 300
|
125 |
+
|
126 |
+
pipe = PixArtAlphaPipeline.from_pretrained(
|
127 |
+
args.pipeline_load_from,
|
128 |
+
transformer=None,
|
129 |
+
torch_dtype=weight_dtype,
|
130 |
+
)
|
131 |
+
pipe.transformer = Transformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=weight_dtype)
|
132 |
+
pipe.scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
133 |
+
|
134 |
+
print("Changing __call__ method of PixArtAlphaPipeline using scripts.diffusers_patches.pipeline_pixart_alpha_call")
|
135 |
+
setattr(PixArtAlphaPipeline, '__call__', pipeline_pixart_alpha_call)
|
136 |
+
|
137 |
+
if os.getenv('CONSISTENCY_DECODER', False):
|
138 |
+
print("Using DALL-E 3 Consistency Decoder")
|
139 |
+
pipe.vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
|
140 |
+
|
141 |
+
if ENABLE_CPU_OFFLOAD:
|
142 |
+
pipe.enable_model_cpu_offload()
|
143 |
+
else:
|
144 |
+
pipe.to(device)
|
145 |
+
print("Loaded on Device!")
|
146 |
+
|
147 |
+
# speed-up T5
|
148 |
+
pipe.text_encoder.to_bettertransformer()
|
149 |
+
|
150 |
+
if USE_TORCH_COMPILE:
|
151 |
+
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
|
152 |
+
print("Model Compiled!")
|
153 |
+
|
154 |
+
|
155 |
+
def save_image(img, seed=''):
|
156 |
+
unique_name = f"{str(uuid.uuid4())}_{seed}.png"
|
157 |
+
save_path = os.path.join(f'output/online_demo_img/{datetime.now().date()}')
|
158 |
+
os.umask(0o000) # file permission: 666; dir permission: 777
|
159 |
+
os.makedirs(save_path, exist_ok=True)
|
160 |
+
unique_name = os.path.join(save_path, unique_name)
|
161 |
+
img.save(unique_name)
|
162 |
+
return unique_name
|
163 |
+
|
164 |
+
|
165 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
166 |
+
if randomize_seed:
|
167 |
+
seed = random.randint(0, MAX_SEED)
|
168 |
+
return seed
|
169 |
+
|
170 |
+
|
171 |
+
@torch.no_grad()
|
172 |
+
@torch.inference_mode()
|
173 |
+
def generate(
|
174 |
+
prompt: str,
|
175 |
+
negative_prompt: str = "",
|
176 |
+
style: str = DEFAULT_STYLE_NAME,
|
177 |
+
use_negative_prompt: bool = False,
|
178 |
+
num_imgs: int = 1,
|
179 |
+
seed: int = 0,
|
180 |
+
width: int = 1024,
|
181 |
+
height: int = 1024,
|
182 |
+
randomize_seed: bool = False,
|
183 |
+
use_resolution_binning: bool = True,
|
184 |
+
progress=gr.Progress(track_tqdm=True),
|
185 |
+
):
|
186 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
187 |
+
generator = torch.Generator().manual_seed(seed)
|
188 |
+
print(f"{PORT}: {model_path}")
|
189 |
+
print(prompt)
|
190 |
+
|
191 |
+
if not use_negative_prompt:
|
192 |
+
negative_prompt = None # type: ignore
|
193 |
+
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
194 |
+
|
195 |
+
images = pipe(
|
196 |
+
prompt=prompt,
|
197 |
+
timesteps=[400],
|
198 |
+
width=width,
|
199 |
+
height=height,
|
200 |
+
guidance_scale=1,
|
201 |
+
num_inference_steps=1,
|
202 |
+
generator=generator,
|
203 |
+
num_images_per_prompt=num_imgs,
|
204 |
+
use_resolution_binning=use_resolution_binning,
|
205 |
+
output_type="pil",
|
206 |
+
max_sequence_length=T5_token_max_length,
|
207 |
+
).images
|
208 |
+
|
209 |
+
image_paths = [save_image(img, seed) for img in images]
|
210 |
+
print(image_paths)
|
211 |
+
return image_paths, seed
|
212 |
+
|
213 |
+
|
214 |
+
examples = [
|
215 |
+
"A small cactus with a happy face in the Sahara desert.",
|
216 |
+
"an astronaut sitting in a diner, eating fries, cinematic, analog film",
|
217 |
+
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
|
218 |
+
"stars, water, brilliantly, gorgeous large scale scene, a little girl, in the style of dreamy realism, light gold and amber, blue and pink, brilliantly illuminated in the background.",
|
219 |
+
"professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
|
220 |
+
"beautiful lady, freckles, big smile, blue eyes, short ginger hair, dark makeup, wearing a floral blue vest top, soft light, dark grey background",
|
221 |
+
"Spectacular Tiny World in the Transparent Jar On the Table, interior of the Great Hall, Elaborate, Carved Architecture, Anatomy, Symetrical, Geometric and Parameteric Details, Precision Flat line Details, Pattern, Dark fantasy, Dark errie mood and ineffably mysterious mood, Technical design, Intricate Ultra Detail, Ornate Detail, Stylized and Futuristic and Biomorphic Details, Architectural Concept, Low contrast Details, Cinematic Lighting, 8k, by moebius, Fullshot, Epic, Fullshot, Octane render, Unreal ,Photorealistic, Hyperrealism",
|
222 |
+
"anthropomorphic profile of the white snow owl Crystal priestess , art deco painting, pretty and expressive eyes, ornate costume, mythical, ethereal, intricate, elaborate, hyperrealism, hyper detailed, 3D, 8K, Ultra Realistic, high octane, ultra resolution, amazing detail, perfection, In frame, photorealistic, cinematic lighting, visual clarity, shading , Lumen Reflections, Super-Resolution, gigapixel, color grading, retouch, enhanced, PBR, Blender, V-ray, Procreate, zBrush, Unreal Engine 5, cinematic, volumetric, dramatic, neon lighting, wide angle lens ,no digital painting blur",
|
223 |
+
"The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8",
|
224 |
+
]
|
225 |
+
|
226 |
+
with gr.Blocks(css="scripts/style.css") as demo:
|
227 |
+
gr.Markdown(DESCRIPTION)
|
228 |
+
gr.DuplicateButton(
|
229 |
+
value="Duplicate Space for private use",
|
230 |
+
elem_id="duplicate-button",
|
231 |
+
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
232 |
+
)
|
233 |
+
with gr.Row(equal_height=False):
|
234 |
+
with gr.Group():
|
235 |
+
with gr.Row():
|
236 |
+
prompt = gr.Text(
|
237 |
+
label="Prompt",
|
238 |
+
show_label=False,
|
239 |
+
max_lines=1,
|
240 |
+
placeholder="Enter your prompt",
|
241 |
+
container=False,
|
242 |
+
)
|
243 |
+
run_button = gr.Button("Run", scale=0)
|
244 |
+
result = gr.Gallery(label="Result", show_label=False)
|
245 |
+
# with gr.Accordion("Advanced options", open=False):
|
246 |
+
with gr.Group():
|
247 |
+
with gr.Row():
|
248 |
+
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
|
249 |
+
with gr.Row(visible=True):
|
250 |
+
schedule = gr.Radio(
|
251 |
+
show_label=True,
|
252 |
+
container=True,
|
253 |
+
interactive=True,
|
254 |
+
choices=SCHEDULE_NAME,
|
255 |
+
value=DEFAULT_SCHEDULE_NAME,
|
256 |
+
label="Sampler Schedule",
|
257 |
+
visible=True,
|
258 |
+
)
|
259 |
+
num_imgs = gr.Slider(
|
260 |
+
label="Num Images",
|
261 |
+
minimum=1,
|
262 |
+
maximum=8,
|
263 |
+
step=1,
|
264 |
+
value=NUM_IMAGES_PER_PROMPT,
|
265 |
+
)
|
266 |
+
style_selection = gr.Radio(
|
267 |
+
show_label=True,
|
268 |
+
container=True,
|
269 |
+
interactive=True,
|
270 |
+
choices=STYLE_NAMES,
|
271 |
+
value=DEFAULT_STYLE_NAME,
|
272 |
+
label="Image Style",
|
273 |
+
)
|
274 |
+
negative_prompt = gr.Text(
|
275 |
+
label="Negative prompt",
|
276 |
+
max_lines=1,
|
277 |
+
placeholder="Enter a negative prompt",
|
278 |
+
visible=True,
|
279 |
+
)
|
280 |
+
seed = gr.Slider(
|
281 |
+
label="Seed",
|
282 |
+
minimum=0,
|
283 |
+
maximum=MAX_SEED,
|
284 |
+
step=1,
|
285 |
+
value=0,
|
286 |
+
)
|
287 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
288 |
+
with gr.Row(visible=True):
|
289 |
+
width = gr.Slider(
|
290 |
+
label="Width",
|
291 |
+
minimum=256,
|
292 |
+
maximum=MAX_IMAGE_SIZE,
|
293 |
+
step=32,
|
294 |
+
value=512,
|
295 |
+
)
|
296 |
+
height = gr.Slider(
|
297 |
+
label="Height",
|
298 |
+
minimum=256,
|
299 |
+
maximum=MAX_IMAGE_SIZE,
|
300 |
+
step=32,
|
301 |
+
value=512,
|
302 |
+
)
|
303 |
+
|
304 |
+
gr.Examples(
|
305 |
+
examples=examples,
|
306 |
+
inputs=prompt,
|
307 |
+
outputs=[result, seed],
|
308 |
+
fn=generate,
|
309 |
+
cache_examples=CACHE_EXAMPLES,
|
310 |
+
)
|
311 |
+
|
312 |
+
use_negative_prompt.change(
|
313 |
+
fn=lambda x: gr.update(visible=x),
|
314 |
+
inputs=use_negative_prompt,
|
315 |
+
outputs=negative_prompt,
|
316 |
+
api_name=False,
|
317 |
+
)
|
318 |
+
|
319 |
+
gr.on(
|
320 |
+
triggers=[
|
321 |
+
prompt.submit,
|
322 |
+
negative_prompt.submit,
|
323 |
+
run_button.click,
|
324 |
+
],
|
325 |
+
fn=generate,
|
326 |
+
inputs=[
|
327 |
+
prompt,
|
328 |
+
negative_prompt,
|
329 |
+
style_selection,
|
330 |
+
use_negative_prompt,
|
331 |
+
num_imgs,
|
332 |
+
seed,
|
333 |
+
width,
|
334 |
+
height,
|
335 |
+
schedule,
|
336 |
+
randomize_seed,
|
337 |
+
],
|
338 |
+
outputs=[result, seed],
|
339 |
+
api_name="run",
|
340 |
+
)
|
341 |
+
|
342 |
+
if __name__ == "__main__":
|
343 |
+
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=PORT, debug=True)
|
app/app_pixart_sigma.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
from __future__ import annotations
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
from pathlib import Path
|
7 |
+
current_file_path = Path(__file__).resolve()
|
8 |
+
sys.path.insert(0, str(current_file_path.parent.parent))
|
9 |
+
import random
|
10 |
+
import gradio as gr
|
11 |
+
import numpy as np
|
12 |
+
import uuid
|
13 |
+
from diffusers import ConsistencyDecoderVAE, DPMSolverMultistepScheduler, Transformer2DModel, AutoencoderKL
|
14 |
+
import torch
|
15 |
+
from typing import Tuple
|
16 |
+
from datetime import datetime
|
17 |
+
from diffusion.sa_solver_diffusers import SASolverScheduler
|
18 |
+
from peft import PeftModel
|
19 |
+
from scripts.diffusers_patches import pixart_sigma_init_patched_inputs, PixArtSigmaPipeline
|
20 |
+
|
21 |
+
|
22 |
+
DESCRIPTION = """![Logo](https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma-project/master/static/images/logo-sigma.png)
|
23 |
+
# PixArt-Sigma 1024px
|
24 |
+
#### [PixArt-Sigma 1024px](https://github.com/PixArt-alpha/PixArt-sigma) is a transformer-based text-to-image diffusion system trained on text embeddings from T5. This demo uses the [PixArt-alpha/PixArt-XL-2-1024-MS](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS) checkpoint.
|
25 |
+
#### English prompts ONLY; 提示词仅限英文
|
26 |
+
### <span style='color: red;'>You may change the DPM-Solver inference steps from 14 to 20, or DPM-Solver Guidance scale from 4.5 to 3.5 if you didn't get satisfied results.
|
27 |
+
"""
|
28 |
+
if not torch.cuda.is_available():
|
29 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
30 |
+
|
31 |
+
MAX_SEED = np.iinfo(np.int32).max
|
32 |
+
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
33 |
+
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "6000"))
|
34 |
+
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
35 |
+
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
36 |
+
PORT = int(os.getenv("DEMO_PORT", "15432"))
|
37 |
+
|
38 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
39 |
+
|
40 |
+
|
41 |
+
style_list = [
|
42 |
+
{
|
43 |
+
"name": "(No style)",
|
44 |
+
"prompt": "{prompt}",
|
45 |
+
"negative_prompt": "",
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"name": "Cinematic",
|
49 |
+
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
|
50 |
+
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"name": "Photographic",
|
54 |
+
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
55 |
+
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"name": "Anime",
|
59 |
+
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
|
60 |
+
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"name": "Manga",
|
64 |
+
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
|
65 |
+
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"name": "Digital Art",
|
69 |
+
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
|
70 |
+
"negative_prompt": "photo, photorealistic, realism, ugly",
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"name": "Pixel art",
|
74 |
+
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
|
75 |
+
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"name": "Fantasy art",
|
79 |
+
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
|
80 |
+
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"name": "Neonpunk",
|
84 |
+
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
|
85 |
+
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"name": "3D Model",
|
89 |
+
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
|
90 |
+
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
|
91 |
+
},
|
92 |
+
]
|
93 |
+
|
94 |
+
|
95 |
+
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
96 |
+
STYLE_NAMES = list(styles.keys())
|
97 |
+
DEFAULT_STYLE_NAME = "(No style)"
|
98 |
+
SCHEDULE_NAME = ["DPM-Solver", "SA-Solver"]
|
99 |
+
DEFAULT_SCHEDULE_NAME = "DPM-Solver"
|
100 |
+
NUM_IMAGES_PER_PROMPT = 1
|
101 |
+
|
102 |
+
def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
|
103 |
+
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
104 |
+
if not negative:
|
105 |
+
negative = ""
|
106 |
+
return p.replace("{prompt}", positive), n + negative
|
107 |
+
|
108 |
+
|
109 |
+
def get_args():
|
110 |
+
parser = argparse.ArgumentParser()
|
111 |
+
parser.add_argument('--is_lora', action='store_true', help='enable lora ckpt loading')
|
112 |
+
parser.add_argument('--repo_id', default="PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", type=str)
|
113 |
+
parser.add_argument('--lora_repo_id', default=None, type=str)
|
114 |
+
parser.add_argument('--model_path', default=None, type=str)
|
115 |
+
parser.add_argument(
|
116 |
+
'--pipeline_load_from', default="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", type=str,
|
117 |
+
help="Download for loading text_encoder, tokenizer and vae "
|
118 |
+
"from https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS")
|
119 |
+
parser.add_argument('--T5_token_max_length', default=120, type=int, help='max length of tokens for T5')
|
120 |
+
return parser.parse_args()
|
121 |
+
|
122 |
+
|
123 |
+
args = get_args()
|
124 |
+
|
125 |
+
if torch.cuda.is_available():
|
126 |
+
weight_dtype = torch.float16
|
127 |
+
T5_token_max_length = args.T5_token_max_length
|
128 |
+
model_path = args.model_path
|
129 |
+
if 'Sigma' in args.model_path:
|
130 |
+
T5_token_max_length = 300
|
131 |
+
|
132 |
+
# tmp patches for diffusers PixArtSigmaPipeline Implementation
|
133 |
+
print(
|
134 |
+
"Changing _init_patched_inputs method of diffusers.models.Transformer2DModel "
|
135 |
+
"using scripts.diffusers_patches.pixart_sigma_init_patched_inputs")
|
136 |
+
setattr(Transformer2DModel, '_init_patched_inputs', pixart_sigma_init_patched_inputs)
|
137 |
+
|
138 |
+
if not args.is_lora:
|
139 |
+
transformer = Transformer2DModel.from_pretrained(
|
140 |
+
model_path,
|
141 |
+
subfolder='transformer',
|
142 |
+
torch_dtype=weight_dtype,
|
143 |
+
)
|
144 |
+
pipe = PixArtSigmaPipeline.from_pretrained(
|
145 |
+
args.pipeline_load_from,
|
146 |
+
transformer=transformer,
|
147 |
+
torch_dtype=weight_dtype,
|
148 |
+
use_safetensors=True,
|
149 |
+
)
|
150 |
+
else:
|
151 |
+
assert args.lora_repo_id is not None
|
152 |
+
transformer = Transformer2DModel.from_pretrained(args.repo_id, subfolder="transformer", torch_dtype=torch.float16)
|
153 |
+
transformer = PeftModel.from_pretrained(transformer, args.lora_repo_id)
|
154 |
+
pipe = PixArtSigmaPipeline.from_pretrained(
|
155 |
+
args.repo_id,
|
156 |
+
transformer=transformer,
|
157 |
+
torch_dtype=torch.float16,
|
158 |
+
use_safetensors=True,
|
159 |
+
)
|
160 |
+
del transformer
|
161 |
+
|
162 |
+
|
163 |
+
if os.getenv('CONSISTENCY_DECODER', False):
|
164 |
+
print("Using DALL-E 3 Consistency Decoder")
|
165 |
+
pipe.vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
|
166 |
+
|
167 |
+
if ENABLE_CPU_OFFLOAD:
|
168 |
+
pipe.enable_model_cpu_offload()
|
169 |
+
else:
|
170 |
+
pipe.to(device)
|
171 |
+
print("Loaded on Device!")
|
172 |
+
|
173 |
+
# speed-up T5
|
174 |
+
pipe.text_encoder.to_bettertransformer()
|
175 |
+
|
176 |
+
if USE_TORCH_COMPILE:
|
177 |
+
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
|
178 |
+
print("Model Compiled!")
|
179 |
+
|
180 |
+
|
181 |
+
def save_image(img, seed=''):
|
182 |
+
unique_name = f"{str(uuid.uuid4())}_{seed}.png"
|
183 |
+
save_path = os.path.join(f'output/online_demo_img/{datetime.now().date()}')
|
184 |
+
os.umask(0o000) # file permission: 666; dir permission: 777
|
185 |
+
os.makedirs(save_path, exist_ok=True)
|
186 |
+
unique_name = os.path.join(save_path, unique_name)
|
187 |
+
img.save(unique_name)
|
188 |
+
return unique_name
|
189 |
+
|
190 |
+
|
191 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
192 |
+
if randomize_seed:
|
193 |
+
seed = random.randint(0, MAX_SEED)
|
194 |
+
return seed
|
195 |
+
|
196 |
+
|
197 |
+
@torch.no_grad()
|
198 |
+
@torch.inference_mode()
|
199 |
+
def generate(
|
200 |
+
prompt: str,
|
201 |
+
negative_prompt: str = "",
|
202 |
+
style: str = DEFAULT_STYLE_NAME,
|
203 |
+
use_negative_prompt: bool = False,
|
204 |
+
num_imgs: int = 1,
|
205 |
+
seed: int = 0,
|
206 |
+
width: int = 1024,
|
207 |
+
height: int = 1024,
|
208 |
+
schedule: str = 'DPM-Solver',
|
209 |
+
dpms_guidance_scale: float = 4.5,
|
210 |
+
sas_guidance_scale: float = 3,
|
211 |
+
dpms_inference_steps: int = 20,
|
212 |
+
sas_inference_steps: int = 25,
|
213 |
+
randomize_seed: bool = False,
|
214 |
+
use_resolution_binning: bool = True,
|
215 |
+
progress=gr.Progress(track_tqdm=True),
|
216 |
+
):
|
217 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
218 |
+
generator = torch.Generator().manual_seed(seed)
|
219 |
+
print(f"{PORT}: {model_path}")
|
220 |
+
print(prompt)
|
221 |
+
|
222 |
+
if schedule == 'DPM-Solver':
|
223 |
+
if not isinstance(pipe.scheduler, DPMSolverMultistepScheduler):
|
224 |
+
pipe.scheduler = DPMSolverMultistepScheduler()
|
225 |
+
num_inference_steps = dpms_inference_steps
|
226 |
+
guidance_scale = dpms_guidance_scale
|
227 |
+
elif schedule == "SA-Solver":
|
228 |
+
if not isinstance(pipe.scheduler, SASolverScheduler):
|
229 |
+
pipe.scheduler = SASolverScheduler.from_config(pipe.scheduler.config, algorithm_type='data_prediction', tau_func=lambda t: 1 if 200 <= t <= 800 else 0, predictor_order=2, corrector_order=2)
|
230 |
+
num_inference_steps = sas_inference_steps
|
231 |
+
guidance_scale = sas_guidance_scale
|
232 |
+
else:
|
233 |
+
raise ValueError(f"Unknown schedule: {schedule}")
|
234 |
+
|
235 |
+
if not use_negative_prompt:
|
236 |
+
negative_prompt = None # type: ignore
|
237 |
+
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
238 |
+
|
239 |
+
images = pipe(
|
240 |
+
prompt=prompt,
|
241 |
+
width=width,
|
242 |
+
height=height,
|
243 |
+
guidance_scale=guidance_scale,
|
244 |
+
num_inference_steps=num_inference_steps,
|
245 |
+
generator=generator,
|
246 |
+
num_images_per_prompt=num_imgs,
|
247 |
+
use_resolution_binning=use_resolution_binning,
|
248 |
+
output_type="pil",
|
249 |
+
max_sequence_length=args.T5_token_max_length,
|
250 |
+
).images
|
251 |
+
|
252 |
+
image_paths = [save_image(img, seed) for img in images]
|
253 |
+
print(image_paths)
|
254 |
+
return image_paths, seed
|
255 |
+
|
256 |
+
|
257 |
+
examples = [
|
258 |
+
"A small cactus with a happy face in the Sahara desert.",
|
259 |
+
"an astronaut sitting in a diner, eating fries, cinematic, analog film",
|
260 |
+
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
|
261 |
+
"stars, water, brilliantly, gorgeous large scale scene, a little girl, in the style of dreamy realism, light gold and amber, blue and pink, brilliantly illuminated in the background.",
|
262 |
+
"professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
|
263 |
+
"beautiful lady, freckles, big smile, blue eyes, short ginger hair, dark makeup, wearing a floral blue vest top, soft light, dark grey background",
|
264 |
+
"Spectacular Tiny World in the Transparent Jar On the Table, interior of the Great Hall, Elaborate, Carved Architecture, Anatomy, Symetrical, Geometric and Parameteric Details, Precision Flat line Details, Pattern, Dark fantasy, Dark errie mood and ineffably mysterious mood, Technical design, Intricate Ultra Detail, Ornate Detail, Stylized and Futuristic and Biomorphic Details, Architectural Concept, Low contrast Details, Cinematic Lighting, 8k, by moebius, Fullshot, Epic, Fullshot, Octane render, Unreal ,Photorealistic, Hyperrealism",
|
265 |
+
"anthropomorphic profile of the white snow owl Crystal priestess , art deco painting, pretty and expressive eyes, ornate costume, mythical, ethereal, intricate, elaborate, hyperrealism, hyper detailed, 3D, 8K, Ultra Realistic, high octane, ultra resolution, amazing detail, perfection, In frame, photorealistic, cinematic lighting, visual clarity, shading , Lumen Reflections, Super-Resolution, gigapixel, color grading, retouch, enhanced, PBR, Blender, V-ray, Procreate, zBrush, Unreal Engine 5, cinematic, volumetric, dramatic, neon lighting, wide angle lens ,no digital painting blur",
|
266 |
+
"The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8",
|
267 |
+
]
|
268 |
+
|
269 |
+
with gr.Blocks(css="scripts/style.css") as demo:
|
270 |
+
gr.Markdown(DESCRIPTION)
|
271 |
+
gr.DuplicateButton(
|
272 |
+
value="Duplicate Space for private use",
|
273 |
+
elem_id="duplicate-button",
|
274 |
+
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
275 |
+
)
|
276 |
+
with gr.Row(equal_height=False):
|
277 |
+
with gr.Group():
|
278 |
+
with gr.Row():
|
279 |
+
prompt = gr.Text(
|
280 |
+
label="Prompt",
|
281 |
+
show_label=False,
|
282 |
+
max_lines=1,
|
283 |
+
placeholder="Enter your prompt",
|
284 |
+
container=False,
|
285 |
+
)
|
286 |
+
run_button = gr.Button("Run", scale=0)
|
287 |
+
result = gr.Gallery(label="Result", show_label=False)
|
288 |
+
# with gr.Accordion("Advanced options", open=False):
|
289 |
+
with gr.Group():
|
290 |
+
with gr.Row():
|
291 |
+
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
|
292 |
+
with gr.Row(visible=True):
|
293 |
+
schedule = gr.Radio(
|
294 |
+
show_label=True,
|
295 |
+
container=True,
|
296 |
+
interactive=True,
|
297 |
+
choices=SCHEDULE_NAME,
|
298 |
+
value=DEFAULT_SCHEDULE_NAME,
|
299 |
+
label="Sampler Schedule",
|
300 |
+
visible=True,
|
301 |
+
)
|
302 |
+
num_imgs = gr.Slider(
|
303 |
+
label="Num Images",
|
304 |
+
minimum=1,
|
305 |
+
maximum=8,
|
306 |
+
step=1,
|
307 |
+
value=1,
|
308 |
+
)
|
309 |
+
style_selection = gr.Radio(
|
310 |
+
show_label=True,
|
311 |
+
container=True,
|
312 |
+
interactive=True,
|
313 |
+
choices=STYLE_NAMES,
|
314 |
+
value=DEFAULT_STYLE_NAME,
|
315 |
+
label="Image Style",
|
316 |
+
)
|
317 |
+
negative_prompt = gr.Text(
|
318 |
+
label="Negative prompt",
|
319 |
+
max_lines=1,
|
320 |
+
placeholder="Enter a negative prompt",
|
321 |
+
visible=True,
|
322 |
+
)
|
323 |
+
seed = gr.Slider(
|
324 |
+
label="Seed",
|
325 |
+
minimum=0,
|
326 |
+
maximum=MAX_SEED,
|
327 |
+
step=1,
|
328 |
+
value=0,
|
329 |
+
)
|
330 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
331 |
+
with gr.Row(visible=True):
|
332 |
+
width = gr.Slider(
|
333 |
+
label="Width",
|
334 |
+
minimum=256,
|
335 |
+
maximum=MAX_IMAGE_SIZE,
|
336 |
+
step=32,
|
337 |
+
value=1024,
|
338 |
+
)
|
339 |
+
height = gr.Slider(
|
340 |
+
label="Height",
|
341 |
+
minimum=256,
|
342 |
+
maximum=MAX_IMAGE_SIZE,
|
343 |
+
step=32,
|
344 |
+
value=1024,
|
345 |
+
)
|
346 |
+
with gr.Row():
|
347 |
+
dpms_guidance_scale = gr.Slider(
|
348 |
+
label="DPM-Solver Guidance scale",
|
349 |
+
minimum=1,
|
350 |
+
maximum=10,
|
351 |
+
step=0.1,
|
352 |
+
value=4.5,
|
353 |
+
)
|
354 |
+
dpms_inference_steps = gr.Slider(
|
355 |
+
label="DPM-Solver inference steps",
|
356 |
+
minimum=5,
|
357 |
+
maximum=40,
|
358 |
+
step=1,
|
359 |
+
value=14,
|
360 |
+
)
|
361 |
+
with gr.Row():
|
362 |
+
sas_guidance_scale = gr.Slider(
|
363 |
+
label="SA-Solver Guidance scale",
|
364 |
+
minimum=1,
|
365 |
+
maximum=10,
|
366 |
+
step=0.1,
|
367 |
+
value=3,
|
368 |
+
)
|
369 |
+
sas_inference_steps = gr.Slider(
|
370 |
+
label="SA-Solver inference steps",
|
371 |
+
minimum=10,
|
372 |
+
maximum=40,
|
373 |
+
step=1,
|
374 |
+
value=25,
|
375 |
+
)
|
376 |
+
|
377 |
+
gr.Examples(
|
378 |
+
examples=examples,
|
379 |
+
inputs=prompt,
|
380 |
+
outputs=[result, seed],
|
381 |
+
fn=generate,
|
382 |
+
cache_examples=CACHE_EXAMPLES,
|
383 |
+
)
|
384 |
+
|
385 |
+
use_negative_prompt.change(
|
386 |
+
fn=lambda x: gr.update(visible=x),
|
387 |
+
inputs=use_negative_prompt,
|
388 |
+
outputs=negative_prompt,
|
389 |
+
api_name=False,
|
390 |
+
)
|
391 |
+
|
392 |
+
gr.on(
|
393 |
+
triggers=[
|
394 |
+
prompt.submit,
|
395 |
+
negative_prompt.submit,
|
396 |
+
run_button.click,
|
397 |
+
],
|
398 |
+
fn=generate,
|
399 |
+
inputs=[
|
400 |
+
prompt,
|
401 |
+
negative_prompt,
|
402 |
+
style_selection,
|
403 |
+
use_negative_prompt,
|
404 |
+
num_imgs,
|
405 |
+
seed,
|
406 |
+
width,
|
407 |
+
height,
|
408 |
+
schedule,
|
409 |
+
dpms_guidance_scale,
|
410 |
+
sas_guidance_scale,
|
411 |
+
dpms_inference_steps,
|
412 |
+
sas_inference_steps,
|
413 |
+
randomize_seed,
|
414 |
+
],
|
415 |
+
outputs=[result, seed],
|
416 |
+
api_name="run",
|
417 |
+
)
|
418 |
+
|
419 |
+
if __name__ == "__main__":
|
420 |
+
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=PORT, debug=True)
|
asset/PixArt.svg
ADDED
asset/docs/pixart.md
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
4 |
+
the License. You may obtain a copy of the License at
|
5 |
+
|
6 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
7 |
+
|
8 |
+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
9 |
+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
10 |
+
specific language governing permissions and limitations under the License.
|
11 |
+
-->
|
12 |
+
|
13 |
+
[//]: # ((reference from [hugging Face](https://github.com/huggingface/diffusers/blob/docs/8bit-inference-pixart/docs/source/en/api/pipelines/pixart.md)))
|
14 |
+
|
15 |
+
## Running the `PixArtAlphaPipeline` in under 8GB GPU VRAM
|
16 |
+
|
17 |
+
It is possible to run the [`PixArtAlphaPipeline`] under 8GB GPU VRAM by loading the text encoder in 8-bit numerical precision. Let's walk through a full-fledged example.
|
18 |
+
|
19 |
+
First, install the `bitsandbytes` library:
|
20 |
+
|
21 |
+
```bash
|
22 |
+
pip install -U bitsandbytes
|
23 |
+
```
|
24 |
+
|
25 |
+
Then load the text encoder in 8-bit:
|
26 |
+
|
27 |
+
```python
|
28 |
+
from transformers import T5EncoderModel
|
29 |
+
from diffusers import PixArtAlphaPipeline
|
30 |
+
|
31 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
32 |
+
"PixArt-alpha/PixArt-XL-2-1024-MS",
|
33 |
+
subfolder="text_encoder",
|
34 |
+
load_in_8bit=True,
|
35 |
+
device_map="auto",
|
36 |
+
|
37 |
+
)
|
38 |
+
pipe = PixArtAlphaPipeline.from_pretrained(
|
39 |
+
"PixArt-alpha/PixArt-XL-2-1024-MS",
|
40 |
+
text_encoder=text_encoder,
|
41 |
+
transformer=None,
|
42 |
+
device_map="auto"
|
43 |
+
)
|
44 |
+
```
|
45 |
+
|
46 |
+
Now, use the `pipe` to encode a prompt:
|
47 |
+
|
48 |
+
```python
|
49 |
+
with torch.no_grad():
|
50 |
+
prompt = "cute cat"
|
51 |
+
prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)
|
52 |
+
|
53 |
+
del text_encoder
|
54 |
+
del pipe
|
55 |
+
flush()
|
56 |
+
```
|
57 |
+
|
58 |
+
`flush()` is just a utility function to clear the GPU VRAM and is implemented like so:
|
59 |
+
|
60 |
+
```python
|
61 |
+
import gc
|
62 |
+
|
63 |
+
def flush():
|
64 |
+
gc.collect()
|
65 |
+
torch.cuda.empty_cache()
|
66 |
+
```
|
67 |
+
|
68 |
+
Then compute the latents providing the prompt embeddings as inputs:
|
69 |
+
|
70 |
+
```python
|
71 |
+
pipe = PixArtAlphaPipeline.from_pretrained(
|
72 |
+
"PixArt-alpha/PixArt-XL-2-1024-MS",
|
73 |
+
text_encoder=None,
|
74 |
+
torch_dtype=torch.float16,
|
75 |
+
).to("cuda")
|
76 |
+
|
77 |
+
latents = pipe(
|
78 |
+
negative_prompt=None,
|
79 |
+
prompt_embeds=prompt_embeds,
|
80 |
+
negative_prompt_embeds=negative_embeds,
|
81 |
+
prompt_attention_mask=prompt_attention_mask,
|
82 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
83 |
+
num_images_per_prompt=1,
|
84 |
+
output_type="latent",
|
85 |
+
).images
|
86 |
+
|
87 |
+
del pipe.transformer
|
88 |
+
flush()
|
89 |
+
```
|
90 |
+
|
91 |
+
Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded.
|
92 |
+
|
93 |
+
Once the latents are computed, pass it off the VAE to decode into a real image:
|
94 |
+
|
95 |
+
```python
|
96 |
+
with torch.no_grad():
|
97 |
+
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
|
98 |
+
image = pipe.image_processor.postprocess(image, output_type="pil")
|
99 |
+
image.save("cat.png")
|
100 |
+
```
|
101 |
+
|
102 |
+
All of this, put together, should allow you to run [`PixArtAlphaPipeline`] under 8GB GPU VRAM.
|
103 |
+
|
104 |
+
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/8bits_cat.png)
|
105 |
+
|
106 |
+
Find the script [here](https://gist.github.com/sayakpaul/3ae0f847001d342af27018a96f467e4e) that can be run end-to-end to report the memory being used.
|
107 |
+
|
108 |
+
<Tip warning={true}>
|
109 |
+
|
110 |
+
Text embeddings computed in 8-bit can have an impact on the quality of the generated images because of the information loss in the representation space induced by the reduced precision. It's recommended to compare the outputs with and without 8-bit.
|
111 |
+
|
112 |
+
</Tip>
|
asset/examples.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
examples = [
|
3 |
+
[
|
4 |
+
"A small cactus with a happy face in the Sahara desert.",
|
5 |
+
"dpm-solver", 20, 4.5,
|
6 |
+
],
|
7 |
+
[
|
8 |
+
"An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history"
|
9 |
+
"of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits "
|
10 |
+
"mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt, he wears a brown beret "
|
11 |
+
"and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile "
|
12 |
+
"as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and "
|
13 |
+
"the Parisian streets and city in the background, depth of field, cinematic 35mm film.",
|
14 |
+
"dpm-solver", 20, 4.5,
|
15 |
+
],
|
16 |
+
[
|
17 |
+
"An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. "
|
18 |
+
"Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. "
|
19 |
+
"The quote 'Find the universe within you' is etched in bold letters across the horizon."
|
20 |
+
"blue and pink, brilliantly illuminated in the background.",
|
21 |
+
"dpm-solver", 20, 4.5,
|
22 |
+
],
|
23 |
+
[
|
24 |
+
"A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.",
|
25 |
+
"dpm-solver", 20, 4.5,
|
26 |
+
],
|
27 |
+
[
|
28 |
+
"A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.",
|
29 |
+
"dpm-solver", 20, 4.5,
|
30 |
+
],
|
31 |
+
[
|
32 |
+
"a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, "
|
33 |
+
"national geographic photo, 8k resolution, crayon art, interactive artwork",
|
34 |
+
"dpm-solver", 20, 4.5,
|
35 |
+
]
|
36 |
+
]
|
asset/logo-sigma.png
ADDED
asset/logo.png
ADDED
asset/samples.txt
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
A small cactus with a happy face in the Sahara desert.
|
2 |
+
Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.
|
3 |
+
beautiful lady, freckles, big smile, blue eyes, short ginger hair, dark makeup, wearing a floral blue vest top, soft light, dark grey background
|
4 |
+
stars, water, brilliantly, gorgeous large scale scene, a little girl, in the style of dreamy realism, light gold and amber, blue and pink, brilliantly illuminated in the background.
|
5 |
+
nature vs human nature, surreal, UHD, 8k, hyper details, rich colors, photograph.
|
6 |
+
Spectacular Tiny World in the Transparent Jar On the Table, interior of the Great Hall, Elaborate, Carved Architecture, Anatomy, Symetrical, Geometric and Parameteric Details, Precision Flat line Details, Pattern, Dark fantasy, Dark errie mood and ineffably mysterious mood, Technical design, Intricate Ultra Detail, Ornate Detail, Stylized and Futuristic and Biomorphic Details, Architectural Concept, Low contrast Details, Cinematic Lighting, 8k, by moebius, Fullshot, Epic, Fullshot, Octane render, Unreal ,Photorealistic, Hyperrealism
|
7 |
+
anthropomorphic profile of the white snow owl Crystal priestess , art deco painting, pretty and expressive eyes, ornate costume, mythical, ethereal, intricate, elaborate, hyperrealism, hyper detailed, 3D, 8K, Ultra Realistic, high octane, ultra resolution, amazing detail, perfection, In frame, photorealistic, cinematic lighting, visual clarity, shading , Lumen Reflections, Super-Resolution, gigapixel, color grading, retouch, enhanced, PBR, Blender, V-ray, Procreate, zBrush, Unreal Engine 5, cinematic, volumetric, dramatic, neon lighting, wide angle lens ,no digital painting blur
|
8 |
+
The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8
|
9 |
+
Bright scene, aerial view, ancient city, fantasy, gorgeous light, mirror reflection, high detail, wide angle lens.
|
10 |
+
8k uhd A man looks up at the starry sky, lonely and ethereal, Minimalism, Chaotic composition Op Art
|
11 |
+
A middle-aged woman of Asian descent, her dark hair streaked with silver, appears fractured and splintered, intricately embedded within a sea of broken porcelain. The porcelain glistens with splatter paint patterns in a harmonious blend of glossy and matte blues, greens, oranges, and reds, capturing her dance in a surreal juxtaposition of movement and stillness. Her skin tone, a light hue like the porcelain, adds an almost mystical quality to her form.
|
12 |
+
A 4k dslr image of a lemur wearing a red magician hat and a blue coat performing magic tricks with cards in a garden.
|
13 |
+
A alpaca made of colorful building blocks, cyberpunk
|
14 |
+
A baby painter trying to draw very simple picture, white background
|
15 |
+
A boy and a girl fall in love
|
16 |
+
A dog that has been meditating all the time
|
17 |
+
A man is sitting in a chair with his chin resting on his hand. The chair, along with the man's feet, are submerged in the sea. Strikingly, the man's back is on fire.
|
18 |
+
A painter study hard to learn how to draw with many concepts in the air, white background
|
19 |
+
A painter with low quality, white background, pixel art
|
20 |
+
A person standing on the desert, desert waves, gossip illustration, half red, half blue, abstract image of sand, clear style, trendy illustration, outdoor, top view, clear style, precision art, ultra high definition image
|
21 |
+
A silhouette of a grand piano overlooking a dusky cityscape viewed from a top-floor penthouse, rendered in the bold and vivid sytle of a vintage travel poster.
|
22 |
+
A sureal parallel world where mankind avoid extinction by preserving nature, epic trees, water streams, various flowers, intricate details, rich colors, rich vegetation, cinematic, symmetrical, beautiful lighting, V-Ray render, sun rays, magical lights, photography
|
23 |
+
A woman is shopping for fresh produce at the farmer's market.
|
24 |
+
A worker that looks like a mixture of cow and horse is working hard to type code
|
25 |
+
A young man dressed in ancient Chinese clothing, Asian people, White robe, Handsome, Hand gestures forming a spell, Martial arts and fairy-like vibe, Carrying a legendary-level giant sword on the back, Game character, Surrounded by runes, Cyberpunk style, neon lights, best quality, masterpiece, cg, hdr, high-definition, extremely detailed, photorealistic, epic, character design, detailed face, superhero, hero, detailed UHD, real-time, vfx, 3D rendering, 8k
|
26 |
+
An alien octopus floats through a protal reading a newspaper
|
27 |
+
An epressive oil painting of a basketbal player dunking, depicted as an explosion of a nebula
|
28 |
+
art collection style and fashion shoot, in the style of made of glass, dark blue and light pink, paul rand, solarpunk, camille vivier, beth didonato hair, barbiecore, hyper-realistic
|
29 |
+
artistic
|
30 |
+
beautiful secen
|
31 |
+
Crocodile in a sweater
|
32 |
+
Design a letter A, 3D stereoscopic Ice material Interior light blue Conceptual product design Futuristic Blind box toy Handcrafted Exquisite 3D effect Full body display Ultra-high precision Ultra-detailed Perfect lighting OC Renderer Blender 8k Ultra-sharp Ultra-noise reduction
|
33 |
+
Floating,colossal,futuristic statue in the sky, awe-inspiring and serenein the style of Stuart Lippincott:2with detailed composition and subtle geometric elements.This sanctuary-ike atmosphere features crisp clarity and soft amber tones.In contrasttiny human figures surround the statueThe pieceincorporates flowing draperiesreminiscent of Shwedoff and Philip McKay's stylesemphasizing thejuxtaposition between the powerful presence of the statue and thevulnerability of the minuscule human figuresshwedoff
|
34 |
+
knolling of a drawing tools for painter
|
35 |
+
Leonardo da Vinci's Last Supper content, Van Goph's Starry Night Style
|
36 |
+
Luffy from ONEPIECE, handsome face, fantasy
|
37 |
+
photography shot through an outdoor window of a coffee shop with neon sign lighting, window glares and reflections, depth of field, {little girl with red hair sitting at a table, portrait, kodak portra 800,105 mm f1.8
|
38 |
+
poster of a mechanical cat, techical Schematics viewed from front and side view on light white blueprint paper, illustartion drafting style, illustation, typography, conceptual art, dark fantasy steampunk, cinematic, dark fantasy
|
39 |
+
The girl in the car is filled with goldfish and flowers, goldfish can fly, Kawaguchi Renko's art, natural posture, holiday dadcore, youthful energy and pressure, body stretching, goldfish simulation movies in the sky, super details, and dreamy high photography. Colorful. Covered by water and goldfish, indoor scene, close-up shot in XT4 movie
|
40 |
+
The image features a woman wearing a red shirt with an icon. She appears to be posing for the camera, and her outfit includes a pair of jeans. The woman seems to be in a good mood, as she is smiling. The background of the image is blurry, focusing more on the woman and her attire.
|
41 |
+
The towel was on top of the hard counter.
|
42 |
+
A vast landscape made entirely of various meats spreads out before the viewer. tender, succulent hills of roast beef, chicken drumstick trees, bacon rivers, and ham boulders create a surreal, yet appetizing scene. the sky is adorned with pepperoni sun and salami clouds.
|
43 |
+
I want to supplement vitamin c, please help me paint related food.
|
44 |
+
A vibrant yellow banana-shaped couch sits in a cozy living room, its curve cradling a pile of colorful cushions. on the wooden floor, a patterned rug adds a touch of eclectic charm, and a potted plant sits in the corner, reaching towards the sunlight filtering through the window.
|
45 |
+
A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.
|
46 |
+
A blue jay standing on a large basket of rainbow macarons.
|
47 |
+
A bucket bag made of blue suede. The bag is decorated with intricate golden paisley patterns. The handle of the bag is made of rubies and pearls.
|
48 |
+
An alien octopus floats through a portal reading a newspaper.
|
49 |
+
bird's eye view of a city.
|
50 |
+
beautiful scene
|
51 |
+
A 2D animation of a folk music band composed of anthropomorphic autumn leaves, each playing traditional bluegrass instruments, amidst a rustic forest setting dappled with the soft light of a harvest moon.
|
52 |
+
In front of a deep black backdrop, a figure of middle years, her Tongan skin rich and glowing, is captured mid-twirl, her curly hair flowing like a storm behind her. Her attire resembles a whirlwind of marble and porcelain fragments. Illuminated by the gleam of scattered porcelain shards, creating a dreamlike atmosphere, the dancer manages to appear fragmented, yet maintains a harmonious and fluid form.
|
53 |
+
Digital illustration of a beach scene crafted from yarn. The sandy beach is depicted with beige yarn, waves are made of blue and white yarn crashing onto the shore. A yarn sun sets on the horizon, casting a warm glow. Yarn palm trees sway gently, and little yarn seashells dot the shoreline.
|
54 |
+
Illustration of a chic chair with a design reminiscent of a pumpkin’s form, with deep orange cushioning, in a stylish loft setting.
|
55 |
+
A detailed oil painting of an old sea captain, steering his ship through a storm. Saltwater is splashing against his weathered face, determination in his eyes. Twirling malevolent clouds are seen above and stern waves threaten to submerge the ship while seagulls dive and twirl through the chaotic landscape. Thunder and lights embark in the distance, illuminating the scene with an eerie green glow.
|
56 |
+
An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. The quote 'Find the universe within you' is etched in bold letters across the horizon.
|
57 |
+
A modern architectural building with large glass windows, situated on a cliff overlooking a serene ocean at sunset
|
58 |
+
photo of an ancient shipwreck nestled on the ocean floor. Marine plants have claimed the wooden structure, and fish swim in and out of its hollow spaces. Sunken treasures and old cannons are scattered around, providing a glimpse into the past
|
59 |
+
A 3D render of a coffee mug placed on a window sill during a stormy day. The storm outside the window is reflected in the coffee, with miniature lightning bolts and turbulent waves seen inside the mug. The room is dimly lit, adding to the dramatic atmosphere.A minimap diorama of a cafe adorned with indoor plants. Wooden beams crisscross above, and a cold brew station stands out with tiny bottles and glasses.
|
60 |
+
An antique botanical illustration drawn with fine lines and a touch of watercolour whimsy, depicting a strange lily crossed with a Venus flytrap, its petals poised as if ready to snap shut on any unsuspecting insects.An illustration inspired by old-world botanical sketches blends a cactus with lilac blooms into a Möbius strip, using detailed lines and subtle watercolor touches to capture nature's diverse beauty and mathematical intrigue.
|
61 |
+
An ink sketch style illustration of a small hedgehog holding a piece of watermelon with its tiny paws, taking little bites with its eyes closed in delight.Photo of a lychee-inspired spherical chair, with a bumpy white exterior and plush interior, set against a tropical wallpaper.
|
62 |
+
3d digital art of an adorable ghost, glowing within, holding a heart shaped pumpkin, Halloween, super cute, spooky haunted house background
|
63 |
+
professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.
|
64 |
+
an astronaut sitting in a diner, eating fries, cinematic, analog film
|
65 |
+
Chinese architecture, ancient style,mountain, bird, lotus, pond, big tree, 4K Unity, octane rendering.
|
66 |
+
Ethereal fantasy concept art of thunder god with hammer. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy.
|
67 |
+
A Japanese girl walking along a path, surrounding by blooming oriental cherry, pink petal slowly falling down to the ground
|
68 |
+
A Ukiyoe style painting, an astronaut riding a unicorn, In the background there is an ancient Japanese architecture
|
69 |
+
Steampunk makeup, in the style of vray tracing, colorful impasto, uhd image, indonesian art, fine feather details with bright red and yellow and green and pink and orange colours, intricate patterns and details, dark cyan and amber makeup. Rich colourful plumes. Victorian style.
|
70 |
+
A cute teddy bear in front of a plain white wall, warm and brown fur, soft and fluffy
|
71 |
+
The beautiful scenery of Seattle, painting by Al Capp.
|
72 |
+
Photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang.
|
73 |
+
An astronaut riding a horse on the moon, oil painting by Van Gogh.
|
74 |
+
A deep forest clearing with a mirrored pond reflecting a galaxy-filled night sky
|
75 |
+
Realistic oil painting of a stunning model merged in multicolor splash made of finely torn paper, eye contact, walking with class in a street.
|
76 |
+
a chinese model is sitting on a train, magazine cover, clothes made of plastic, photorealistic,futuristic style, gray and green light, movie lighting, 32K HD
|
77 |
+
a handsome 24 years old boy in the middle with sky color background wearing eye glasses, it's super detailed with anime style, it's a portrait with delicated eyes and nice looking face
|
78 |
+
a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, national geographic photo, 8k resolution, crayon art, interactive artwork
|
79 |
+
3D rendering miniature scene design, Many tall buildings, A winding urban road runs through the middle,a lot of cars on the road, transparent material pipeline transports Materials, ,there are many people around, in thestyle of light orange and yellow, graphic design- inspired illustrations, classic still-life, beeple, josan gon-zalez, manga-influenced, miniature dioramas, in thestyle of playful and whimsical designs, graphic de-sign-inspired illustrations, minimalism, hyperrealismlomo lca, e-commerce C4D style, e-commerce posterUl, UX, octane render, blender
|
80 |
+
Close-up photos of models, hazy light and shadow, laser metal hair accessories, soft and beautiful, light gold pupils, white eyelashes, low saturation, real skin details, clear pores and fine lines, light reflection and refraction, ultra-clear, cinematography, award-winning works
|
81 |
+
A cute orange kitten sliding down an aqua slide. happy excited. 16mm lens in front. we see his excitement and scared in the eye. vibrant colors. water splashing on the lens
|
82 |
+
Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.
|
83 |
+
A gorgeously rendered papercraft world of a coral reef, rife with colorful fish and sea creatures.
|
84 |
+
An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt , he wears a brown beret and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film.
|
85 |
+
A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.
|
86 |
+
A New Zealand female business owner stands and is happy that his business is growing by having good VoIP and broadband supplied by Voyager Internet. This business owner is dressed semi casual and is standing with a funky office space in the background. The image is light and bright and is well lit. This image needs to be shot like a professional photo shoot using a Canon R6 with high quality 25mm lens. This image has a shallow depth of field
|
87 |
+
The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8
|
88 |
+
Editorial photoshoot of a old woman, high fashion 2000s fashion
|
89 |
+
Mural Painted of Prince in Purple Rain on side of 5 story brick building next to zen garden vacant lot in the urban center district, rgb
|
90 |
+
Cozy Scandinavian living room, there is a cat sleeping on the couch, depth of field
|
91 |
+
Street style centered straight shot photo shot on Afga Vista 400, lense 50mm, of a two women,skin to skin touch face, emotion, hughing, natural blond hair, natural features, ultra detailed, skin texture, Rembrandt light, soft shadows
|
92 |
+
Frog, in forest, colorful, no watermark, no signature, in forest, 8k
|
93 |
+
selfie of a woman and her lion cub on the plains
|
94 |
+
A fisherman fixing his net sitting on a beautiful tropical beach at sunset with bending palm trees fishing gear and a small boat on shore
|
95 |
+
Coast, decorative painting, horizon, modern, fashionable, full of abstract feeling, full of imagination, the picture reveals the sense of time passing, there is a feeling of the end of the world
|
96 |
+
A close up of a branch of a tree and a golden bug on the top a leaf, shutterstock contest winner,ecological art, depth of field, shallow depth of field, macro photography
|
97 |
+
Outdoor style fashion photo, full – body shot of a man with short brown hair, happy and smiling, he is standing on his hipster bicycle wearing a light blue long sleeved blouse with closed buttons and dark blue jeans trousers, in the background the exterior of an Aldi store, fully lit background, natural afternoon lighting
|
98 |
+
beautiful woman sniper, wearing soviet army uniform, one eye on sniper lens, in snow ground
|
99 |
+
A very attractive and natural woman, sitting on a yoka mat, breathing, eye closed, no make up, intense satisfaction, she looks like she is intensely relaxed, yoga class, sunrise, 35mm
|
100 |
+
a close up of a helmet on a person, digital art, inspired by Han Gan, cloisonnism, female, victorian armor, ultramarine, best of behance, anton fadeev 8 k, fined detail, sci-fi character, elegant armor, fantasy art behance
|
101 |
+
a melting apple
|
102 |
+
yellow FIAT 500 Cinquecento 1957 driving through liechtenstein castle with a lot of banknotes scattered behind ,filled with wads of cash , car color yellow, license plate R-33
|
103 |
+
tented resort in the desert, rocky and sandy terrain, 5 star hotel, beautiful landscape, landscape photography, depth of view, Fujifilm GFX 100 –uplight
|
104 |
+
Full body shot, a French woman, Photography, French Streets background, backlighting, rim light, Fujifilm.
|
105 |
+
Modern luxury contemporary luxury home interiors house, in the style of mimicking ruined materials, ray tracing, haunting houses, and stone, capture the essence of nature, gray and bronze, dynamic outdoor shots.
|
106 |
+
Over the shoulder game perspective, game screen of Diablo 4, Inside the gorgeous palace is the wet ground, The necromancer knelt before the king, and a horde of skeletons he summoned stood at his side, cinematic light.
|
107 |
+
Color photo of a corgi made of transparent glass, standing on the riverside in Yosemite National Park.
|
108 |
+
Happy dreamy owl monster sitting on a tree branch, colorful glittering particles, forest background, detailed feathers.
|
109 |
+
Game-Art - An island with different geographical properties and multiple small cities floating in space
|
110 |
+
Photorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee.
|
111 |
+
A car made out of vegetables.
|
112 |
+
A serene lakeside during autumn with trees displaying a palette of fiery colors.
|
113 |
+
A realistic landscape shot of the Northern Lights dancing over a snowy mountain range in Iceland.
|
114 |
+
A deep forest clearing with a mirrored pond reflecting a galaxy-filled night sky.
|
115 |
+
Drone view of waves crashing against the rugged cliffs along Big Sur’s Garay Point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore.
|
116 |
+
A curvy timber house near a sea, designed by Zaha Hadid, represent the image of a cold, modern architecture, at night, white lighting, highly detailed.
|
117 |
+
Eiffel Tower was Made up of more than 2 million translucent straws to look like a cloud, with the bell tower at the top of the building, Michel installed huge foam-making machines in the forest to blow huge amounts of unpredictable wet clouds in the building's classic architecture.
|
118 |
+
Close-up photos of models, hazy light and shadow, laser metal hair accessories, soft and beautiful, light gold pupils, white eyelashes, low saturation, real skin details, clear pores and fine lines, light reflection and refraction, ultra-clear, cinematography, award-winning works.
|
119 |
+
A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.
|
120 |
+
A close-up photo of a person. The subject is a woman. She wore a blue coat with a gray dress underneath. She has blue eyes and blond hair, and wears a pair of earrings. Behind are blurred city buildings and streets.
|
configs/PixArt_xl2_internal.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data_root = '/data/data'
|
2 |
+
data = dict(type='InternalData', root='images', image_list_json=['data_info.json'], transform='default_train', load_vae_feat=True, load_t5_feat=True)
|
3 |
+
image_size = 256 # the generated image resolution
|
4 |
+
train_batch_size = 32
|
5 |
+
eval_batch_size = 16
|
6 |
+
use_fsdp=False # if use FSDP mode
|
7 |
+
valid_num=0 # take as valid aspect-ratio when sample number >= valid_num
|
8 |
+
fp32_attention = True
|
9 |
+
# model setting
|
10 |
+
model = 'PixArt_XL_2'
|
11 |
+
aspect_ratio_type = None # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
|
12 |
+
multi_scale = False # if use multiscale dataset model training
|
13 |
+
pe_interpolation = 1.0 # positional embedding interpolation
|
14 |
+
# qk norm
|
15 |
+
qk_norm = False
|
16 |
+
# kv token compression
|
17 |
+
kv_compress = False
|
18 |
+
kv_compress_config = {
|
19 |
+
'sampling': None,
|
20 |
+
'scale_factor': 1,
|
21 |
+
'kv_compress_layer': [],
|
22 |
+
}
|
23 |
+
|
24 |
+
# training setting
|
25 |
+
num_workers=4
|
26 |
+
train_sampling_steps = 1000
|
27 |
+
visualize=False
|
28 |
+
eval_sampling_steps = 250
|
29 |
+
model_max_length = 120
|
30 |
+
lora_rank = 4
|
31 |
+
num_epochs = 80
|
32 |
+
gradient_accumulation_steps = 1
|
33 |
+
grad_checkpointing = False
|
34 |
+
gradient_clip = 1.0
|
35 |
+
gc_step = 1
|
36 |
+
auto_lr = dict(rule='sqrt')
|
37 |
+
|
38 |
+
# we use different weight decay with the official implementation since it results better result
|
39 |
+
optimizer = dict(type='AdamW', lr=1e-4, weight_decay=3e-2, eps=1e-10)
|
40 |
+
lr_schedule = 'constant'
|
41 |
+
lr_schedule_args = dict(num_warmup_steps=500)
|
42 |
+
|
43 |
+
save_image_epochs = 1
|
44 |
+
save_model_epochs = 1
|
45 |
+
save_model_steps=1000000
|
46 |
+
|
47 |
+
sample_posterior = True
|
48 |
+
mixed_precision = 'fp16'
|
49 |
+
scale_factor = 0.18215 # ldm vae: 0.18215; sdxl vae: 0.13025
|
50 |
+
ema_rate = 0.9999
|
51 |
+
tensorboard_mox_interval = 50
|
52 |
+
log_interval = 50
|
53 |
+
cfg_scale = 4
|
54 |
+
mask_type='null'
|
55 |
+
num_group_tokens=0
|
56 |
+
mask_loss_coef=0.
|
57 |
+
load_mask_index=False # load prepared mask_type index
|
58 |
+
# load model settings
|
59 |
+
vae_pretrained = "/cache/pretrained_models/sd-vae-ft-ema"
|
60 |
+
load_from = None
|
61 |
+
resume_from = dict(checkpoint=None, load_ema=False, resume_optimizer=True, resume_lr_scheduler=True)
|
62 |
+
snr_loss=False
|
63 |
+
real_prompt_ratio = 1.0
|
64 |
+
# classifier free guidance
|
65 |
+
class_dropout_prob = 0.1
|
66 |
+
# work dir settings
|
67 |
+
work_dir = '/cache/exps/'
|
68 |
+
s3_work_dir = None
|
69 |
+
micro_condition = False
|
70 |
+
seed = 43
|
71 |
+
skip_step=0
|
72 |
+
|
73 |
+
# LCM
|
74 |
+
loss_type = 'huber'
|
75 |
+
huber_c = 0.001
|
76 |
+
num_ddim_timesteps=50
|
77 |
+
w_max = 15.0
|
78 |
+
w_min = 3.0
|
79 |
+
ema_decay = 0.95
|
configs/pixart_alpha_config/PixArt_xl2_img1024_dreambooth.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = ['../PixArt_xl2_internal.py']
|
2 |
+
data_root = 'data/dreambooth/dataset'
|
3 |
+
|
4 |
+
data = dict(type='DreamBooth', root='dog6', prompt=['a photo of sks dog'], transform='default_train', load_vae_feat=True)
|
5 |
+
image_size = 1024
|
6 |
+
|
7 |
+
# model setting
|
8 |
+
model = 'PixArtMS_XL_2' # model for multi-scale training
|
9 |
+
fp32_attention = True
|
10 |
+
load_from = 'Path/to/PixArt-XL-2-1024-MS.pth'
|
11 |
+
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
|
12 |
+
aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
|
13 |
+
multi_scale = True # if use multiscale dataset model training
|
14 |
+
pe_interpolation = 2.0
|
15 |
+
|
16 |
+
# training setting
|
17 |
+
num_workers=1
|
18 |
+
train_batch_size = 1
|
19 |
+
num_epochs = 200
|
20 |
+
gradient_accumulation_steps = 1
|
21 |
+
grad_checkpointing = True
|
22 |
+
gradient_clip = 0.01
|
23 |
+
optimizer = dict(type='AdamW', lr=5e-6, weight_decay=3e-2, eps=1e-10)
|
24 |
+
lr_schedule_args = dict(num_warmup_steps=0)
|
25 |
+
auto_lr = None
|
26 |
+
|
27 |
+
log_interval = 1
|
28 |
+
save_model_epochs=10000
|
29 |
+
save_model_steps=100
|
30 |
+
work_dir = 'output/debug'
|
configs/pixart_alpha_config/PixArt_xl2_img1024_internal.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = ['../PixArt_xl2_internal.py']
|
2 |
+
data_root = 'data'
|
3 |
+
image_list_json = ['data_info.json',]
|
4 |
+
|
5 |
+
data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
|
6 |
+
image_size = 1024
|
7 |
+
|
8 |
+
# model setting
|
9 |
+
model = 'PixArt_XL_2'
|
10 |
+
fp32_attention = True
|
11 |
+
load_from = None
|
12 |
+
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
|
13 |
+
pe_interpolation = 2.0
|
14 |
+
|
15 |
+
# training setting
|
16 |
+
num_workers=10
|
17 |
+
train_batch_size = 2 # 32
|
18 |
+
num_epochs = 200 # 3
|
19 |
+
gradient_accumulation_steps = 1
|
20 |
+
grad_checkpointing = True
|
21 |
+
gradient_clip = 0.01
|
22 |
+
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
|
23 |
+
lr_schedule_args = dict(num_warmup_steps=1000)
|
24 |
+
|
25 |
+
eval_sampling_steps = 200
|
26 |
+
log_interval = 20
|
27 |
+
save_model_epochs=1
|
28 |
+
save_model_steps=2000
|
29 |
+
work_dir = 'output/debug'
|
configs/pixart_alpha_config/PixArt_xl2_img1024_internalms.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = ['../PixArt_xl2_internal.py']
|
2 |
+
data_root = 'data'
|
3 |
+
image_list_json = ['data_info.json',]
|
4 |
+
|
5 |
+
data = dict(type='InternalDataMS', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
|
6 |
+
image_size = 1024
|
7 |
+
|
8 |
+
# model setting
|
9 |
+
model = 'PixArtMS_XL_2' # model for multi-scale training
|
10 |
+
fp32_attention = True
|
11 |
+
load_from = None
|
12 |
+
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
|
13 |
+
aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
|
14 |
+
multi_scale = True # if use multiscale dataset model training
|
15 |
+
pe_interpolation = 2.0
|
16 |
+
|
17 |
+
# training setting
|
18 |
+
num_workers=10
|
19 |
+
train_batch_size = 12 # max 14 for PixArt-xL/2 when grad_checkpoint
|
20 |
+
num_epochs = 10 # 3
|
21 |
+
gradient_accumulation_steps = 1
|
22 |
+
grad_checkpointing = True
|
23 |
+
gradient_clip = 0.01
|
24 |
+
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
|
25 |
+
lr_schedule_args = dict(num_warmup_steps=1000)
|
26 |
+
save_model_epochs=1
|
27 |
+
save_model_steps=2000
|
28 |
+
|
29 |
+
log_interval = 20
|
30 |
+
eval_sampling_steps = 200
|
31 |
+
work_dir = 'output/debug'
|
32 |
+
micro_condition = True
|
configs/pixart_alpha_config/PixArt_xl2_img256_internal.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = ['../PixArt_xl2_internal.py']
|
2 |
+
data_root = 'data'
|
3 |
+
image_list_json = ['data_info.json',]
|
4 |
+
|
5 |
+
data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
|
6 |
+
image_size = 256
|
7 |
+
|
8 |
+
# model setting
|
9 |
+
model = 'PixArt_XL_2'
|
10 |
+
fp32_attention = True
|
11 |
+
load_from = None
|
12 |
+
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
|
13 |
+
# training setting
|
14 |
+
eval_sampling_steps = 200
|
15 |
+
|
16 |
+
num_workers=10
|
17 |
+
train_batch_size = 176 # 32 # max 96 for PixArt-L/4 when grad_checkpoint
|
18 |
+
num_epochs = 200 # 3
|
19 |
+
gradient_accumulation_steps = 1
|
20 |
+
grad_checkpointing = True
|
21 |
+
gradient_clip = 0.01
|
22 |
+
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
|
23 |
+
lr_schedule_args = dict(num_warmup_steps=1000)
|
24 |
+
|
25 |
+
log_interval = 20
|
26 |
+
save_model_epochs=5
|
27 |
+
work_dir = 'output/debug'
|
configs/pixart_alpha_config/PixArt_xl2_img512_internal.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = ['../PixArt_xl2_internal.py']
|
2 |
+
data_root = 'data'
|
3 |
+
image_list_json = ['data_info.json',]
|
4 |
+
|
5 |
+
data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
|
6 |
+
image_size = 512
|
7 |
+
|
8 |
+
# model setting
|
9 |
+
model = 'PixArt_XL_2'
|
10 |
+
fp32_attention = True
|
11 |
+
load_from = None
|
12 |
+
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
|
13 |
+
pe_interpolation = 1.0
|
14 |
+
|
15 |
+
# training setting
|
16 |
+
use_fsdp=False # if use FSDP mode
|
17 |
+
num_workers=10
|
18 |
+
train_batch_size = 38 # 32
|
19 |
+
num_epochs = 200 # 3
|
20 |
+
gradient_accumulation_steps = 1
|
21 |
+
grad_checkpointing = True
|
22 |
+
gradient_clip = 0.01
|
23 |
+
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
|
24 |
+
lr_schedule_args = dict(num_warmup_steps=1000)
|
25 |
+
|
26 |
+
eval_sampling_steps = 200
|
27 |
+
log_interval = 20
|
28 |
+
save_model_epochs=1
|
29 |
+
work_dir = 'output/debug'
|
configs/pixart_alpha_config/PixArt_xl2_img512_internalms.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = ['../PixArt_xl2_internal.py']
|
2 |
+
data_root = 'data'
|
3 |
+
image_list_json = ['data_info.json',]
|
4 |
+
|
5 |
+
data = dict(type='InternalDataMS', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
|
6 |
+
image_size = 512
|
7 |
+
|
8 |
+
# model setting
|
9 |
+
model = 'PixArtMS_XL_2' # model for multi-scale training
|
10 |
+
fp32_attention = True
|
11 |
+
load_from = None
|
12 |
+
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
|
13 |
+
aspect_ratio_type = 'ASPECT_RATIO_512' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
|
14 |
+
multi_scale = True # if use multiscale dataset model training
|
15 |
+
pe_interpolation = 1.0
|
16 |
+
|
17 |
+
# training setting
|
18 |
+
num_workers=10
|
19 |
+
train_batch_size = 40 # max 40 for PixArt-xL/2 when grad_checkpoint
|
20 |
+
num_epochs = 20 # 3
|
21 |
+
gradient_accumulation_steps = 1
|
22 |
+
grad_checkpointing = True
|
23 |
+
gradient_clip = 0.01
|
24 |
+
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
|
25 |
+
lr_schedule_args = dict(num_warmup_steps=1000)
|
26 |
+
save_model_epochs=1
|
27 |
+
save_model_steps=2000
|
28 |
+
|
29 |
+
log_interval = 20
|
30 |
+
eval_sampling_steps = 200
|
31 |
+
work_dir = 'output/debug'
|
configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = ['../PixArt_xl2_internal.py']
|
2 |
+
data_root = 'pixart-sigma-toy-dataset'
|
3 |
+
image_list_json = ['data_info.json']
|
4 |
+
|
5 |
+
data = dict(
|
6 |
+
type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
|
7 |
+
load_vae_feat=False, load_t5_feat=False
|
8 |
+
)
|
9 |
+
image_size = 1024
|
10 |
+
|
11 |
+
# model setting
|
12 |
+
model = 'PixArtMS_XL_2'
|
13 |
+
mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16']
|
14 |
+
fp32_attention = True
|
15 |
+
load_from = None
|
16 |
+
resume_from = None
|
17 |
+
vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
|
18 |
+
aspect_ratio_type = 'ASPECT_RATIO_1024'
|
19 |
+
multi_scale = True # if use multiscale dataset model training
|
20 |
+
pe_interpolation = 2.0
|
21 |
+
|
22 |
+
# training setting
|
23 |
+
num_workers = 10
|
24 |
+
train_batch_size = 2 # 3 for w.o feature extraction; 12 for feature extraction
|
25 |
+
num_epochs = 2 # 3
|
26 |
+
gradient_accumulation_steps = 1
|
27 |
+
grad_checkpointing = True
|
28 |
+
gradient_clip = 0.01
|
29 |
+
optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
|
30 |
+
lr_schedule_args = dict(num_warmup_steps=1000)
|
31 |
+
|
32 |
+
eval_sampling_steps = 500
|
33 |
+
visualize = True
|
34 |
+
log_interval = 20
|
35 |
+
save_model_epochs = 1
|
36 |
+
save_model_steps = 1000
|
37 |
+
work_dir = 'output/debug'
|
38 |
+
|
39 |
+
# pixart-sigma
|
40 |
+
scale_factor = 0.13025
|
41 |
+
real_prompt_ratio = 0.5
|
42 |
+
model_max_length = 300
|
43 |
+
class_dropout_prob = 0.1
|
44 |
+
|
45 |
+
qk_norm = False
|
46 |
+
skip_step = 0 # skip steps during data loading
|
configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms_kvcompress.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = ['../PixArt_xl2_internal.py']
|
2 |
+
data_root = 'data'
|
3 |
+
image_list_json = ['data_info.json']
|
4 |
+
|
5 |
+
data = dict(
|
6 |
+
type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
|
7 |
+
load_vae_feat=False, load_t5_feat=False
|
8 |
+
)
|
9 |
+
image_size = 1024
|
10 |
+
|
11 |
+
# model setting
|
12 |
+
model = 'PixArtMS_XL_2'
|
13 |
+
mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16']
|
14 |
+
fp32_attention = True
|
15 |
+
load_from = None
|
16 |
+
resume_from = None
|
17 |
+
vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
|
18 |
+
aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
|
19 |
+
multi_scale = True # if use multiscale dataset model training
|
20 |
+
pe_interpolation = 2.0
|
21 |
+
|
22 |
+
# training setting
|
23 |
+
num_workers = 10
|
24 |
+
train_batch_size = 4 # 16
|
25 |
+
num_epochs = 2 # 3
|
26 |
+
gradient_accumulation_steps = 1
|
27 |
+
grad_checkpointing = True
|
28 |
+
gradient_clip = 0.01
|
29 |
+
optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
|
30 |
+
lr_schedule_args = dict(num_warmup_steps=500)
|
31 |
+
|
32 |
+
eval_sampling_steps = 250
|
33 |
+
visualize = True
|
34 |
+
log_interval = 10
|
35 |
+
save_model_epochs = 1
|
36 |
+
save_model_steps = 1000
|
37 |
+
work_dir = 'output/debug'
|
38 |
+
|
39 |
+
# pixart-sigma
|
40 |
+
scale_factor = 0.13025
|
41 |
+
real_prompt_ratio = 0.5
|
42 |
+
model_max_length = 300
|
43 |
+
class_dropout_prob = 0.1
|
44 |
+
kv_compress = True
|
45 |
+
kv_compress_config = {
|
46 |
+
'sampling': 'conv', # ['conv', 'uniform', 'ave']
|
47 |
+
'scale_factor': 2,
|
48 |
+
'kv_compress_layer': [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27],
|
49 |
+
}
|
50 |
+
qk_norm = False
|
51 |
+
skip_step = 0 # skip steps during data loading
|
configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_lcm.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = ['../PixArt_xl2_internal.py']
|
2 |
+
data_root = 'pixart-sigma-toy-dataset'
|
3 |
+
image_list_json = ['data_info.json']
|
4 |
+
|
5 |
+
data = dict(
|
6 |
+
type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
|
7 |
+
load_vae_feat=True, load_t5_feat=True,
|
8 |
+
)
|
9 |
+
image_size = 1024
|
10 |
+
|
11 |
+
# model setting
|
12 |
+
model = 'PixArtMS_XL_2' # model for multi-scale training
|
13 |
+
fp32_attention = False
|
14 |
+
load_from = None
|
15 |
+
resume_from = None
|
16 |
+
vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
|
17 |
+
aspect_ratio_type = 'ASPECT_RATIO_1024'
|
18 |
+
multi_scale = True # if use multiscale dataset model training
|
19 |
+
pe_interpolation = 2.0
|
20 |
+
|
21 |
+
# training setting
|
22 |
+
num_workers = 4
|
23 |
+
train_batch_size = 12 # max 12 for PixArt-xL/2 when grad_checkpoint
|
24 |
+
num_epochs = 10 # 3
|
25 |
+
gradient_accumulation_steps = 1
|
26 |
+
grad_checkpointing = True
|
27 |
+
gradient_clip = 0.01
|
28 |
+
optimizer = dict(type='CAMEWrapper', lr=1e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
|
29 |
+
lr_schedule_args = dict(num_warmup_steps=100)
|
30 |
+
save_model_epochs = 10
|
31 |
+
save_model_steps = 1000
|
32 |
+
valid_num = 0 # take as valid aspect-ratio when sample number >= valid_num
|
33 |
+
|
34 |
+
log_interval = 10
|
35 |
+
eval_sampling_steps = 5
|
36 |
+
visualize = True
|
37 |
+
work_dir = 'output/debug'
|
38 |
+
|
39 |
+
# pixart-sigma
|
40 |
+
scale_factor = 0.13025
|
41 |
+
real_prompt_ratio = 0.5
|
42 |
+
model_max_length = 300
|
43 |
+
class_dropout_prob = 0.1
|
44 |
+
|
45 |
+
# LCM
|
46 |
+
loss_type = 'huber'
|
47 |
+
huber_c = 0.001
|
48 |
+
num_ddim_timesteps = 50
|
49 |
+
w_max = 15.0
|
50 |
+
w_min = 3.0
|
51 |
+
ema_decay = 0.95
|
52 |
+
cfg_scale = 4.5
|
configs/pixart_sigma_config/PixArt_sigma_xl2_img256_internal.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = ['../PixArt_xl2_internal.py']
|
2 |
+
data_root = 'pixart-sigma-toy-dataset'
|
3 |
+
image_list_json = ['data_info.json']
|
4 |
+
|
5 |
+
data = dict(
|
6 |
+
type='InternalDataSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
|
7 |
+
load_vae_feat=False, load_t5_feat=False,
|
8 |
+
)
|
9 |
+
image_size = 256
|
10 |
+
|
11 |
+
# model setting
|
12 |
+
model = 'PixArt_XL_2'
|
13 |
+
mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16']
|
14 |
+
fp32_attention = True
|
15 |
+
load_from = "output/pretrained_models/PixArt-Sigma-XL-2-256x256.pth" # https://huggingface.co/PixArt-alpha/PixArt-Sigma
|
16 |
+
resume_from = None
|
17 |
+
vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
|
18 |
+
multi_scale = False # if use multiscale dataset model training
|
19 |
+
pe_interpolation = 0.5
|
20 |
+
|
21 |
+
# training setting
|
22 |
+
num_workers = 10
|
23 |
+
train_batch_size = 64 # 64 as default
|
24 |
+
num_epochs = 200 # 3
|
25 |
+
gradient_accumulation_steps = 1
|
26 |
+
grad_checkpointing = True
|
27 |
+
gradient_clip = 0.01
|
28 |
+
optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
|
29 |
+
lr_schedule_args = dict(num_warmup_steps=1000)
|
30 |
+
|
31 |
+
eval_sampling_steps = 500
|
32 |
+
log_interval = 20
|
33 |
+
save_model_epochs = 5
|
34 |
+
save_model_steps = 2500
|
35 |
+
work_dir = 'output/debug'
|
36 |
+
|
37 |
+
# pixart-sigma
|
38 |
+
scale_factor = 0.13025
|
39 |
+
real_prompt_ratio = 0.5
|
40 |
+
model_max_length = 300
|
41 |
+
class_dropout_prob = 0.1
|
configs/pixart_sigma_config/PixArt_sigma_xl2_img2K_internalms_kvcompress.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = ['../PixArt_xl2_internal.py']
|
2 |
+
data_root = 'data'
|
3 |
+
image_list_json = ['data_info.json']
|
4 |
+
|
5 |
+
data = dict(
|
6 |
+
type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
|
7 |
+
load_vae_feat=False, load_t5_feat=False
|
8 |
+
)
|
9 |
+
image_size = 2048
|
10 |
+
|
11 |
+
# model setting
|
12 |
+
model = 'PixArtMS_XL_2'
|
13 |
+
mixed_precision = 'fp16'
|
14 |
+
fp32_attention = True
|
15 |
+
load_from = None
|
16 |
+
resume_from = None
|
17 |
+
vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
|
18 |
+
aspect_ratio_type = 'ASPECT_RATIO_2048' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
|
19 |
+
multi_scale = True # if use multiscale dataset model training
|
20 |
+
pe_interpolation = 4.0
|
21 |
+
|
22 |
+
# training setting
|
23 |
+
num_workers = 10
|
24 |
+
train_batch_size = 4 # 48
|
25 |
+
num_epochs = 10 # 3
|
26 |
+
gradient_accumulation_steps = 1
|
27 |
+
grad_checkpointing = True
|
28 |
+
gradient_clip = 0.01
|
29 |
+
optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
|
30 |
+
lr_schedule_args = dict(num_warmup_steps=100)
|
31 |
+
|
32 |
+
eval_sampling_steps = 100
|
33 |
+
visualize = True
|
34 |
+
log_interval = 10
|
35 |
+
save_model_epochs = 10
|
36 |
+
save_model_steps = 100
|
37 |
+
work_dir = 'output/debug'
|
38 |
+
|
39 |
+
# pixart-sigma
|
40 |
+
scale_factor = 0.13025
|
41 |
+
real_prompt_ratio = 0.5
|
42 |
+
model_max_length = 300
|
43 |
+
class_dropout_prob = 0.1
|
44 |
+
kv_compress = False
|
45 |
+
kv_compress_config = {
|
46 |
+
'sampling': 'conv',
|
47 |
+
'scale_factor': 2,
|
48 |
+
'kv_compress_layer': [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27],
|
49 |
+
}
|
configs/pixart_sigma_config/PixArt_sigma_xl2_img512_internalms.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = ['../PixArt_xl2_internal.py']
|
2 |
+
data_root = 'pixart-sigma-toy-dataset'
|
3 |
+
image_list_json = ['data_info.json']
|
4 |
+
|
5 |
+
data = dict(
|
6 |
+
type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
|
7 |
+
load_vae_feat=False, load_t5_feat=False,
|
8 |
+
)
|
9 |
+
image_size = 512
|
10 |
+
|
11 |
+
# model setting
|
12 |
+
model = 'PixArtMS_XL_2'
|
13 |
+
mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16']
|
14 |
+
fp32_attention = True
|
15 |
+
load_from = "output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth" # https://huggingface.co/PixArt-alpha/PixArt-Sigma
|
16 |
+
resume_from = None
|
17 |
+
vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
|
18 |
+
aspect_ratio_type = 'ASPECT_RATIO_512'
|
19 |
+
multi_scale = True # if use multiscale dataset model training
|
20 |
+
pe_interpolation = 1.0
|
21 |
+
|
22 |
+
# training setting
|
23 |
+
num_workers = 10
|
24 |
+
train_batch_size = 2 # 48 as default
|
25 |
+
num_epochs = 10 # 3
|
26 |
+
gradient_accumulation_steps = 1
|
27 |
+
grad_checkpointing = True
|
28 |
+
gradient_clip = 0.01
|
29 |
+
optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
|
30 |
+
lr_schedule_args = dict(num_warmup_steps=1000)
|
31 |
+
|
32 |
+
eval_sampling_steps = 500
|
33 |
+
visualize = True
|
34 |
+
log_interval = 20
|
35 |
+
save_model_epochs = 5
|
36 |
+
save_model_steps = 2500
|
37 |
+
work_dir = 'output/debug'
|
38 |
+
|
39 |
+
# pixart-sigma
|
40 |
+
scale_factor = 0.13025
|
41 |
+
real_prompt_ratio = 0.5
|
42 |
+
model_max_length = 300
|
43 |
+
class_dropout_prob = 0.1
|
diffusion/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
from .iddpm import IDDPM
|
7 |
+
from .dpm_solver import DPMS
|
8 |
+
from .sa_sampler import SASolverSampler
|
diffusion/data/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .datasets import *
|
2 |
+
from .transforms import get_transform
|
diffusion/data/builder.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
from mmcv import Registry, build_from_cfg
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
|
7 |
+
from diffusion.data.transforms import get_transform
|
8 |
+
from diffusion.utils.logger import get_root_logger
|
9 |
+
|
10 |
+
DATASETS = Registry('datasets')
|
11 |
+
|
12 |
+
DATA_ROOT = '/cache/data'
|
13 |
+
|
14 |
+
|
15 |
+
def set_data_root(data_root):
|
16 |
+
global DATA_ROOT
|
17 |
+
DATA_ROOT = data_root
|
18 |
+
|
19 |
+
|
20 |
+
def get_data_path(data_dir):
|
21 |
+
if os.path.isabs(data_dir):
|
22 |
+
return data_dir
|
23 |
+
global DATA_ROOT
|
24 |
+
return os.path.join(DATA_ROOT, data_dir)
|
25 |
+
|
26 |
+
|
27 |
+
def build_dataset(cfg, resolution=224, **kwargs):
|
28 |
+
logger = get_root_logger()
|
29 |
+
|
30 |
+
dataset_type = cfg.get('type')
|
31 |
+
logger.info(f"Constructing dataset {dataset_type}...")
|
32 |
+
t = time.time()
|
33 |
+
transform = cfg.pop('transform', 'default_train')
|
34 |
+
transform = get_transform(transform, resolution)
|
35 |
+
dataset = build_from_cfg(cfg, DATASETS, default_args=dict(transform=transform, resolution=resolution, **kwargs))
|
36 |
+
logger.info(f"Dataset {dataset_type} constructed. time: {(time.time() - t):.2f} s, length (use/ori): {len(dataset)}/{dataset.ori_imgs_nums}")
|
37 |
+
return dataset
|
38 |
+
|
39 |
+
|
40 |
+
def build_dataloader(dataset, batch_size=256, num_workers=4, shuffle=True, **kwargs):
|
41 |
+
if 'batch_sampler' in kwargs:
|
42 |
+
dataloader = DataLoader(dataset, batch_sampler=kwargs['batch_sampler'], num_workers=num_workers, pin_memory=True)
|
43 |
+
else:
|
44 |
+
dataloader = DataLoader(dataset,
|
45 |
+
batch_size=batch_size,
|
46 |
+
shuffle=shuffle,
|
47 |
+
num_workers=num_workers,
|
48 |
+
pin_memory=True,
|
49 |
+
**kwargs)
|
50 |
+
return dataloader
|
diffusion/data/datasets/InternalData.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from diffusers.utils.torch_utils import randn_tensor
|
9 |
+
from torchvision import transforms as T
|
10 |
+
from diffusion.data.builder import get_data_path, DATASETS
|
11 |
+
from diffusion.utils.logger import get_root_logger
|
12 |
+
|
13 |
+
import json
|
14 |
+
|
15 |
+
@DATASETS.register_module()
|
16 |
+
class InternalData(Dataset):
|
17 |
+
def __init__(self,
|
18 |
+
root,
|
19 |
+
image_list_json='data_info.json',
|
20 |
+
transform=None,
|
21 |
+
resolution=256,
|
22 |
+
sample_subset=None,
|
23 |
+
load_vae_feat=False,
|
24 |
+
input_size=32,
|
25 |
+
patch_size=2,
|
26 |
+
mask_ratio=0.0,
|
27 |
+
load_mask_index=False,
|
28 |
+
max_length=120,
|
29 |
+
config=None,
|
30 |
+
**kwargs):
|
31 |
+
self.root = get_data_path(root)
|
32 |
+
self.transform = transform
|
33 |
+
self.load_vae_feat = load_vae_feat
|
34 |
+
self.ori_imgs_nums = 0
|
35 |
+
self.resolution = resolution
|
36 |
+
self.N = int(resolution // (input_size // patch_size))
|
37 |
+
self.mask_ratio = mask_ratio
|
38 |
+
self.load_mask_index = load_mask_index
|
39 |
+
self.max_lenth = max_length
|
40 |
+
self.meta_data_clean = []
|
41 |
+
self.img_samples = []
|
42 |
+
self.txt_feat_samples = []
|
43 |
+
self.vae_feat_samples = []
|
44 |
+
self.mask_index_samples = []
|
45 |
+
self.prompt_samples = []
|
46 |
+
|
47 |
+
image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json]
|
48 |
+
for json_file in image_list_json:
|
49 |
+
meta_data = self.load_json(os.path.join(self.root, 'partition', json_file))
|
50 |
+
self.ori_imgs_nums += len(meta_data)
|
51 |
+
meta_data_clean = [item for item in meta_data if item['ratio'] <= 4]
|
52 |
+
self.meta_data_clean.extend(meta_data_clean)
|
53 |
+
self.img_samples.extend([os.path.join(self.root.replace('InternData', "InternImgs"), item['path']) for item in meta_data_clean])
|
54 |
+
self.txt_feat_samples.extend([os.path.join(self.root, 'caption_feature_wmask', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')) for item in meta_data_clean])
|
55 |
+
self.vae_feat_samples.extend([os.path.join(self.root, f'img_vae_features_{resolution}resolution/noflip', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')) for item in meta_data_clean])
|
56 |
+
self.prompt_samples.extend([item['prompt'] for item in meta_data_clean])
|
57 |
+
|
58 |
+
# Set loader and extensions
|
59 |
+
if load_vae_feat:
|
60 |
+
self.transform = None
|
61 |
+
self.loader = self.vae_feat_loader
|
62 |
+
else:
|
63 |
+
self.loader = default_loader
|
64 |
+
|
65 |
+
if sample_subset is not None:
|
66 |
+
self.sample_subset(sample_subset) # sample dataset for local debug
|
67 |
+
logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
|
68 |
+
logger.info(f"T5 max token length: {self.max_lenth}")
|
69 |
+
|
70 |
+
def getdata(self, index):
|
71 |
+
img_path = self.img_samples[index]
|
72 |
+
npz_path = self.txt_feat_samples[index]
|
73 |
+
npy_path = self.vae_feat_samples[index]
|
74 |
+
prompt = self.prompt_samples[index]
|
75 |
+
data_info = {
|
76 |
+
'img_hw': torch.tensor([torch.tensor(self.resolution), torch.tensor(self.resolution)], dtype=torch.float32),
|
77 |
+
'aspect_ratio': torch.tensor(1.)
|
78 |
+
}
|
79 |
+
|
80 |
+
img = self.loader(npy_path) if self.load_vae_feat else self.loader(img_path)
|
81 |
+
txt_info = np.load(npz_path)
|
82 |
+
txt_fea = torch.from_numpy(txt_info['caption_feature']) # 1xTx4096
|
83 |
+
attention_mask = torch.ones(1, 1, txt_fea.shape[1]) # 1x1xT
|
84 |
+
if 'attention_mask' in txt_info.keys():
|
85 |
+
attention_mask = torch.from_numpy(txt_info['attention_mask'])[None]
|
86 |
+
if txt_fea.shape[1] != self.max_lenth:
|
87 |
+
txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1)
|
88 |
+
attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1)
|
89 |
+
|
90 |
+
if self.transform:
|
91 |
+
img = self.transform(img)
|
92 |
+
|
93 |
+
data_info['prompt'] = prompt
|
94 |
+
return img, txt_fea, attention_mask, data_info
|
95 |
+
|
96 |
+
def __getitem__(self, idx):
|
97 |
+
for _ in range(20):
|
98 |
+
try:
|
99 |
+
return self.getdata(idx)
|
100 |
+
except Exception as e:
|
101 |
+
print(f"Error details: {str(e)}")
|
102 |
+
idx = np.random.randint(len(self))
|
103 |
+
raise RuntimeError('Too many bad data.')
|
104 |
+
|
105 |
+
def get_data_info(self, idx):
|
106 |
+
data_info = self.meta_data_clean[idx]
|
107 |
+
return {'height': data_info['height'], 'width': data_info['width']}
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
def vae_feat_loader(path):
|
111 |
+
# [mean, std]
|
112 |
+
mean, std = torch.from_numpy(np.load(path)).chunk(2)
|
113 |
+
sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype)
|
114 |
+
return mean + std * sample
|
115 |
+
|
116 |
+
def load_ori_img(self, img_path):
|
117 |
+
# 加载图像并转换为Tensor
|
118 |
+
transform = T.Compose([
|
119 |
+
T.Resize(256), # Image.BICUBIC
|
120 |
+
T.CenterCrop(256),
|
121 |
+
T.ToTensor(),
|
122 |
+
])
|
123 |
+
return transform(Image.open(img_path))
|
124 |
+
|
125 |
+
def load_json(self, file_path):
|
126 |
+
with open(file_path, 'r') as f:
|
127 |
+
meta_data = json.load(f)
|
128 |
+
|
129 |
+
return meta_data
|
130 |
+
|
131 |
+
def sample_subset(self, ratio):
|
132 |
+
sampled_idx = random.sample(list(range(len(self))), int(len(self) * ratio))
|
133 |
+
self.img_samples = [self.img_samples[i] for i in sampled_idx]
|
134 |
+
|
135 |
+
def __len__(self):
|
136 |
+
return len(self.img_samples)
|
137 |
+
|
138 |
+
def __getattr__(self, name):
|
139 |
+
if name == "set_epoch":
|
140 |
+
return lambda epoch: None
|
141 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
142 |
+
|
143 |
+
@DATASETS.register_module()
|
144 |
+
class InternalDataSigma(Dataset):
|
145 |
+
def __init__(self,
|
146 |
+
root,
|
147 |
+
image_list_json='data_info.json',
|
148 |
+
transform=None,
|
149 |
+
resolution=256,
|
150 |
+
sample_subset=None,
|
151 |
+
load_vae_feat=False,
|
152 |
+
load_t5_feat=False,
|
153 |
+
input_size=32,
|
154 |
+
patch_size=2,
|
155 |
+
mask_ratio=0.0,
|
156 |
+
mask_type='null',
|
157 |
+
load_mask_index=False,
|
158 |
+
real_prompt_ratio=1.0,
|
159 |
+
max_length=300,
|
160 |
+
config=None,
|
161 |
+
**kwargs):
|
162 |
+
self.root = get_data_path(root)
|
163 |
+
self.transform = transform
|
164 |
+
self.load_vae_feat = load_vae_feat
|
165 |
+
self.load_t5_feat = load_t5_feat
|
166 |
+
self.ori_imgs_nums = 0
|
167 |
+
self.resolution = resolution
|
168 |
+
self.N = int(resolution // (input_size // patch_size))
|
169 |
+
self.mask_ratio = mask_ratio
|
170 |
+
self.load_mask_index = load_mask_index
|
171 |
+
self.mask_type = mask_type
|
172 |
+
self.real_prompt_ratio = real_prompt_ratio
|
173 |
+
self.max_lenth = max_length
|
174 |
+
self.meta_data_clean = []
|
175 |
+
self.img_samples = []
|
176 |
+
self.txt_samples = []
|
177 |
+
self.sharegpt4v_txt_samples = []
|
178 |
+
self.txt_feat_samples = []
|
179 |
+
self.vae_feat_samples = []
|
180 |
+
self.mask_index_samples = []
|
181 |
+
self.gpt4v_txt_feat_samples = []
|
182 |
+
self.weight_dtype = torch.float16 if self.real_prompt_ratio > 0 else torch.float32
|
183 |
+
logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
|
184 |
+
logger.info(f"T5 max token length: {self.max_lenth}")
|
185 |
+
logger.info(f"ratio of real user prompt: {self.real_prompt_ratio}")
|
186 |
+
|
187 |
+
image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json]
|
188 |
+
for json_file in image_list_json:
|
189 |
+
meta_data = self.load_json(os.path.join(self.root, json_file))
|
190 |
+
logger.info(f"{json_file} data volume: {len(meta_data)}")
|
191 |
+
self.ori_imgs_nums += len(meta_data)
|
192 |
+
meta_data_clean = [item for item in meta_data if item['ratio'] <= 4.5]
|
193 |
+
self.meta_data_clean.extend(meta_data_clean)
|
194 |
+
self.img_samples.extend([
|
195 |
+
os.path.join(self.root.replace('InternData', 'InternImgs'), item['path']) for item in meta_data_clean
|
196 |
+
])
|
197 |
+
self.txt_samples.extend([item['prompt'] for item in meta_data_clean])
|
198 |
+
self.sharegpt4v_txt_samples.extend([item['sharegpt4v'] if 'sharegpt4v' in item else '' for item in meta_data_clean])
|
199 |
+
self.txt_feat_samples.extend([
|
200 |
+
os.path.join(
|
201 |
+
self.root,
|
202 |
+
'caption_features_new',
|
203 |
+
item['path'].rsplit('/', 1)[-1].replace('.png', '.npz')
|
204 |
+
) for item in meta_data_clean
|
205 |
+
])
|
206 |
+
self.gpt4v_txt_feat_samples.extend([
|
207 |
+
os.path.join(
|
208 |
+
self.root,
|
209 |
+
'sharegpt4v_caption_features_new',
|
210 |
+
item['path'].rsplit('/', 1)[-1].replace('.png', '.npz')
|
211 |
+
) for item in meta_data_clean
|
212 |
+
])
|
213 |
+
self.vae_feat_samples.extend(
|
214 |
+
[
|
215 |
+
os.path.join(
|
216 |
+
self.root,
|
217 |
+
f'img_sdxl_vae_features_{resolution}resolution_new',
|
218 |
+
item['path'].rsplit('/', 1)[-1].replace('.png', '.npy')
|
219 |
+
) for item in meta_data_clean
|
220 |
+
])
|
221 |
+
|
222 |
+
# Set loader and extensions
|
223 |
+
if load_vae_feat:
|
224 |
+
self.transform = None
|
225 |
+
self.loader = self.vae_feat_loader
|
226 |
+
else:
|
227 |
+
self.loader = default_loader
|
228 |
+
|
229 |
+
if sample_subset is not None:
|
230 |
+
self.sample_subset(sample_subset) # sample dataset for local debug
|
231 |
+
|
232 |
+
def getdata(self, index):
|
233 |
+
img_path = self.img_samples[index]
|
234 |
+
real_prompt = random.random() < self.real_prompt_ratio
|
235 |
+
npz_path = self.txt_feat_samples[index] if real_prompt else self.gpt4v_txt_feat_samples[index]
|
236 |
+
txt = self.txt_samples[index] if real_prompt else self.sharegpt4v_txt_samples[index]
|
237 |
+
npy_path = self.vae_feat_samples[index]
|
238 |
+
data_info = {'img_hw': torch.tensor([torch.tensor(self.resolution), torch.tensor(self.resolution)], dtype=torch.float32),
|
239 |
+
'aspect_ratio': torch.tensor(1.)}
|
240 |
+
|
241 |
+
if self.load_vae_feat:
|
242 |
+
img = self.loader(npy_path)
|
243 |
+
else:
|
244 |
+
img = self.loader(img_path)
|
245 |
+
|
246 |
+
attention_mask = torch.ones(1, 1, self.max_lenth) # 1x1xT
|
247 |
+
if self.load_t5_feat:
|
248 |
+
txt_info = np.load(npz_path)
|
249 |
+
txt_fea = torch.from_numpy(txt_info['caption_feature']) # 1xTx4096
|
250 |
+
if 'attention_mask' in txt_info.keys():
|
251 |
+
attention_mask = torch.from_numpy(txt_info['attention_mask'])[None]
|
252 |
+
if txt_fea.shape[1] != self.max_lenth:
|
253 |
+
txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1)
|
254 |
+
attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1)
|
255 |
+
else:
|
256 |
+
txt_fea = txt
|
257 |
+
|
258 |
+
if self.transform:
|
259 |
+
img = self.transform(img)
|
260 |
+
|
261 |
+
data_info["mask_type"] = self.mask_type
|
262 |
+
return img, txt_fea, attention_mask.to(torch.int16), data_info
|
263 |
+
|
264 |
+
def __getitem__(self, idx):
|
265 |
+
for _ in range(20):
|
266 |
+
try:
|
267 |
+
data = self.getdata(idx)
|
268 |
+
return data
|
269 |
+
except Exception as e:
|
270 |
+
print(f"Error details {self.img_samples[idx]}: {str(e)}")
|
271 |
+
idx = np.random.randint(len(self))
|
272 |
+
raise RuntimeError('Too many bad data.')
|
273 |
+
|
274 |
+
def get_data_info(self, idx):
|
275 |
+
data_info = self.meta_data_clean[idx]
|
276 |
+
return {'height': data_info['height'], 'width': data_info['width']}
|
277 |
+
|
278 |
+
@staticmethod
|
279 |
+
def vae_feat_loader(path):
|
280 |
+
# [mean, std]
|
281 |
+
mean, std = torch.from_numpy(np.load(path)).chunk(2)
|
282 |
+
sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype)
|
283 |
+
return mean + std * sample
|
284 |
+
|
285 |
+
def load_ori_img(self, img_path):
|
286 |
+
# 加载图像并转换为Tensor
|
287 |
+
transform = T.Compose([
|
288 |
+
T.Resize(256), # Image.BICUBIC
|
289 |
+
T.CenterCrop(256),
|
290 |
+
T.ToTensor(),
|
291 |
+
])
|
292 |
+
img = transform(Image.open(img_path))
|
293 |
+
return img
|
294 |
+
|
295 |
+
def load_json(self, file_path):
|
296 |
+
with open(file_path, 'r') as f:
|
297 |
+
meta_data = json.load(f)
|
298 |
+
|
299 |
+
return meta_data
|
300 |
+
|
301 |
+
def sample_subset(self, ratio):
|
302 |
+
sampled_idx = random.sample(list(range(len(self))), int(len(self) * ratio))
|
303 |
+
self.img_samples = [self.img_samples[i] for i in sampled_idx]
|
304 |
+
|
305 |
+
def __len__(self):
|
306 |
+
return len(self.img_samples)
|
307 |
+
|
308 |
+
def __getattr__(self, name):
|
309 |
+
if name == "set_epoch":
|
310 |
+
return lambda epoch: None
|
311 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
312 |
+
|
diffusion/data/datasets/InternalData_ms.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
from torchvision.datasets.folder import default_loader
|
6 |
+
from diffusion.data.datasets.InternalData import InternalData, InternalDataSigma
|
7 |
+
from diffusion.data.builder import get_data_path, DATASETS
|
8 |
+
from diffusion.utils.logger import get_root_logger
|
9 |
+
import torchvision.transforms as T
|
10 |
+
from torchvision.transforms.functional import InterpolationMode
|
11 |
+
from diffusion.data.datasets.utils import *
|
12 |
+
|
13 |
+
def get_closest_ratio(height: float, width: float, ratios: dict):
|
14 |
+
aspect_ratio = height / width
|
15 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
16 |
+
return ratios[closest_ratio], float(closest_ratio)
|
17 |
+
|
18 |
+
|
19 |
+
@DATASETS.register_module()
|
20 |
+
class InternalDataMS(InternalData):
|
21 |
+
def __init__(self,
|
22 |
+
root,
|
23 |
+
image_list_json='data_info.json',
|
24 |
+
transform=None,
|
25 |
+
resolution=256,
|
26 |
+
sample_subset=None,
|
27 |
+
load_vae_feat=False,
|
28 |
+
input_size=32,
|
29 |
+
patch_size=2,
|
30 |
+
mask_ratio=0.0,
|
31 |
+
mask_type='null',
|
32 |
+
load_mask_index=False,
|
33 |
+
real_prompt_ratio=1.0,
|
34 |
+
max_length=120,
|
35 |
+
config=None,
|
36 |
+
**kwargs):
|
37 |
+
self.root = get_data_path(root)
|
38 |
+
self.transform = transform
|
39 |
+
self.load_vae_feat = load_vae_feat
|
40 |
+
self.ori_imgs_nums = 0
|
41 |
+
self.resolution = resolution
|
42 |
+
self.N = int(resolution // (input_size // patch_size))
|
43 |
+
self.mask_ratio = mask_ratio
|
44 |
+
self.load_mask_index = load_mask_index
|
45 |
+
self.mask_type = mask_type
|
46 |
+
self.real_prompt_ratio = real_prompt_ratio
|
47 |
+
self.max_lenth = max_length
|
48 |
+
self.base_size = int(kwargs['aspect_ratio_type'].split('_')[-1])
|
49 |
+
self.aspect_ratio = eval(kwargs.pop('aspect_ratio_type')) # base aspect ratio
|
50 |
+
self.meta_data_clean = []
|
51 |
+
self.img_samples = []
|
52 |
+
self.txt_feat_samples = []
|
53 |
+
self.vae_feat_samples = []
|
54 |
+
self.mask_index_samples = []
|
55 |
+
self.ratio_index = {}
|
56 |
+
self.ratio_nums = {}
|
57 |
+
# self.weight_dtype = torch.float16 if self.real_prompt_ratio > 0 else torch.float32
|
58 |
+
for k, v in self.aspect_ratio.items():
|
59 |
+
self.ratio_index[float(k)] = [] # used for self.getitem
|
60 |
+
self.ratio_nums[float(k)] = 0 # used for batch-sampler
|
61 |
+
|
62 |
+
image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json]
|
63 |
+
for json_file in image_list_json:
|
64 |
+
meta_data = self.load_json(os.path.join(self.root, json_file))
|
65 |
+
self.ori_imgs_nums += len(meta_data)
|
66 |
+
meta_data_clean = [item for item in meta_data if item['ratio'] <= 4]
|
67 |
+
self.meta_data_clean.extend(meta_data_clean)
|
68 |
+
self.img_samples.extend([os.path.join(self.root.replace('InternData', "InternImgs"), item['path']) for item in meta_data_clean])
|
69 |
+
self.txt_feat_samples.extend([os.path.join(self.root, 'caption_features', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')) for item in meta_data_clean])
|
70 |
+
self.vae_feat_samples.extend([os.path.join(self.root, f'img_vae_fatures_{resolution}_multiscale/ms', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')) for item in meta_data_clean])
|
71 |
+
|
72 |
+
# Set loader and extensions
|
73 |
+
if load_vae_feat:
|
74 |
+
self.transform = None
|
75 |
+
self.loader = self.vae_feat_loader
|
76 |
+
else:
|
77 |
+
self.loader = default_loader
|
78 |
+
|
79 |
+
if sample_subset is not None:
|
80 |
+
self.sample_subset(sample_subset) # sample dataset for local debug
|
81 |
+
|
82 |
+
# scan the dataset for ratio static
|
83 |
+
for i, info in enumerate(self.meta_data_clean[:len(self.meta_data_clean)//3]):
|
84 |
+
ori_h, ori_w = info['height'], info['width']
|
85 |
+
closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
|
86 |
+
self.ratio_nums[closest_ratio] += 1
|
87 |
+
if len(self.ratio_index[closest_ratio]) == 0:
|
88 |
+
self.ratio_index[closest_ratio].append(i)
|
89 |
+
# print(self.ratio_nums)
|
90 |
+
logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
|
91 |
+
logger.info(f"T5 max token length: {self.max_lenth}")
|
92 |
+
|
93 |
+
def getdata(self, index):
|
94 |
+
img_path = self.img_samples[index]
|
95 |
+
npz_path = self.txt_feat_samples[index]
|
96 |
+
npy_path = self.vae_feat_samples[index]
|
97 |
+
ori_h, ori_w = self.meta_data_clean[index]['height'], self.meta_data_clean[index]['width']
|
98 |
+
|
99 |
+
# Calculate the closest aspect ratio and resize & crop image[w, h]
|
100 |
+
closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
|
101 |
+
closest_size = list(map(lambda x: int(x), closest_size))
|
102 |
+
self.closest_ratio = closest_ratio
|
103 |
+
|
104 |
+
if self.load_vae_feat:
|
105 |
+
try:
|
106 |
+
img = self.loader(npy_path)
|
107 |
+
if index not in self.ratio_index[closest_ratio]:
|
108 |
+
self.ratio_index[closest_ratio].append(index)
|
109 |
+
except Exception:
|
110 |
+
index = random.choice(self.ratio_index[closest_ratio])
|
111 |
+
return self.getdata(index)
|
112 |
+
h, w = (img.shape[1], img.shape[2])
|
113 |
+
assert h, w == (ori_h//8, ori_w//8)
|
114 |
+
else:
|
115 |
+
img = self.loader(img_path)
|
116 |
+
h, w = (img.size[1], img.size[0])
|
117 |
+
assert h, w == (ori_h, ori_w)
|
118 |
+
|
119 |
+
data_info = {'img_hw': torch.tensor([ori_h, ori_w], dtype=torch.float32)}
|
120 |
+
data_info['aspect_ratio'] = closest_ratio
|
121 |
+
data_info["mask_type"] = self.mask_type
|
122 |
+
|
123 |
+
txt_info = np.load(npz_path)
|
124 |
+
txt_fea = torch.from_numpy(txt_info['caption_feature'])
|
125 |
+
attention_mask = torch.ones(1, 1, txt_fea.shape[1])
|
126 |
+
if 'attention_mask' in txt_info.keys():
|
127 |
+
attention_mask = torch.from_numpy(txt_info['attention_mask'])[None]
|
128 |
+
|
129 |
+
if not self.load_vae_feat:
|
130 |
+
if closest_size[0] / ori_h > closest_size[1] / ori_w:
|
131 |
+
resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h)
|
132 |
+
else:
|
133 |
+
resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1]
|
134 |
+
self.transform = T.Compose([
|
135 |
+
T.Lambda(lambda img: img.convert('RGB')),
|
136 |
+
T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC), # Image.BICUBIC
|
137 |
+
T.CenterCrop(closest_size),
|
138 |
+
T.ToTensor(),
|
139 |
+
T.Normalize([.5], [.5]),
|
140 |
+
])
|
141 |
+
|
142 |
+
if self.transform:
|
143 |
+
img = self.transform(img)
|
144 |
+
|
145 |
+
return img, txt_fea, attention_mask, data_info
|
146 |
+
|
147 |
+
def __getitem__(self, idx):
|
148 |
+
for _ in range(20):
|
149 |
+
try:
|
150 |
+
return self.getdata(idx)
|
151 |
+
except Exception as e:
|
152 |
+
print(f"Error details: {str(e)}")
|
153 |
+
idx = random.choice(self.ratio_index[self.closest_ratio])
|
154 |
+
raise RuntimeError('Too many bad data.')
|
155 |
+
|
156 |
+
|
157 |
+
@DATASETS.register_module()
|
158 |
+
class InternalDataMSSigma(InternalDataSigma):
|
159 |
+
def __init__(self,
|
160 |
+
root,
|
161 |
+
image_list_json='data_info.json',
|
162 |
+
transform=None,
|
163 |
+
resolution=256,
|
164 |
+
sample_subset=None,
|
165 |
+
load_vae_feat=False,
|
166 |
+
load_t5_feat=False,
|
167 |
+
input_size=32,
|
168 |
+
patch_size=2,
|
169 |
+
mask_ratio=0.0,
|
170 |
+
mask_type='null',
|
171 |
+
load_mask_index=False,
|
172 |
+
real_prompt_ratio=1.0,
|
173 |
+
max_length=300,
|
174 |
+
config=None,
|
175 |
+
**kwargs):
|
176 |
+
self.root = get_data_path(root)
|
177 |
+
self.transform = transform
|
178 |
+
self.load_vae_feat = load_vae_feat
|
179 |
+
self.load_t5_feat = load_t5_feat
|
180 |
+
self.ori_imgs_nums = 0
|
181 |
+
self.resolution = resolution
|
182 |
+
self.N = int(resolution // (input_size // patch_size))
|
183 |
+
self.mask_ratio = mask_ratio
|
184 |
+
self.load_mask_index = load_mask_index
|
185 |
+
self.mask_type = mask_type
|
186 |
+
self.real_prompt_ratio = real_prompt_ratio
|
187 |
+
self.max_lenth = max_length
|
188 |
+
self.base_size = int(kwargs['aspect_ratio_type'].split('_')[-1])
|
189 |
+
self.aspect_ratio = eval(kwargs.pop('aspect_ratio_type')) # base aspect ratio
|
190 |
+
self.meta_data_clean = []
|
191 |
+
self.img_samples = []
|
192 |
+
self.txt_samples = []
|
193 |
+
self.sharegpt4v_txt_samples = []
|
194 |
+
self.txt_feat_samples = []
|
195 |
+
self.vae_feat_samples = []
|
196 |
+
self.mask_index_samples = []
|
197 |
+
self.ratio_index = {}
|
198 |
+
self.ratio_nums = {}
|
199 |
+
self.gpt4v_txt_feat_samples = []
|
200 |
+
self.weight_dtype = torch.float16 if self.real_prompt_ratio > 0 else torch.float32
|
201 |
+
self.interpolate_model = InterpolationMode.BICUBIC
|
202 |
+
if self.aspect_ratio in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]:
|
203 |
+
self.interpolate_model = InterpolationMode.LANCZOS
|
204 |
+
suffix = ''
|
205 |
+
for k, v in self.aspect_ratio.items():
|
206 |
+
self.ratio_index[float(k)] = [] # used for self.getitem
|
207 |
+
self.ratio_nums[float(k)] = 0 # used for batch-sampler
|
208 |
+
logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
|
209 |
+
logger.info(f"T5 max token length: {self.max_lenth}")
|
210 |
+
logger.info(f"ratio of real user prompt: {self.real_prompt_ratio}")
|
211 |
+
|
212 |
+
image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json]
|
213 |
+
for json_file in image_list_json:
|
214 |
+
meta_data = self.load_json(os.path.join(self.root, json_file))
|
215 |
+
logger.info(f"{json_file} data volume: {len(meta_data)}")
|
216 |
+
self.ori_imgs_nums += len(meta_data)
|
217 |
+
meta_data_clean = [item for item in meta_data if item['ratio'] <= 4.5]
|
218 |
+
self.meta_data_clean.extend(meta_data_clean)
|
219 |
+
self.img_samples.extend([
|
220 |
+
os.path.join(self.root.replace('InternData'+suffix, 'InternImgs'), item['path']) for item in meta_data_clean
|
221 |
+
])
|
222 |
+
self.txt_samples.extend([item['prompt'] for item in meta_data_clean])
|
223 |
+
self.sharegpt4v_txt_samples.extend([item['sharegpt4v'] if 'sharegpt4v' in item else '' for item in meta_data_clean])
|
224 |
+
self.txt_feat_samples.extend([
|
225 |
+
os.path.join(
|
226 |
+
self.root,
|
227 |
+
'caption_features_new',
|
228 |
+
'_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')
|
229 |
+
) for item in meta_data_clean
|
230 |
+
])
|
231 |
+
self.gpt4v_txt_feat_samples.extend([
|
232 |
+
os.path.join(
|
233 |
+
self.root,
|
234 |
+
'sharegpt4v_caption_features_new',
|
235 |
+
'_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')
|
236 |
+
) for item in meta_data_clean
|
237 |
+
])
|
238 |
+
self.vae_feat_samples.extend(
|
239 |
+
[
|
240 |
+
os.path.join(
|
241 |
+
self.root + suffix,
|
242 |
+
f'img_sdxl_vae_features_{resolution}resolution_ms_new',
|
243 |
+
'_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')
|
244 |
+
) for item in meta_data_clean
|
245 |
+
])
|
246 |
+
|
247 |
+
if self.real_prompt_ratio < 1:
|
248 |
+
assert len(self.sharegpt4v_txt_samples[0]) != 0
|
249 |
+
|
250 |
+
# Set loader and extensions
|
251 |
+
if load_vae_feat:
|
252 |
+
self.transform = None
|
253 |
+
self.loader = self.vae_feat_loader
|
254 |
+
else:
|
255 |
+
self.loader = default_loader
|
256 |
+
|
257 |
+
if sample_subset is not None:
|
258 |
+
self.sample_subset(sample_subset) # sample dataset for local debug
|
259 |
+
|
260 |
+
# scan the dataset for ratio static
|
261 |
+
for i, info in enumerate(self.meta_data_clean[:len(self.meta_data_clean)//3]):
|
262 |
+
ori_h, ori_w = info['height'], info['width']
|
263 |
+
closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
|
264 |
+
self.ratio_nums[closest_ratio] += 1
|
265 |
+
if len(self.ratio_index[closest_ratio]) == 0:
|
266 |
+
self.ratio_index[closest_ratio].append(i)
|
267 |
+
|
268 |
+
def getdata(self, index):
|
269 |
+
img_path = self.img_samples[index]
|
270 |
+
real_prompt = random.random() < self.real_prompt_ratio
|
271 |
+
npz_path = self.txt_feat_samples[index] if real_prompt else self.gpt4v_txt_feat_samples[index]
|
272 |
+
txt = self.txt_samples[index] if real_prompt else self.sharegpt4v_txt_samples[index]
|
273 |
+
npy_path = self.vae_feat_samples[index]
|
274 |
+
data_info = {}
|
275 |
+
ori_h, ori_w = self.meta_data_clean[index]['height'], self.meta_data_clean[index]['width']
|
276 |
+
|
277 |
+
# Calculate the closest aspect ratio and resize & crop image[w, h]
|
278 |
+
closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
|
279 |
+
closest_size = list(map(lambda x: int(x), closest_size))
|
280 |
+
self.closest_ratio = closest_ratio
|
281 |
+
|
282 |
+
if self.load_vae_feat:
|
283 |
+
img = self.loader(npy_path)
|
284 |
+
if index not in self.ratio_index[closest_ratio]:
|
285 |
+
self.ratio_index[closest_ratio].append(index)
|
286 |
+
h, w = (img.shape[1], img.shape[2])
|
287 |
+
assert h, w == (ori_h//8, ori_w//8)
|
288 |
+
else:
|
289 |
+
img = self.loader(img_path)
|
290 |
+
h, w = (img.size[1], img.size[0])
|
291 |
+
assert h, w == (ori_h, ori_w)
|
292 |
+
|
293 |
+
data_info['img_hw'] = torch.tensor([ori_h, ori_w], dtype=torch.float32)
|
294 |
+
data_info['aspect_ratio'] = closest_ratio
|
295 |
+
data_info["mask_type"] = self.mask_type
|
296 |
+
|
297 |
+
attention_mask = torch.ones(1, 1, self.max_lenth)
|
298 |
+
if self.load_t5_feat:
|
299 |
+
txt_info = np.load(npz_path)
|
300 |
+
txt_fea = torch.from_numpy(txt_info['caption_feature'])
|
301 |
+
if 'attention_mask' in txt_info.keys():
|
302 |
+
attention_mask = torch.from_numpy(txt_info['attention_mask'])[None]
|
303 |
+
if txt_fea.shape[1] != self.max_lenth:
|
304 |
+
txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1).to(self.weight_dtype)
|
305 |
+
attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1)
|
306 |
+
else:
|
307 |
+
txt_fea = txt
|
308 |
+
|
309 |
+
if not self.load_vae_feat:
|
310 |
+
if closest_size[0] / ori_h > closest_size[1] / ori_w:
|
311 |
+
resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h)
|
312 |
+
else:
|
313 |
+
resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1]
|
314 |
+
self.transform = T.Compose([
|
315 |
+
T.Lambda(lambda img: img.convert('RGB')),
|
316 |
+
T.Resize(resize_size, interpolation=self.interpolate_model), # Image.BICUBIC
|
317 |
+
T.CenterCrop(closest_size),
|
318 |
+
T.ToTensor(),
|
319 |
+
T.Normalize([.5], [.5]),
|
320 |
+
])
|
321 |
+
|
322 |
+
if self.transform:
|
323 |
+
img = self.transform(img)
|
324 |
+
|
325 |
+
return img, txt_fea, attention_mask.to(torch.int16), data_info
|
326 |
+
|
327 |
+
def __getitem__(self, idx):
|
328 |
+
for _ in range(20):
|
329 |
+
try:
|
330 |
+
data = self.getdata(idx)
|
331 |
+
return data
|
332 |
+
except Exception as e:
|
333 |
+
print(f"Error details: {str(e)}")
|
334 |
+
idx = random.choice(self.ratio_index[self.closest_ratio])
|
335 |
+
raise RuntimeError('Too many bad data.')
|
336 |
+
|
diffusion/data/datasets/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .InternalData import InternalData, InternalDataSigma
|
2 |
+
from .InternalData_ms import InternalDataMS, InternalDataSigma
|
3 |
+
from .utils import *
|
diffusion/data/datasets/utils.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
ASPECT_RATIO_2880 = {
|
3 |
+
'0.25': [1408.0, 5760.0], '0.26': [1408.0, 5568.0], '0.27': [1408.0, 5376.0], '0.28': [1408.0, 5184.0],
|
4 |
+
'0.32': [1600.0, 4992.0], '0.33': [1600.0, 4800.0], '0.34': [1600.0, 4672.0], '0.4': [1792.0, 4480.0],
|
5 |
+
'0.42': [1792.0, 4288.0], '0.47': [1920.0, 4096.0], '0.49': [1920.0, 3904.0], '0.51': [1920.0, 3776.0],
|
6 |
+
'0.55': [2112.0, 3840.0], '0.59': [2112.0, 3584.0], '0.68': [2304.0, 3392.0], '0.72': [2304.0, 3200.0],
|
7 |
+
'0.78': [2496.0, 3200.0], '0.83': [2496.0, 3008.0], '0.89': [2688.0, 3008.0], '0.93': [2688.0, 2880.0],
|
8 |
+
'1.0': [2880.0, 2880.0], '1.07': [2880.0, 2688.0], '1.12': [3008.0, 2688.0], '1.21': [3008.0, 2496.0],
|
9 |
+
'1.28': [3200.0, 2496.0], '1.39': [3200.0, 2304.0], '1.47': [3392.0, 2304.0], '1.7': [3584.0, 2112.0],
|
10 |
+
'1.82': [3840.0, 2112.0], '2.03': [3904.0, 1920.0], '2.13': [4096.0, 1920.0], '2.39': [4288.0, 1792.0],
|
11 |
+
'2.5': [4480.0, 1792.0], '2.92': [4672.0, 1600.0], '3.0': [4800.0, 1600.0], '3.12': [4992.0, 1600.0],
|
12 |
+
'3.68': [5184.0, 1408.0], '3.82': [5376.0, 1408.0], '3.95': [5568.0, 1408.0], '4.0': [5760.0, 1408.0]
|
13 |
+
}
|
14 |
+
|
15 |
+
ASPECT_RATIO_2048 = {
|
16 |
+
'0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0], '0.27': [1024.0, 3840.0], '0.28': [1024.0, 3712.0],
|
17 |
+
'0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0],
|
18 |
+
'0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0],
|
19 |
+
'0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0],
|
20 |
+
'0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0],
|
21 |
+
'1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0],
|
22 |
+
'1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0],
|
23 |
+
'1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0],
|
24 |
+
'2.5': [3200.0, 1280.0], '2.89': [3328.0, 1152.0], '3.0': [3456.0, 1152.0], '3.11': [3584.0, 1152.0],
|
25 |
+
'3.62': [3712.0, 1024.0], '3.75': [3840.0, 1024.0], '3.88': [3968.0, 1024.0], '4.0': [4096.0, 1024.0]
|
26 |
+
}
|
27 |
+
|
28 |
+
ASPECT_RATIO_1024 = {
|
29 |
+
'0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.],
|
30 |
+
'0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
|
31 |
+
'0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
|
32 |
+
'0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
|
33 |
+
'0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
|
34 |
+
'1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
|
35 |
+
'1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
|
36 |
+
'1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
|
37 |
+
'2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.],
|
38 |
+
'3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.],
|
39 |
+
}
|
40 |
+
|
41 |
+
ASPECT_RATIO_512 = {
|
42 |
+
'0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
|
43 |
+
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
|
44 |
+
'0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
|
45 |
+
'0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
|
46 |
+
'0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
|
47 |
+
'1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
|
48 |
+
'1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
|
49 |
+
'1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
|
50 |
+
'2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
|
51 |
+
'3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
|
52 |
+
}
|
53 |
+
|
54 |
+
ASPECT_RATIO_256 = {
|
55 |
+
'0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0],
|
56 |
+
'0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
|
57 |
+
'0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
|
58 |
+
'0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
|
59 |
+
'0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
|
60 |
+
'1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
|
61 |
+
'1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
|
62 |
+
'1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
|
63 |
+
'2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0],
|
64 |
+
'3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0]
|
65 |
+
}
|
66 |
+
|
67 |
+
ASPECT_RATIO_256_TEST = {
|
68 |
+
'0.25': [128.0, 512.0], '0.28': [128.0, 464.0],
|
69 |
+
'0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
|
70 |
+
'0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
|
71 |
+
'0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
|
72 |
+
'0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
|
73 |
+
'1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
|
74 |
+
'1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
|
75 |
+
'1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
|
76 |
+
'2.5': [400.0, 160.0], '3.0': [432.0, 144.0],
|
77 |
+
'4.0': [512.0, 128.0]
|
78 |
+
}
|
79 |
+
|
80 |
+
ASPECT_RATIO_512_TEST = {
|
81 |
+
'0.25': [256.0, 1024.0], '0.28': [256.0, 928.0],
|
82 |
+
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
|
83 |
+
'0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
|
84 |
+
'0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
|
85 |
+
'0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
|
86 |
+
'1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
|
87 |
+
'1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
|
88 |
+
'1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
|
89 |
+
'2.5': [800.0, 320.0], '3.0': [864.0, 288.0],
|
90 |
+
'4.0': [1024.0, 256.0]
|
91 |
+
}
|
92 |
+
|
93 |
+
ASPECT_RATIO_1024_TEST = {
|
94 |
+
'0.25': [512., 2048.], '0.28': [512., 1856.],
|
95 |
+
'0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
|
96 |
+
'0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
|
97 |
+
'0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
|
98 |
+
'0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
|
99 |
+
'1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
|
100 |
+
'1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
|
101 |
+
'1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
|
102 |
+
'2.5': [1600., 640.], '3.0': [1728., 576.],
|
103 |
+
'4.0': [2048., 512.],
|
104 |
+
}
|
105 |
+
|
106 |
+
ASPECT_RATIO_2048_TEST = {
|
107 |
+
'0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0],
|
108 |
+
'0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0],
|
109 |
+
'0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0],
|
110 |
+
'0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0],
|
111 |
+
'0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0],
|
112 |
+
'1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0],
|
113 |
+
'1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0],
|
114 |
+
'1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0],
|
115 |
+
'2.5': [3200.0, 1280.0], '3.0': [3456.0, 1152.0],
|
116 |
+
'4.0': [4096.0, 1024.0]
|
117 |
+
}
|
118 |
+
|
119 |
+
ASPECT_RATIO_2880_TEST = {
|
120 |
+
'0.25': [2048.0, 8192.0], '0.26': [2048.0, 7936.0],
|
121 |
+
'0.32': [2304.0, 7168.0], '0.33': [2304.0, 6912.0], '0.35': [2304.0, 6656.0], '0.4': [2560.0, 6400.0],
|
122 |
+
'0.42': [2560.0, 6144.0], '0.48': [2816.0, 5888.0], '0.5': [2816.0, 5632.0], '0.52': [2816.0, 5376.0],
|
123 |
+
'0.57': [3072.0, 5376.0], '0.6': [3072.0, 5120.0], '0.68': [3328.0, 4864.0], '0.72': [3328.0, 4608.0],
|
124 |
+
'0.78': [3584.0, 4608.0], '0.82': [3584.0, 4352.0], '0.88': [3840.0, 4352.0], '0.94': [3840.0, 4096.0],
|
125 |
+
'1.0': [4096.0, 4096.0], '1.07': [4096.0, 3840.0], '1.13': [4352.0, 3840.0], '1.21': [4352.0, 3584.0],
|
126 |
+
'1.29': [4608.0, 3584.0], '1.38': [4608.0, 3328.0], '1.46': [4864.0, 3328.0], '1.67': [5120.0, 3072.0],
|
127 |
+
'1.75': [5376.0, 3072.0], '2.0': [5632.0, 2816.0], '2.09': [5888.0, 2816.0], '2.4': [6144.0, 2560.0],
|
128 |
+
'2.5': [6400.0, 2560.0], '3.0': [6912.0, 2304.0],
|
129 |
+
'4.0': [8192.0, 2048.0],
|
130 |
+
}
|
131 |
+
|
132 |
+
def get_chunks(lst, n):
|
133 |
+
for i in range(0, len(lst), n):
|
134 |
+
yield lst[i:i + n]
|
diffusion/data/transforms.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision.transforms as T
|
2 |
+
|
3 |
+
TRANSFORMS = dict()
|
4 |
+
|
5 |
+
|
6 |
+
def register_transform(transform):
|
7 |
+
name = transform.__name__
|
8 |
+
if name in TRANSFORMS:
|
9 |
+
raise RuntimeError(f'Transform {name} has already registered.')
|
10 |
+
TRANSFORMS.update({name: transform})
|
11 |
+
|
12 |
+
|
13 |
+
def get_transform(type, resolution):
|
14 |
+
transform = TRANSFORMS[type](resolution)
|
15 |
+
transform = T.Compose(transform)
|
16 |
+
transform.image_size = resolution
|
17 |
+
return transform
|
18 |
+
|
19 |
+
|
20 |
+
@register_transform
|
21 |
+
def default_train(n_px):
|
22 |
+
transform = [
|
23 |
+
T.Lambda(lambda img: img.convert('RGB')),
|
24 |
+
T.Resize(n_px), # Image.BICUBIC
|
25 |
+
T.CenterCrop(n_px),
|
26 |
+
# T.RandomHorizontalFlip(),
|
27 |
+
T.ToTensor(),
|
28 |
+
T.Normalize([.5], [.5]),
|
29 |
+
]
|
30 |
+
return transform
|
diffusion/dpm_solver.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .model import gaussian_diffusion as gd
|
3 |
+
from .model.dpm_solver import model_wrapper, DPM_Solver, NoiseScheduleVP
|
4 |
+
|
5 |
+
|
6 |
+
def DPMS(
|
7 |
+
model,
|
8 |
+
condition,
|
9 |
+
uncondition,
|
10 |
+
cfg_scale,
|
11 |
+
model_type='noise', # or "x_start" or "v" or "score"
|
12 |
+
noise_schedule="linear",
|
13 |
+
guidance_type='classifier-free',
|
14 |
+
model_kwargs={},
|
15 |
+
diffusion_steps=1000
|
16 |
+
):
|
17 |
+
betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
|
18 |
+
|
19 |
+
## 1. Define the noise schedule.
|
20 |
+
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)
|
21 |
+
|
22 |
+
## 2. Convert your discrete-time `model` to the continuous-time
|
23 |
+
## noise prediction model. Here is an example for a diffusion model
|
24 |
+
## `model` with the noise prediction type ("noise") .
|
25 |
+
model_fn = model_wrapper(
|
26 |
+
model,
|
27 |
+
noise_schedule,
|
28 |
+
model_type=model_type,
|
29 |
+
model_kwargs=model_kwargs,
|
30 |
+
guidance_type=guidance_type,
|
31 |
+
condition=condition,
|
32 |
+
unconditional_condition=uncondition,
|
33 |
+
guidance_scale=cfg_scale,
|
34 |
+
)
|
35 |
+
## 3. Define dpm-solver and sample by multistep DPM-Solver.
|
36 |
+
return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
|
diffusion/iddpm.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
from diffusion.model.respace import SpacedDiffusion, space_timesteps
|
6 |
+
from .model import gaussian_diffusion as gd
|
7 |
+
|
8 |
+
|
9 |
+
def IDDPM(
|
10 |
+
timestep_respacing,
|
11 |
+
noise_schedule="linear",
|
12 |
+
use_kl=False,
|
13 |
+
sigma_small=False,
|
14 |
+
predict_xstart=False,
|
15 |
+
learn_sigma=True,
|
16 |
+
pred_sigma=True,
|
17 |
+
rescale_learned_sigmas=False,
|
18 |
+
diffusion_steps=1000,
|
19 |
+
snr=False,
|
20 |
+
return_startx=False,
|
21 |
+
):
|
22 |
+
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
|
23 |
+
if use_kl:
|
24 |
+
loss_type = gd.LossType.RESCALED_KL
|
25 |
+
elif rescale_learned_sigmas:
|
26 |
+
loss_type = gd.LossType.RESCALED_MSE
|
27 |
+
else:
|
28 |
+
loss_type = gd.LossType.MSE
|
29 |
+
if timestep_respacing is None or timestep_respacing == "":
|
30 |
+
timestep_respacing = [diffusion_steps]
|
31 |
+
return SpacedDiffusion(
|
32 |
+
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
|
33 |
+
betas=betas,
|
34 |
+
model_mean_type=(
|
35 |
+
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
|
36 |
+
),
|
37 |
+
model_var_type=(
|
38 |
+
((
|
39 |
+
gd.ModelVarType.FIXED_LARGE
|
40 |
+
if not sigma_small
|
41 |
+
else gd.ModelVarType.FIXED_SMALL
|
42 |
+
)
|
43 |
+
if not learn_sigma
|
44 |
+
else gd.ModelVarType.LEARNED_RANGE
|
45 |
+
)
|
46 |
+
if pred_sigma
|
47 |
+
else None
|
48 |
+
),
|
49 |
+
loss_type=loss_type,
|
50 |
+
snr=snr,
|
51 |
+
return_startx=return_startx,
|
52 |
+
# rescale_timesteps=rescale_timesteps,
|
53 |
+
)
|
diffusion/lcm_scheduler.py
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
16 |
+
# and https://github.com/hojonathanho/diffusion
|
17 |
+
|
18 |
+
import math
|
19 |
+
from dataclasses import dataclass
|
20 |
+
from typing import List, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
|
25 |
+
from diffusers import ConfigMixin, SchedulerMixin
|
26 |
+
from diffusers.configuration_utils import register_to_config
|
27 |
+
from diffusers.utils import BaseOutput
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
|
32 |
+
class LCMSchedulerOutput(BaseOutput):
|
33 |
+
"""
|
34 |
+
Output class for the scheduler's `step` function output.
|
35 |
+
Args:
|
36 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
37 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
38 |
+
denoising loop.
|
39 |
+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
40 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
41 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
42 |
+
"""
|
43 |
+
|
44 |
+
prev_sample: torch.FloatTensor
|
45 |
+
denoised: Optional[torch.FloatTensor] = None
|
46 |
+
|
47 |
+
|
48 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
49 |
+
def betas_for_alpha_bar(
|
50 |
+
num_diffusion_timesteps,
|
51 |
+
max_beta=0.999,
|
52 |
+
alpha_transform_type="cosine",
|
53 |
+
):
|
54 |
+
"""
|
55 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
56 |
+
(1-beta) over time from t = [0,1].
|
57 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
58 |
+
to that part of the diffusion process.
|
59 |
+
Args:
|
60 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
61 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
62 |
+
prevent singularities.
|
63 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
64 |
+
Choose from `cosine` or `exp`
|
65 |
+
Returns:
|
66 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
67 |
+
"""
|
68 |
+
if alpha_transform_type == "cosine":
|
69 |
+
|
70 |
+
def alpha_bar_fn(t):
|
71 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
72 |
+
|
73 |
+
elif alpha_transform_type == "exp":
|
74 |
+
|
75 |
+
def alpha_bar_fn(t):
|
76 |
+
return math.exp(t * -12.0)
|
77 |
+
|
78 |
+
else:
|
79 |
+
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
80 |
+
|
81 |
+
betas = []
|
82 |
+
for i in range(num_diffusion_timesteps):
|
83 |
+
t1 = i / num_diffusion_timesteps
|
84 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
85 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
86 |
+
return torch.tensor(betas, dtype=torch.float32)
|
87 |
+
|
88 |
+
|
89 |
+
def rescale_zero_terminal_snr(betas):
|
90 |
+
"""
|
91 |
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
92 |
+
Args:
|
93 |
+
betas (`torch.FloatTensor`):
|
94 |
+
the betas that the scheduler is being initialized with.
|
95 |
+
Returns:
|
96 |
+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
97 |
+
"""
|
98 |
+
# Convert betas to alphas_bar_sqrt
|
99 |
+
alphas = 1.0 - betas
|
100 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
101 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
102 |
+
|
103 |
+
# Store old values.
|
104 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
105 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
106 |
+
|
107 |
+
# Shift so the last timestep is zero.
|
108 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
109 |
+
|
110 |
+
# Scale so the first timestep is back to the old value.
|
111 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
112 |
+
|
113 |
+
# Convert alphas_bar_sqrt to betas
|
114 |
+
alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
|
115 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
116 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
117 |
+
betas = 1 - alphas
|
118 |
+
|
119 |
+
return betas
|
120 |
+
|
121 |
+
|
122 |
+
class LCMScheduler(SchedulerMixin, ConfigMixin):
|
123 |
+
"""
|
124 |
+
`LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
|
125 |
+
non-Markovian guidance.
|
126 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
127 |
+
methods the library implements for all schedulers such as loading and saving.
|
128 |
+
Args:
|
129 |
+
num_train_timesteps (`int`, defaults to 1000):
|
130 |
+
The number of diffusion steps to train the model.
|
131 |
+
beta_start (`float`, defaults to 0.0001):
|
132 |
+
The starting `beta` value of inference.
|
133 |
+
beta_end (`float`, defaults to 0.02):
|
134 |
+
The final `beta` value.
|
135 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
136 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
137 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
138 |
+
trained_betas (`np.ndarray`, *optional*):
|
139 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
140 |
+
clip_sample (`bool`, defaults to `True`):
|
141 |
+
Clip the predicted sample for numerical stability.
|
142 |
+
clip_sample_range (`float`, defaults to 1.0):
|
143 |
+
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
144 |
+
set_alpha_to_one (`bool`, defaults to `True`):
|
145 |
+
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
146 |
+
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
147 |
+
otherwise it uses the alpha value at step 0.
|
148 |
+
steps_offset (`int`, defaults to 0):
|
149 |
+
An offset added to the inference steps. You can use a combination of `offset=1` and
|
150 |
+
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
151 |
+
Diffusion.
|
152 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
153 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
154 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
155 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
156 |
+
thresholding (`bool`, defaults to `False`):
|
157 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
158 |
+
as Stable Diffusion.
|
159 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
160 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
161 |
+
sample_max_value (`float`, defaults to 1.0):
|
162 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
163 |
+
timestep_spacing (`str`, defaults to `"leading"`):
|
164 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
165 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
166 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
167 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
168 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
169 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
170 |
+
"""
|
171 |
+
|
172 |
+
# _compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
173 |
+
order = 1
|
174 |
+
|
175 |
+
@register_to_config
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
num_train_timesteps: int = 1000,
|
179 |
+
beta_start: float = 0.0001,
|
180 |
+
beta_end: float = 0.02,
|
181 |
+
beta_schedule: str = "linear",
|
182 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
183 |
+
clip_sample: bool = True,
|
184 |
+
set_alpha_to_one: bool = True,
|
185 |
+
steps_offset: int = 0,
|
186 |
+
prediction_type: str = "epsilon",
|
187 |
+
thresholding: bool = False,
|
188 |
+
dynamic_thresholding_ratio: float = 0.995,
|
189 |
+
clip_sample_range: float = 1.0,
|
190 |
+
sample_max_value: float = 1.0,
|
191 |
+
timestep_spacing: str = "leading",
|
192 |
+
rescale_betas_zero_snr: bool = False,
|
193 |
+
):
|
194 |
+
if trained_betas is not None:
|
195 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
196 |
+
elif beta_schedule == "linear":
|
197 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
198 |
+
elif beta_schedule == "scaled_linear":
|
199 |
+
# this schedule is very specific to the latent diffusion model.
|
200 |
+
self.betas = (
|
201 |
+
torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
202 |
+
)
|
203 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
204 |
+
# Glide cosine schedule
|
205 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
206 |
+
else:
|
207 |
+
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
208 |
+
|
209 |
+
# Rescale for zero SNR
|
210 |
+
if rescale_betas_zero_snr:
|
211 |
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
212 |
+
|
213 |
+
self.alphas = 1.0 - self.betas
|
214 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
215 |
+
|
216 |
+
# At every step in ddim, we are looking into the previous alphas_cumprod
|
217 |
+
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
218 |
+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
219 |
+
# whether we use the final alpha of the "non-previous" one.
|
220 |
+
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
221 |
+
|
222 |
+
# standard deviation of the initial noise distribution
|
223 |
+
self.init_noise_sigma = 1.0
|
224 |
+
|
225 |
+
# setable values
|
226 |
+
self.num_inference_steps = None
|
227 |
+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
228 |
+
|
229 |
+
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
230 |
+
"""
|
231 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
232 |
+
current timestep.
|
233 |
+
Args:
|
234 |
+
sample (`torch.FloatTensor`):
|
235 |
+
The input sample.
|
236 |
+
timestep (`int`, *optional*):
|
237 |
+
The current timestep in the diffusion chain.
|
238 |
+
Returns:
|
239 |
+
`torch.FloatTensor`:
|
240 |
+
A scaled input sample.
|
241 |
+
"""
|
242 |
+
return sample
|
243 |
+
|
244 |
+
def _get_variance(self, timestep, prev_timestep):
|
245 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
246 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
247 |
+
beta_prod_t = 1 - alpha_prod_t
|
248 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
249 |
+
|
250 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
251 |
+
|
252 |
+
return variance
|
253 |
+
|
254 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
255 |
+
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
256 |
+
"""
|
257 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
258 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
259 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
260 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
261 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
262 |
+
https://arxiv.org/abs/2205.11487
|
263 |
+
"""
|
264 |
+
dtype = sample.dtype
|
265 |
+
batch_size, channels, height, width = sample.shape
|
266 |
+
|
267 |
+
if dtype not in (torch.float32, torch.float64):
|
268 |
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
269 |
+
|
270 |
+
# Flatten sample for doing quantile calculation along each image
|
271 |
+
sample = sample.reshape(batch_size, channels * height * width)
|
272 |
+
|
273 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
274 |
+
|
275 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
276 |
+
s = torch.clamp(
|
277 |
+
s, min=1, max=self.config.sample_max_value
|
278 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
279 |
+
|
280 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
281 |
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
282 |
+
|
283 |
+
sample = sample.reshape(batch_size, channels, height, width)
|
284 |
+
sample = sample.to(dtype)
|
285 |
+
|
286 |
+
return sample
|
287 |
+
|
288 |
+
def set_timesteps(self, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None):
|
289 |
+
"""
|
290 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
291 |
+
Args:
|
292 |
+
num_inference_steps (`int`):
|
293 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
294 |
+
"""
|
295 |
+
|
296 |
+
if num_inference_steps > self.config.num_train_timesteps:
|
297 |
+
raise ValueError(
|
298 |
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
299 |
+
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
300 |
+
f" maximal {self.config.num_train_timesteps} timesteps."
|
301 |
+
)
|
302 |
+
|
303 |
+
self.num_inference_steps = num_inference_steps
|
304 |
+
|
305 |
+
# LCM Timesteps Setting: # Linear Spacing
|
306 |
+
c = self.config.num_train_timesteps // lcm_origin_steps
|
307 |
+
lcm_origin_timesteps = np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1 # LCM Training Steps Schedule
|
308 |
+
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
|
309 |
+
timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule
|
310 |
+
|
311 |
+
self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
|
312 |
+
|
313 |
+
def get_scalings_for_boundary_condition_discrete(self, t):
|
314 |
+
self.sigma_data = 0.5 # Default: 0.5
|
315 |
+
|
316 |
+
# By dividing 0.1: This is almost a delta function at t=0.
|
317 |
+
c_skip = self.sigma_data ** 2 / ((t / 0.1) ** 2 + self.sigma_data ** 2)
|
318 |
+
c_out = ((t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data ** 2) ** 0.5)
|
319 |
+
return c_skip, c_out
|
320 |
+
|
321 |
+
def step(
|
322 |
+
self,
|
323 |
+
model_output: torch.FloatTensor,
|
324 |
+
timeindex: int,
|
325 |
+
timestep: int,
|
326 |
+
sample: torch.FloatTensor,
|
327 |
+
eta: float = 0.0,
|
328 |
+
use_clipped_model_output: bool = False,
|
329 |
+
generator=None,
|
330 |
+
variance_noise: Optional[torch.FloatTensor] = None,
|
331 |
+
return_dict: bool = True,
|
332 |
+
) -> Union[LCMSchedulerOutput, Tuple]:
|
333 |
+
"""
|
334 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
335 |
+
process from the learned model outputs (most often the predicted noise).
|
336 |
+
Args:
|
337 |
+
model_output (`torch.FloatTensor`):
|
338 |
+
The direct output from learned diffusion model.
|
339 |
+
timestep (`float`):
|
340 |
+
The current discrete timestep in the diffusion chain.
|
341 |
+
sample (`torch.FloatTensor`):
|
342 |
+
A current instance of a sample created by the diffusion process.
|
343 |
+
eta (`float`):
|
344 |
+
The weight of noise for added noise in diffusion step.
|
345 |
+
use_clipped_model_output (`bool`, defaults to `False`):
|
346 |
+
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
|
347 |
+
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
|
348 |
+
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
|
349 |
+
`use_clipped_model_output` has no effect.
|
350 |
+
generator (`torch.Generator`, *optional*):
|
351 |
+
A random number generator.
|
352 |
+
variance_noise (`torch.FloatTensor`):
|
353 |
+
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
354 |
+
itself. Useful for methods such as [`CycleDiffusion`].
|
355 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
356 |
+
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
|
357 |
+
Returns:
|
358 |
+
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
|
359 |
+
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
|
360 |
+
tuple is returned where the first element is the sample tensor.
|
361 |
+
"""
|
362 |
+
if self.num_inference_steps is None:
|
363 |
+
raise ValueError(
|
364 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
365 |
+
)
|
366 |
+
|
367 |
+
# 1. get previous step value
|
368 |
+
prev_timeindex = timeindex + 1
|
369 |
+
if prev_timeindex < len(self.timesteps):
|
370 |
+
prev_timestep = self.timesteps[prev_timeindex]
|
371 |
+
else:
|
372 |
+
prev_timestep = timestep
|
373 |
+
|
374 |
+
# 2. compute alphas, betas
|
375 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
376 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
377 |
+
|
378 |
+
beta_prod_t = 1 - alpha_prod_t
|
379 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
380 |
+
|
381 |
+
# 3. Get scalings for boundary conditions
|
382 |
+
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
|
383 |
+
|
384 |
+
# 4. Different Parameterization:
|
385 |
+
parameterization = self.config.prediction_type
|
386 |
+
|
387 |
+
if parameterization == "epsilon": # noise-prediction
|
388 |
+
pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
|
389 |
+
|
390 |
+
elif parameterization == "sample": # x-prediction
|
391 |
+
pred_x0 = model_output
|
392 |
+
|
393 |
+
elif parameterization == "v_prediction": # v-prediction
|
394 |
+
pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
|
395 |
+
|
396 |
+
# 4. Denoise model output using boundary conditions
|
397 |
+
denoised = c_out * pred_x0 + c_skip * sample
|
398 |
+
|
399 |
+
# 5. Sample z ~ N(0, I), For MultiStep Inference
|
400 |
+
# Noise is not used for one-step sampling.
|
401 |
+
if len(self.timesteps) > 1:
|
402 |
+
noise = torch.randn(model_output.shape).to(model_output.device)
|
403 |
+
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
|
404 |
+
else:
|
405 |
+
prev_sample = denoised
|
406 |
+
|
407 |
+
if not return_dict:
|
408 |
+
return (prev_sample, denoised)
|
409 |
+
|
410 |
+
return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
|
411 |
+
|
412 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
413 |
+
def add_noise(
|
414 |
+
self,
|
415 |
+
original_samples: torch.FloatTensor,
|
416 |
+
noise: torch.FloatTensor,
|
417 |
+
timesteps: torch.IntTensor,
|
418 |
+
) -> torch.FloatTensor:
|
419 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
420 |
+
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
421 |
+
timesteps = timesteps.to(original_samples.device)
|
422 |
+
|
423 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
424 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
425 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
426 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
427 |
+
|
428 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
429 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
430 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
431 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
432 |
+
|
433 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
434 |
+
return noisy_samples
|
435 |
+
|
436 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
437 |
+
def get_velocity(
|
438 |
+
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
439 |
+
) -> torch.FloatTensor:
|
440 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
441 |
+
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
442 |
+
timesteps = timesteps.to(sample.device)
|
443 |
+
|
444 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
445 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
446 |
+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
447 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
448 |
+
|
449 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
450 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
451 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
452 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
453 |
+
|
454 |
+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
455 |
+
return velocity
|
456 |
+
|
457 |
+
def __len__(self):
|
458 |
+
return self.config.num_train_timesteps
|
459 |
+
|
diffusion/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .nets import *
|
diffusion/model/builder.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmcv import Registry
|
2 |
+
|
3 |
+
from diffusion.model.utils import set_grad_checkpoint
|
4 |
+
|
5 |
+
MODELS = Registry('models')
|
6 |
+
|
7 |
+
|
8 |
+
def build_model(cfg, use_grad_checkpoint=False, use_fp32_attention=False, gc_step=1, **kwargs):
|
9 |
+
if isinstance(cfg, str):
|
10 |
+
cfg = dict(type=cfg)
|
11 |
+
model = MODELS.build(cfg, default_args=kwargs)
|
12 |
+
if use_grad_checkpoint:
|
13 |
+
set_grad_checkpoint(model, use_fp32_attention=use_fp32_attention, gc_step=gc_step)
|
14 |
+
return model
|
diffusion/model/diffusion_utils.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch as th
|
8 |
+
|
9 |
+
|
10 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
11 |
+
"""
|
12 |
+
Compute the KL divergence between two gaussians.
|
13 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
14 |
+
scalars, among other use cases.
|
15 |
+
"""
|
16 |
+
tensor = None
|
17 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
18 |
+
if isinstance(obj, th.Tensor):
|
19 |
+
tensor = obj
|
20 |
+
break
|
21 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
22 |
+
|
23 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
24 |
+
# Tensors, but it does not work for th.exp().
|
25 |
+
logvar1, logvar2 = [
|
26 |
+
x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device)
|
27 |
+
for x in (logvar1, logvar2)
|
28 |
+
]
|
29 |
+
|
30 |
+
return 0.5 * (
|
31 |
+
-1.0
|
32 |
+
+ logvar2
|
33 |
+
- logvar1
|
34 |
+
+ th.exp(logvar1 - logvar2)
|
35 |
+
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
def approx_standard_normal_cdf(x):
|
40 |
+
"""
|
41 |
+
A fast approximation of the cumulative distribution function of the
|
42 |
+
standard normal.
|
43 |
+
"""
|
44 |
+
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
45 |
+
|
46 |
+
|
47 |
+
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
|
48 |
+
"""
|
49 |
+
Compute the log-likelihood of a continuous Gaussian distribution.
|
50 |
+
:param x: the targets
|
51 |
+
:param means: the Gaussian mean Tensor.
|
52 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
53 |
+
:return: a tensor like x of log probabilities (in nats).
|
54 |
+
"""
|
55 |
+
centered_x = x - means
|
56 |
+
inv_stdv = th.exp(-log_scales)
|
57 |
+
normalized_x = centered_x * inv_stdv
|
58 |
+
log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
|
59 |
+
return log_probs
|
60 |
+
|
61 |
+
|
62 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
63 |
+
"""
|
64 |
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
65 |
+
given image.
|
66 |
+
:param x: the target images. It is assumed that this was uint8 values,
|
67 |
+
rescaled to the range [-1, 1].
|
68 |
+
:param means: the Gaussian mean Tensor.
|
69 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
70 |
+
:return: a tensor like x of log probabilities (in nats).
|
71 |
+
"""
|
72 |
+
assert x.shape == means.shape == log_scales.shape
|
73 |
+
centered_x = x - means
|
74 |
+
inv_stdv = th.exp(-log_scales)
|
75 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
76 |
+
cdf_plus = approx_standard_normal_cdf(plus_in)
|
77 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
78 |
+
cdf_min = approx_standard_normal_cdf(min_in)
|
79 |
+
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
80 |
+
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
81 |
+
cdf_delta = cdf_plus - cdf_min
|
82 |
+
log_probs = th.where(
|
83 |
+
x < -0.999,
|
84 |
+
log_cdf_plus,
|
85 |
+
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
86 |
+
)
|
87 |
+
assert log_probs.shape == x.shape
|
88 |
+
return log_probs
|
diffusion/model/dpm_solver.py
ADDED
@@ -0,0 +1,1337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
|
4 |
+
|
5 |
+
class NoiseScheduleVP:
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
schedule='discrete',
|
9 |
+
betas=None,
|
10 |
+
alphas_cumprod=None,
|
11 |
+
continuous_beta_0=0.1,
|
12 |
+
continuous_beta_1=20.,
|
13 |
+
dtype=torch.float32,
|
14 |
+
):
|
15 |
+
"""Create a wrapper class for the forward SDE (VP type).
|
16 |
+
|
17 |
+
***
|
18 |
+
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
19 |
+
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
|
20 |
+
***
|
21 |
+
|
22 |
+
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
23 |
+
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
24 |
+
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
25 |
+
|
26 |
+
log_alpha_t = self.marginal_log_mean_coeff(t)
|
27 |
+
sigma_t = self.marginal_std(t)
|
28 |
+
lambda_t = self.marginal_lambda(t)
|
29 |
+
|
30 |
+
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
31 |
+
|
32 |
+
t = self.inverse_lambda(lambda_t)
|
33 |
+
|
34 |
+
===============================================================
|
35 |
+
|
36 |
+
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
37 |
+
|
38 |
+
1. For discrete-time DPMs:
|
39 |
+
|
40 |
+
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
41 |
+
t_i = (i + 1) / N
|
42 |
+
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
43 |
+
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
47 |
+
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
48 |
+
|
49 |
+
Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
50 |
+
|
51 |
+
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
52 |
+
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
53 |
+
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
54 |
+
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
55 |
+
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
56 |
+
and
|
57 |
+
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
58 |
+
|
59 |
+
|
60 |
+
2. For continuous-time DPMs:
|
61 |
+
|
62 |
+
We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise
|
63 |
+
schedule are the default settings in Yang Song's ScoreSDE:
|
64 |
+
|
65 |
+
Args:
|
66 |
+
beta_min: A `float` number. The smallest beta for the linear schedule.
|
67 |
+
beta_max: A `float` number. The largest beta for the linear schedule.
|
68 |
+
T: A `float` number. The ending time of the forward process.
|
69 |
+
|
70 |
+
===============================================================
|
71 |
+
|
72 |
+
Args:
|
73 |
+
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
74 |
+
'linear' for continuous-time DPMs.
|
75 |
+
Returns:
|
76 |
+
A wrapper object of the forward SDE (VP type).
|
77 |
+
|
78 |
+
===============================================================
|
79 |
+
|
80 |
+
Example:
|
81 |
+
|
82 |
+
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
83 |
+
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
84 |
+
|
85 |
+
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
86 |
+
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
87 |
+
|
88 |
+
# For continuous-time DPMs (VPSDE), linear schedule:
|
89 |
+
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
90 |
+
|
91 |
+
"""
|
92 |
+
|
93 |
+
if schedule not in ['discrete', 'linear']:
|
94 |
+
raise ValueError(
|
95 |
+
"Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear'".format(schedule))
|
96 |
+
|
97 |
+
self.schedule = schedule
|
98 |
+
if schedule == 'discrete':
|
99 |
+
if betas is not None:
|
100 |
+
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
101 |
+
else:
|
102 |
+
assert alphas_cumprod is not None
|
103 |
+
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
104 |
+
self.T = 1.
|
105 |
+
self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype)
|
106 |
+
self.total_N = self.log_alpha_array.shape[1]
|
107 |
+
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
|
108 |
+
else:
|
109 |
+
self.T = 1.
|
110 |
+
self.total_N = 1000
|
111 |
+
self.beta_0 = continuous_beta_0
|
112 |
+
self.beta_1 = continuous_beta_1
|
113 |
+
|
114 |
+
def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1):
|
115 |
+
"""
|
116 |
+
For some beta schedules such as cosine schedule, the log-SNR has numerical isssues.
|
117 |
+
We clip the log-SNR near t=T within -5.1 to ensure the stability.
|
118 |
+
Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE.
|
119 |
+
"""
|
120 |
+
log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas))
|
121 |
+
lambs = log_alphas - log_sigmas
|
122 |
+
idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda)
|
123 |
+
if idx > 0:
|
124 |
+
log_alphas = log_alphas[:-idx]
|
125 |
+
return log_alphas
|
126 |
+
|
127 |
+
def marginal_log_mean_coeff(self, t):
|
128 |
+
"""
|
129 |
+
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
130 |
+
"""
|
131 |
+
if self.schedule == 'discrete':
|
132 |
+
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
|
133 |
+
self.log_alpha_array.to(t.device)).reshape((-1))
|
134 |
+
elif self.schedule == 'linear':
|
135 |
+
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
136 |
+
|
137 |
+
def marginal_alpha(self, t):
|
138 |
+
"""
|
139 |
+
Compute alpha_t of a given continuous-time label t in [0, T].
|
140 |
+
"""
|
141 |
+
return torch.exp(self.marginal_log_mean_coeff(t))
|
142 |
+
|
143 |
+
def marginal_std(self, t):
|
144 |
+
"""
|
145 |
+
Compute sigma_t of a given continuous-time label t in [0, T].
|
146 |
+
"""
|
147 |
+
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
148 |
+
|
149 |
+
def marginal_lambda(self, t):
|
150 |
+
"""
|
151 |
+
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
152 |
+
"""
|
153 |
+
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
154 |
+
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
155 |
+
return log_mean_coeff - log_std
|
156 |
+
|
157 |
+
def inverse_lambda(self, lamb):
|
158 |
+
"""
|
159 |
+
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
160 |
+
"""
|
161 |
+
if self.schedule == 'linear':
|
162 |
+
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
163 |
+
Delta = self.beta_0 ** 2 + tmp
|
164 |
+
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
165 |
+
elif self.schedule == 'discrete':
|
166 |
+
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
167 |
+
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
|
168 |
+
torch.flip(self.t_array.to(lamb.device), [1]))
|
169 |
+
return t.reshape((-1,))
|
170 |
+
|
171 |
+
|
172 |
+
def model_wrapper(
|
173 |
+
model,
|
174 |
+
noise_schedule,
|
175 |
+
model_type="noise",
|
176 |
+
model_kwargs={},
|
177 |
+
guidance_type="uncond",
|
178 |
+
condition=None,
|
179 |
+
unconditional_condition=None,
|
180 |
+
guidance_scale=1.,
|
181 |
+
classifier_fn=None,
|
182 |
+
classifier_kwargs={},
|
183 |
+
):
|
184 |
+
"""Create a wrapper function for the noise prediction model.
|
185 |
+
|
186 |
+
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
187 |
+
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
188 |
+
|
189 |
+
We support four types of the diffusion model by setting `model_type`:
|
190 |
+
|
191 |
+
1. "noise": noise prediction model. (Trained by predicting noise).
|
192 |
+
|
193 |
+
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
194 |
+
|
195 |
+
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
196 |
+
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
197 |
+
|
198 |
+
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
|
199 |
+
arXiv preprint arXiv:2202.00512 (2022).
|
200 |
+
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
201 |
+
arXiv preprint arXiv:2210.02303 (2022).
|
202 |
+
|
203 |
+
4. "score": marginal score function. (Trained by denoising score matching).
|
204 |
+
Note that the score function and the noise prediction model follows a simple relationship:
|
205 |
+
```
|
206 |
+
noise(x_t, t) = -sigma_t * score(x_t, t)
|
207 |
+
```
|
208 |
+
|
209 |
+
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
210 |
+
1. "uncond": unconditional sampling by DPMs.
|
211 |
+
The input `model` has the following format:
|
212 |
+
``
|
213 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
214 |
+
``
|
215 |
+
|
216 |
+
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
217 |
+
The input `model` has the following format:
|
218 |
+
``
|
219 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
220 |
+
``
|
221 |
+
|
222 |
+
The input `classifier_fn` has the following format:
|
223 |
+
``
|
224 |
+
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
225 |
+
``
|
226 |
+
|
227 |
+
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
|
228 |
+
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
229 |
+
|
230 |
+
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
231 |
+
The input `model` has the following format:
|
232 |
+
``
|
233 |
+
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
234 |
+
``
|
235 |
+
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
236 |
+
|
237 |
+
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
238 |
+
arXiv preprint arXiv:2207.12598 (2022).
|
239 |
+
|
240 |
+
|
241 |
+
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
242 |
+
or continuous-time labels (i.e. epsilon to T).
|
243 |
+
|
244 |
+
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
245 |
+
``
|
246 |
+
def model_fn(x, t_continuous) -> noise:
|
247 |
+
t_input = get_model_input_time(t_continuous)
|
248 |
+
return noise_pred(model, x, t_input, **model_kwargs)
|
249 |
+
``
|
250 |
+
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
251 |
+
|
252 |
+
===============================================================
|
253 |
+
|
254 |
+
Args:
|
255 |
+
model: A diffusion model with the corresponding format described above.
|
256 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
257 |
+
model_type: A `str`. The parameterization type of the diffusion model.
|
258 |
+
"noise" or "x_start" or "v" or "score".
|
259 |
+
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
260 |
+
guidance_type: A `str`. The type of the guidance for sampling.
|
261 |
+
"uncond" or "classifier" or "classifier-free".
|
262 |
+
condition: A pytorch tensor. The condition for the guided sampling.
|
263 |
+
Only used for "classifier" or "classifier-free" guidance type.
|
264 |
+
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
265 |
+
Only used for "classifier-free" guidance type.
|
266 |
+
guidance_scale: A `float`. The scale for the guided sampling.
|
267 |
+
classifier_fn: A classifier function. Only used for the classifier guidance.
|
268 |
+
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
269 |
+
Returns:
|
270 |
+
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
271 |
+
"""
|
272 |
+
|
273 |
+
def get_model_input_time(t_continuous):
|
274 |
+
"""
|
275 |
+
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
276 |
+
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
277 |
+
For continuous-time DPMs, we just use `t_continuous`.
|
278 |
+
"""
|
279 |
+
if noise_schedule.schedule == 'discrete':
|
280 |
+
return (t_continuous - 1. / noise_schedule.total_N) * 1000.
|
281 |
+
else:
|
282 |
+
return t_continuous
|
283 |
+
|
284 |
+
def noise_pred_fn(x, t_continuous, cond=None):
|
285 |
+
t_input = get_model_input_time(t_continuous)
|
286 |
+
if cond is None:
|
287 |
+
output = model(x, t_input, **model_kwargs)
|
288 |
+
else:
|
289 |
+
output = model(x, t_input, cond, **model_kwargs)
|
290 |
+
if model_type == "noise":
|
291 |
+
return output
|
292 |
+
elif model_type == "x_start":
|
293 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
294 |
+
return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
|
295 |
+
elif model_type == "v":
|
296 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
297 |
+
return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
|
298 |
+
elif model_type == "score":
|
299 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
300 |
+
return -expand_dims(sigma_t, x.dim()) * output
|
301 |
+
|
302 |
+
def cond_grad_fn(x, t_input):
|
303 |
+
"""
|
304 |
+
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
305 |
+
"""
|
306 |
+
with torch.enable_grad():
|
307 |
+
x_in = x.detach().requires_grad_(True)
|
308 |
+
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
309 |
+
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
310 |
+
|
311 |
+
def model_fn(x, t_continuous):
|
312 |
+
"""
|
313 |
+
The noise predicition model function that is used for DPM-Solver.
|
314 |
+
"""
|
315 |
+
if guidance_type == "uncond":
|
316 |
+
return noise_pred_fn(x, t_continuous)
|
317 |
+
elif guidance_type == "classifier":
|
318 |
+
assert classifier_fn is not None
|
319 |
+
t_input = get_model_input_time(t_continuous)
|
320 |
+
cond_grad = cond_grad_fn(x, t_input)
|
321 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
322 |
+
noise = noise_pred_fn(x, t_continuous)
|
323 |
+
return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
|
324 |
+
elif guidance_type == "classifier-free":
|
325 |
+
if guidance_scale == 1. or unconditional_condition is None:
|
326 |
+
return noise_pred_fn(x, t_continuous, cond=condition)
|
327 |
+
else:
|
328 |
+
x_in = torch.cat([x] * 2)
|
329 |
+
t_in = torch.cat([t_continuous] * 2)
|
330 |
+
c_in = torch.cat([unconditional_condition, condition])
|
331 |
+
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
332 |
+
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
333 |
+
|
334 |
+
assert model_type in ["noise", "x_start", "v", "score"]
|
335 |
+
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
336 |
+
return model_fn
|
337 |
+
|
338 |
+
|
339 |
+
class DPM_Solver:
|
340 |
+
def __init__(
|
341 |
+
self,
|
342 |
+
model_fn,
|
343 |
+
noise_schedule,
|
344 |
+
algorithm_type="dpmsolver++",
|
345 |
+
correcting_x0_fn=None,
|
346 |
+
correcting_xt_fn=None,
|
347 |
+
thresholding_max_val=1.,
|
348 |
+
dynamic_thresholding_ratio=0.995,
|
349 |
+
):
|
350 |
+
"""Construct a DPM-Solver.
|
351 |
+
|
352 |
+
We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
|
353 |
+
|
354 |
+
We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
|
355 |
+
can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
|
356 |
+
dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
|
357 |
+
DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
|
358 |
+
DPMs (such as stable-diffusion).
|
359 |
+
|
360 |
+
To support advanced algorithms in image-to-image applications, we also support corrector functions for
|
361 |
+
both x0 and xt.
|
362 |
+
|
363 |
+
Args:
|
364 |
+
model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
|
365 |
+
``
|
366 |
+
def model_fn(x, t_continuous):
|
367 |
+
return noise
|
368 |
+
``
|
369 |
+
The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
|
370 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
371 |
+
algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
|
372 |
+
correcting_x0_fn: A `str` or a function with the following format:
|
373 |
+
```
|
374 |
+
def correcting_x0_fn(x0, t):
|
375 |
+
x0_new = ...
|
376 |
+
return x0_new
|
377 |
+
```
|
378 |
+
This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
|
379 |
+
```
|
380 |
+
x0_pred = data_pred_model(xt, t)
|
381 |
+
if correcting_x0_fn is not None:
|
382 |
+
x0_pred = correcting_x0_fn(x0_pred, t)
|
383 |
+
xt_1 = update(x0_pred, xt, t)
|
384 |
+
```
|
385 |
+
If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
|
386 |
+
correcting_xt_fn: A function with the following format:
|
387 |
+
```
|
388 |
+
def correcting_xt_fn(xt, t, step):
|
389 |
+
x_new = ...
|
390 |
+
return x_new
|
391 |
+
```
|
392 |
+
This function is to correct the intermediate samples xt at each sampling step. e.g.,
|
393 |
+
```
|
394 |
+
xt = ...
|
395 |
+
xt = correcting_xt_fn(xt, t, step)
|
396 |
+
```
|
397 |
+
thresholding_max_val: A `float`. The max value for thresholding.
|
398 |
+
Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
|
399 |
+
dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
|
400 |
+
Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
|
401 |
+
|
402 |
+
[1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
|
403 |
+
Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
|
404 |
+
with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
|
405 |
+
"""
|
406 |
+
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
|
407 |
+
self.noise_schedule = noise_schedule
|
408 |
+
assert algorithm_type in ["dpmsolver", "dpmsolver++"]
|
409 |
+
self.algorithm_type = algorithm_type
|
410 |
+
if correcting_x0_fn == "dynamic_thresholding":
|
411 |
+
self.correcting_x0_fn = self.dynamic_thresholding_fn
|
412 |
+
else:
|
413 |
+
self.correcting_x0_fn = correcting_x0_fn
|
414 |
+
self.correcting_xt_fn = correcting_xt_fn
|
415 |
+
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
|
416 |
+
self.thresholding_max_val = thresholding_max_val
|
417 |
+
|
418 |
+
def dynamic_thresholding_fn(self, x0, t):
|
419 |
+
"""
|
420 |
+
The dynamic thresholding method.
|
421 |
+
"""
|
422 |
+
dims = x0.dim()
|
423 |
+
p = self.dynamic_thresholding_ratio
|
424 |
+
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
425 |
+
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
426 |
+
x0 = torch.clamp(x0, -s, s) / s
|
427 |
+
return x0
|
428 |
+
|
429 |
+
def noise_prediction_fn(self, x, t):
|
430 |
+
"""
|
431 |
+
Return the noise prediction model.
|
432 |
+
"""
|
433 |
+
return self.model(x, t)
|
434 |
+
|
435 |
+
def data_prediction_fn(self, x, t):
|
436 |
+
"""
|
437 |
+
Return the data prediction model (with corrector).
|
438 |
+
"""
|
439 |
+
noise = self.noise_prediction_fn(x, t)
|
440 |
+
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
441 |
+
x0 = (x - sigma_t * noise) / alpha_t
|
442 |
+
if self.correcting_x0_fn is not None:
|
443 |
+
x0 = self.correcting_x0_fn(x0, t)
|
444 |
+
return x0
|
445 |
+
|
446 |
+
def model_fn(self, x, t):
|
447 |
+
"""
|
448 |
+
Convert the model to the noise prediction model or the data prediction model.
|
449 |
+
"""
|
450 |
+
if self.algorithm_type == "dpmsolver++":
|
451 |
+
return self.data_prediction_fn(x, t)
|
452 |
+
else:
|
453 |
+
return self.noise_prediction_fn(x, t)
|
454 |
+
|
455 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
456 |
+
"""Compute the intermediate time steps for sampling.
|
457 |
+
|
458 |
+
Args:
|
459 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
460 |
+
- 'logSNR': uniform logSNR for the time steps.
|
461 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
462 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
463 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
464 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
465 |
+
N: A `int`. The total number of the spacing of the time steps.
|
466 |
+
device: A torch device.
|
467 |
+
Returns:
|
468 |
+
A pytorch tensor of the time steps, with the shape (N + 1,).
|
469 |
+
"""
|
470 |
+
if skip_type == 'logSNR':
|
471 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
472 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
473 |
+
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
474 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
475 |
+
elif skip_type == 'time_uniform':
|
476 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
477 |
+
elif skip_type == 'time_quadratic':
|
478 |
+
t_order = 2
|
479 |
+
t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
|
480 |
+
return t
|
481 |
+
else:
|
482 |
+
raise ValueError(
|
483 |
+
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
484 |
+
|
485 |
+
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
486 |
+
"""
|
487 |
+
Get the order of each step for sampling by the singlestep DPM-Solver.
|
488 |
+
|
489 |
+
We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
|
490 |
+
Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
|
491 |
+
- If order == 1:
|
492 |
+
We take `steps` of DPM-Solver-1 (i.e. DDIM).
|
493 |
+
- If order == 2:
|
494 |
+
- Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
|
495 |
+
- If steps % 2 == 0, we use K steps of DPM-Solver-2.
|
496 |
+
- If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
497 |
+
- If order == 3:
|
498 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
499 |
+
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
500 |
+
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
|
501 |
+
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
|
502 |
+
|
503 |
+
============================================
|
504 |
+
Args:
|
505 |
+
order: A `int`. The max order for the solver (2 or 3).
|
506 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
507 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
508 |
+
- 'logSNR': uniform logSNR for the time steps.
|
509 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
510 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
511 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
512 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
513 |
+
device: A torch device.
|
514 |
+
Returns:
|
515 |
+
orders: A list of the solver order of each step.
|
516 |
+
"""
|
517 |
+
if order == 3:
|
518 |
+
K = steps // 3 + 1
|
519 |
+
if steps % 3 == 0:
|
520 |
+
orders = [3, ] * (K - 2) + [2, 1]
|
521 |
+
elif steps % 3 == 1:
|
522 |
+
orders = [3, ] * (K - 1) + [1]
|
523 |
+
else:
|
524 |
+
orders = [3, ] * (K - 1) + [2]
|
525 |
+
elif order == 2:
|
526 |
+
if steps % 2 == 0:
|
527 |
+
K = steps // 2
|
528 |
+
orders = [2, ] * K
|
529 |
+
else:
|
530 |
+
K = steps // 2 + 1
|
531 |
+
orders = [2, ] * (K - 1) + [1]
|
532 |
+
elif order == 1:
|
533 |
+
K = 1
|
534 |
+
orders = [1, ] * steps
|
535 |
+
else:
|
536 |
+
raise ValueError("'order' must be '1' or '2' or '3'.")
|
537 |
+
if skip_type == 'logSNR':
|
538 |
+
# To reproduce the results in DPM-Solver paper
|
539 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
540 |
+
else:
|
541 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
|
542 |
+
torch.cumsum(torch.tensor([0, ] + orders), 0).to(device)]
|
543 |
+
return timesteps_outer, orders
|
544 |
+
|
545 |
+
def denoise_to_zero_fn(self, x, s):
|
546 |
+
"""
|
547 |
+
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
548 |
+
"""
|
549 |
+
return self.data_prediction_fn(x, s)
|
550 |
+
|
551 |
+
def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
|
552 |
+
"""
|
553 |
+
DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
|
554 |
+
|
555 |
+
Args:
|
556 |
+
x: A pytorch tensor. The initial value at time `s`.
|
557 |
+
s: A pytorch tensor. The starting time, with the shape (1,).
|
558 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
559 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
560 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
561 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`.
|
562 |
+
Returns:
|
563 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
564 |
+
"""
|
565 |
+
ns = self.noise_schedule
|
566 |
+
dims = x.dim()
|
567 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
568 |
+
h = lambda_t - lambda_s
|
569 |
+
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
|
570 |
+
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
|
571 |
+
alpha_t = torch.exp(log_alpha_t)
|
572 |
+
|
573 |
+
if self.algorithm_type == "dpmsolver++":
|
574 |
+
phi_1 = torch.expm1(-h)
|
575 |
+
if model_s is None:
|
576 |
+
model_s = self.model_fn(x, s)
|
577 |
+
x_t = (
|
578 |
+
sigma_t / sigma_s * x
|
579 |
+
- alpha_t * phi_1 * model_s
|
580 |
+
)
|
581 |
+
if return_intermediate:
|
582 |
+
return x_t, {'model_s': model_s}
|
583 |
+
else:
|
584 |
+
return x_t
|
585 |
+
else:
|
586 |
+
phi_1 = torch.expm1(h)
|
587 |
+
if model_s is None:
|
588 |
+
model_s = self.model_fn(x, s)
|
589 |
+
x_t = (
|
590 |
+
torch.exp(log_alpha_t - log_alpha_s) * x
|
591 |
+
- (sigma_t * phi_1) * model_s
|
592 |
+
)
|
593 |
+
if return_intermediate:
|
594 |
+
return x_t, {'model_s': model_s}
|
595 |
+
else:
|
596 |
+
return x_t
|
597 |
+
|
598 |
+
def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
|
599 |
+
solver_type='dpmsolver'):
|
600 |
+
"""
|
601 |
+
Singlestep solver DPM-Solver-2 from time `s` to time `t`.
|
602 |
+
|
603 |
+
Args:
|
604 |
+
x: A pytorch tensor. The initial value at time `s`.
|
605 |
+
s: A pytorch tensor. The starting time, with the shape (1,).
|
606 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
607 |
+
r1: A `float`. The hyperparameter of the second-order solver.
|
608 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
609 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
610 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
|
611 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
612 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
613 |
+
Returns:
|
614 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
615 |
+
"""
|
616 |
+
if solver_type not in ['dpmsolver', 'taylor']:
|
617 |
+
raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
|
618 |
+
if r1 is None:
|
619 |
+
r1 = 0.5
|
620 |
+
ns = self.noise_schedule
|
621 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
622 |
+
h = lambda_t - lambda_s
|
623 |
+
lambda_s1 = lambda_s + r1 * h
|
624 |
+
s1 = ns.inverse_lambda(lambda_s1)
|
625 |
+
log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
|
626 |
+
s1), ns.marginal_log_mean_coeff(t)
|
627 |
+
sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
|
628 |
+
alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
|
629 |
+
|
630 |
+
if self.algorithm_type == "dpmsolver++":
|
631 |
+
phi_11 = torch.expm1(-r1 * h)
|
632 |
+
phi_1 = torch.expm1(-h)
|
633 |
+
|
634 |
+
if model_s is None:
|
635 |
+
model_s = self.model_fn(x, s)
|
636 |
+
x_s1 = (
|
637 |
+
(sigma_s1 / sigma_s) * x
|
638 |
+
- (alpha_s1 * phi_11) * model_s
|
639 |
+
)
|
640 |
+
model_s1 = self.model_fn(x_s1, s1)
|
641 |
+
if solver_type == 'dpmsolver':
|
642 |
+
x_t = (
|
643 |
+
(sigma_t / sigma_s) * x
|
644 |
+
- (alpha_t * phi_1) * model_s
|
645 |
+
- (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
|
646 |
+
)
|
647 |
+
elif solver_type == 'taylor':
|
648 |
+
x_t = (
|
649 |
+
(sigma_t / sigma_s) * x
|
650 |
+
- (alpha_t * phi_1) * model_s
|
651 |
+
+ (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s)
|
652 |
+
)
|
653 |
+
else:
|
654 |
+
phi_11 = torch.expm1(r1 * h)
|
655 |
+
phi_1 = torch.expm1(h)
|
656 |
+
|
657 |
+
if model_s is None:
|
658 |
+
model_s = self.model_fn(x, s)
|
659 |
+
x_s1 = (
|
660 |
+
torch.exp(log_alpha_s1 - log_alpha_s) * x
|
661 |
+
- (sigma_s1 * phi_11) * model_s
|
662 |
+
)
|
663 |
+
model_s1 = self.model_fn(x_s1, s1)
|
664 |
+
if solver_type == 'dpmsolver':
|
665 |
+
x_t = (
|
666 |
+
torch.exp(log_alpha_t - log_alpha_s) * x
|
667 |
+
- (sigma_t * phi_1) * model_s
|
668 |
+
- (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
|
669 |
+
)
|
670 |
+
elif solver_type == 'taylor':
|
671 |
+
x_t = (
|
672 |
+
torch.exp(log_alpha_t - log_alpha_s) * x
|
673 |
+
- (sigma_t * phi_1) * model_s
|
674 |
+
- (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s)
|
675 |
+
)
|
676 |
+
if return_intermediate:
|
677 |
+
return x_t, {'model_s': model_s, 'model_s1': model_s1}
|
678 |
+
else:
|
679 |
+
return x_t
|
680 |
+
|
681 |
+
def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
|
682 |
+
return_intermediate=False, solver_type='dpmsolver'):
|
683 |
+
"""
|
684 |
+
Singlestep solver DPM-Solver-3 from time `s` to time `t`.
|
685 |
+
|
686 |
+
Args:
|
687 |
+
x: A pytorch tensor. The initial value at time `s`.
|
688 |
+
s: A pytorch tensor. The starting time, with the shape (1,).
|
689 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
690 |
+
r1: A `float`. The hyperparameter of the third-order solver.
|
691 |
+
r2: A `float`. The hyperparameter of the third-order solver.
|
692 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
693 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
694 |
+
model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
|
695 |
+
If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
|
696 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
|
697 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
698 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
699 |
+
Returns:
|
700 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
701 |
+
"""
|
702 |
+
if solver_type not in ['dpmsolver', 'taylor']:
|
703 |
+
raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
|
704 |
+
if r1 is None:
|
705 |
+
r1 = 1. / 3.
|
706 |
+
if r2 is None:
|
707 |
+
r2 = 2. / 3.
|
708 |
+
ns = self.noise_schedule
|
709 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
710 |
+
h = lambda_t - lambda_s
|
711 |
+
lambda_s1 = lambda_s + r1 * h
|
712 |
+
lambda_s2 = lambda_s + r2 * h
|
713 |
+
s1 = ns.inverse_lambda(lambda_s1)
|
714 |
+
s2 = ns.inverse_lambda(lambda_s2)
|
715 |
+
log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
|
716 |
+
s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
|
717 |
+
sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
|
718 |
+
s2), ns.marginal_std(t)
|
719 |
+
alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
|
720 |
+
|
721 |
+
if self.algorithm_type == "dpmsolver++":
|
722 |
+
phi_11 = torch.expm1(-r1 * h)
|
723 |
+
phi_12 = torch.expm1(-r2 * h)
|
724 |
+
phi_1 = torch.expm1(-h)
|
725 |
+
phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
|
726 |
+
phi_2 = phi_1 / h + 1.
|
727 |
+
phi_3 = phi_2 / h - 0.5
|
728 |
+
|
729 |
+
if model_s is None:
|
730 |
+
model_s = self.model_fn(x, s)
|
731 |
+
if model_s1 is None:
|
732 |
+
x_s1 = (
|
733 |
+
(sigma_s1 / sigma_s) * x
|
734 |
+
- (alpha_s1 * phi_11) * model_s
|
735 |
+
)
|
736 |
+
model_s1 = self.model_fn(x_s1, s1)
|
737 |
+
x_s2 = (
|
738 |
+
(sigma_s2 / sigma_s) * x
|
739 |
+
- (alpha_s2 * phi_12) * model_s
|
740 |
+
+ r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
|
741 |
+
)
|
742 |
+
model_s2 = self.model_fn(x_s2, s2)
|
743 |
+
if solver_type == 'dpmsolver':
|
744 |
+
x_t = (
|
745 |
+
(sigma_t / sigma_s) * x
|
746 |
+
- (alpha_t * phi_1) * model_s
|
747 |
+
+ (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
|
748 |
+
)
|
749 |
+
elif solver_type == 'taylor':
|
750 |
+
D1_0 = (1. / r1) * (model_s1 - model_s)
|
751 |
+
D1_1 = (1. / r2) * (model_s2 - model_s)
|
752 |
+
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
753 |
+
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
|
754 |
+
x_t = (
|
755 |
+
(sigma_t / sigma_s) * x
|
756 |
+
- (alpha_t * phi_1) * model_s
|
757 |
+
+ (alpha_t * phi_2) * D1
|
758 |
+
- (alpha_t * phi_3) * D2
|
759 |
+
)
|
760 |
+
else:
|
761 |
+
phi_11 = torch.expm1(r1 * h)
|
762 |
+
phi_12 = torch.expm1(r2 * h)
|
763 |
+
phi_1 = torch.expm1(h)
|
764 |
+
phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
|
765 |
+
phi_2 = phi_1 / h - 1.
|
766 |
+
phi_3 = phi_2 / h - 0.5
|
767 |
+
|
768 |
+
if model_s is None:
|
769 |
+
model_s = self.model_fn(x, s)
|
770 |
+
if model_s1 is None:
|
771 |
+
x_s1 = (
|
772 |
+
(torch.exp(log_alpha_s1 - log_alpha_s)) * x
|
773 |
+
- (sigma_s1 * phi_11) * model_s
|
774 |
+
)
|
775 |
+
model_s1 = self.model_fn(x_s1, s1)
|
776 |
+
x_s2 = (
|
777 |
+
(torch.exp(log_alpha_s2 - log_alpha_s)) * x
|
778 |
+
- (sigma_s2 * phi_12) * model_s
|
779 |
+
- r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
|
780 |
+
)
|
781 |
+
model_s2 = self.model_fn(x_s2, s2)
|
782 |
+
if solver_type == 'dpmsolver':
|
783 |
+
x_t = (
|
784 |
+
(torch.exp(log_alpha_t - log_alpha_s)) * x
|
785 |
+
- (sigma_t * phi_1) * model_s
|
786 |
+
- (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
|
787 |
+
)
|
788 |
+
elif solver_type == 'taylor':
|
789 |
+
D1_0 = (1. / r1) * (model_s1 - model_s)
|
790 |
+
D1_1 = (1. / r2) * (model_s2 - model_s)
|
791 |
+
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
792 |
+
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
|
793 |
+
x_t = (
|
794 |
+
(torch.exp(log_alpha_t - log_alpha_s)) * x
|
795 |
+
- (sigma_t * phi_1) * model_s
|
796 |
+
- (sigma_t * phi_2) * D1
|
797 |
+
- (sigma_t * phi_3) * D2
|
798 |
+
)
|
799 |
+
|
800 |
+
if return_intermediate:
|
801 |
+
return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
|
802 |
+
else:
|
803 |
+
return x_t
|
804 |
+
|
805 |
+
def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
|
806 |
+
"""
|
807 |
+
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
|
808 |
+
|
809 |
+
Args:
|
810 |
+
x: A pytorch tensor. The initial value at time `s`.
|
811 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
812 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
|
813 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
814 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
815 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
816 |
+
Returns:
|
817 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
818 |
+
"""
|
819 |
+
if solver_type not in ['dpmsolver', 'taylor']:
|
820 |
+
raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
|
821 |
+
ns = self.noise_schedule
|
822 |
+
model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
|
823 |
+
t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
|
824 |
+
lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
|
825 |
+
t_prev_0), ns.marginal_lambda(t)
|
826 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
827 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
828 |
+
alpha_t = torch.exp(log_alpha_t)
|
829 |
+
|
830 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
831 |
+
h = lambda_t - lambda_prev_0
|
832 |
+
r0 = h_0 / h
|
833 |
+
D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
|
834 |
+
if self.algorithm_type == "dpmsolver++":
|
835 |
+
phi_1 = torch.expm1(-h)
|
836 |
+
if solver_type == 'dpmsolver':
|
837 |
+
x_t = (
|
838 |
+
(sigma_t / sigma_prev_0) * x
|
839 |
+
- (alpha_t * phi_1) * model_prev_0
|
840 |
+
- 0.5 * (alpha_t * phi_1) * D1_0
|
841 |
+
)
|
842 |
+
elif solver_type == 'taylor':
|
843 |
+
x_t = (
|
844 |
+
(sigma_t / sigma_prev_0) * x
|
845 |
+
- (alpha_t * phi_1) * model_prev_0
|
846 |
+
+ (alpha_t * (phi_1 / h + 1.)) * D1_0
|
847 |
+
)
|
848 |
+
else:
|
849 |
+
phi_1 = torch.expm1(h)
|
850 |
+
if solver_type == 'dpmsolver':
|
851 |
+
x_t = (
|
852 |
+
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
853 |
+
- (sigma_t * phi_1) * model_prev_0
|
854 |
+
- 0.5 * (sigma_t * phi_1) * D1_0
|
855 |
+
)
|
856 |
+
elif solver_type == 'taylor':
|
857 |
+
x_t = (
|
858 |
+
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
859 |
+
- (sigma_t * phi_1) * model_prev_0
|
860 |
+
- (sigma_t * (phi_1 / h - 1.)) * D1_0
|
861 |
+
)
|
862 |
+
return x_t
|
863 |
+
|
864 |
+
def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'):
|
865 |
+
"""
|
866 |
+
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
|
867 |
+
|
868 |
+
Args:
|
869 |
+
x: A pytorch tensor. The initial value at time `s`.
|
870 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
871 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
|
872 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
873 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
874 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
875 |
+
Returns:
|
876 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
877 |
+
"""
|
878 |
+
ns = self.noise_schedule
|
879 |
+
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
|
880 |
+
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
|
881 |
+
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
|
882 |
+
t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
|
883 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
884 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
885 |
+
alpha_t = torch.exp(log_alpha_t)
|
886 |
+
|
887 |
+
h_1 = lambda_prev_1 - lambda_prev_2
|
888 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
889 |
+
h = lambda_t - lambda_prev_0
|
890 |
+
r0, r1 = h_0 / h, h_1 / h
|
891 |
+
D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
|
892 |
+
D1_1 = (1. / r1) * (model_prev_1 - model_prev_2)
|
893 |
+
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
894 |
+
D2 = (1. / (r0 + r1)) * (D1_0 - D1_1)
|
895 |
+
if self.algorithm_type == "dpmsolver++":
|
896 |
+
phi_1 = torch.expm1(-h)
|
897 |
+
phi_2 = phi_1 / h + 1.
|
898 |
+
phi_3 = phi_2 / h - 0.5
|
899 |
+
x_t = (
|
900 |
+
(sigma_t / sigma_prev_0) * x
|
901 |
+
- (alpha_t * phi_1) * model_prev_0
|
902 |
+
+ (alpha_t * phi_2) * D1
|
903 |
+
- (alpha_t * phi_3) * D2
|
904 |
+
)
|
905 |
+
else:
|
906 |
+
phi_1 = torch.expm1(h)
|
907 |
+
phi_2 = phi_1 / h - 1.
|
908 |
+
phi_3 = phi_2 / h - 0.5
|
909 |
+
x_t = (
|
910 |
+
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
911 |
+
- (sigma_t * phi_1) * model_prev_0
|
912 |
+
- (sigma_t * phi_2) * D1
|
913 |
+
- (sigma_t * phi_3) * D2
|
914 |
+
)
|
915 |
+
return x_t
|
916 |
+
|
917 |
+
def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None,
|
918 |
+
r2=None):
|
919 |
+
"""
|
920 |
+
Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
|
921 |
+
|
922 |
+
Args:
|
923 |
+
x: A pytorch tensor. The initial value at time `s`.
|
924 |
+
s: A pytorch tensor. The starting time, with the shape (1,).
|
925 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
926 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
927 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
|
928 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
929 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
930 |
+
r1: A `float`. The hyperparameter of the second-order or third-order solver.
|
931 |
+
r2: A `float`. The hyperparameter of the third-order solver.
|
932 |
+
Returns:
|
933 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
934 |
+
"""
|
935 |
+
if order == 1:
|
936 |
+
return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
|
937 |
+
elif order == 2:
|
938 |
+
return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
|
939 |
+
solver_type=solver_type, r1=r1)
|
940 |
+
elif order == 3:
|
941 |
+
return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
|
942 |
+
solver_type=solver_type, r1=r1, r2=r2)
|
943 |
+
else:
|
944 |
+
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
945 |
+
|
946 |
+
def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'):
|
947 |
+
"""
|
948 |
+
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
|
949 |
+
|
950 |
+
Args:
|
951 |
+
x: A pytorch tensor. The initial value at time `s`.
|
952 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
953 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
|
954 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
955 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
956 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
957 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
958 |
+
Returns:
|
959 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
960 |
+
"""
|
961 |
+
if order == 1:
|
962 |
+
return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
|
963 |
+
elif order == 2:
|
964 |
+
return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
|
965 |
+
elif order == 3:
|
966 |
+
return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
|
967 |
+
else:
|
968 |
+
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
969 |
+
|
970 |
+
def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
|
971 |
+
solver_type='dpmsolver'):
|
972 |
+
"""
|
973 |
+
The adaptive step size solver based on singlestep DPM-Solver.
|
974 |
+
|
975 |
+
Args:
|
976 |
+
x: A pytorch tensor. The initial value at time `t_T`.
|
977 |
+
order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
|
978 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
979 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
980 |
+
h_init: A `float`. The initial step size (for logSNR).
|
981 |
+
atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
|
982 |
+
rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
|
983 |
+
theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
|
984 |
+
t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
|
985 |
+
current time and `t_0` is less than `t_err`. The default setting is 1e-5.
|
986 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
987 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
988 |
+
Returns:
|
989 |
+
x_0: A pytorch tensor. The approximated solution at time `t_0`.
|
990 |
+
|
991 |
+
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
|
992 |
+
"""
|
993 |
+
ns = self.noise_schedule
|
994 |
+
s = t_T * torch.ones((1,)).to(x)
|
995 |
+
lambda_s = ns.marginal_lambda(s)
|
996 |
+
lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
|
997 |
+
h = h_init * torch.ones_like(s).to(x)
|
998 |
+
x_prev = x
|
999 |
+
nfe = 0
|
1000 |
+
if order == 2:
|
1001 |
+
r1 = 0.5
|
1002 |
+
lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
|
1003 |
+
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
|
1004 |
+
solver_type=solver_type,
|
1005 |
+
**kwargs)
|
1006 |
+
elif order == 3:
|
1007 |
+
r1, r2 = 1. / 3., 2. / 3.
|
1008 |
+
lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
|
1009 |
+
return_intermediate=True,
|
1010 |
+
solver_type=solver_type)
|
1011 |
+
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
|
1012 |
+
solver_type=solver_type,
|
1013 |
+
**kwargs)
|
1014 |
+
else:
|
1015 |
+
raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
|
1016 |
+
while torch.abs((s - t_0)).mean() > t_err:
|
1017 |
+
t = ns.inverse_lambda(lambda_s + h)
|
1018 |
+
x_lower, lower_noise_kwargs = lower_update(x, s, t)
|
1019 |
+
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
|
1020 |
+
delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
|
1021 |
+
norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
|
1022 |
+
E = norm_fn((x_higher - x_lower) / delta).max()
|
1023 |
+
if torch.all(E <= 1.):
|
1024 |
+
x = x_higher
|
1025 |
+
s = t
|
1026 |
+
x_prev = x_lower
|
1027 |
+
lambda_s = ns.marginal_lambda(s)
|
1028 |
+
h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
|
1029 |
+
nfe += order
|
1030 |
+
print('adaptive solver nfe', nfe)
|
1031 |
+
return x
|
1032 |
+
|
1033 |
+
def add_noise(self, x, t, noise=None):
|
1034 |
+
"""
|
1035 |
+
Compute the noised input xt = alpha_t * x + sigma_t * noise.
|
1036 |
+
|
1037 |
+
Args:
|
1038 |
+
x: A `torch.Tensor` with shape `(batch_size, *shape)`.
|
1039 |
+
t: A `torch.Tensor` with shape `(t_size,)`.
|
1040 |
+
Returns:
|
1041 |
+
xt with shape `(t_size, batch_size, *shape)`.
|
1042 |
+
"""
|
1043 |
+
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
1044 |
+
if noise is None:
|
1045 |
+
noise = torch.randn((t.shape[0], *x.shape), device=x.device)
|
1046 |
+
x = x.reshape((-1, *x.shape))
|
1047 |
+
xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
|
1048 |
+
if t.shape[0] == 1:
|
1049 |
+
return xt.squeeze(0)
|
1050 |
+
else:
|
1051 |
+
return xt
|
1052 |
+
|
1053 |
+
def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
|
1054 |
+
method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
|
1055 |
+
atol=0.0078, rtol=0.05, return_intermediate=False,
|
1056 |
+
):
|
1057 |
+
"""
|
1058 |
+
Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
|
1059 |
+
For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
|
1060 |
+
"""
|
1061 |
+
t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start
|
1062 |
+
t_T = self.noise_schedule.T if t_end is None else t_end
|
1063 |
+
assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
1064 |
+
return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type,
|
1065 |
+
method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero,
|
1066 |
+
solver_type=solver_type,
|
1067 |
+
atol=atol, rtol=rtol, return_intermediate=return_intermediate)
|
1068 |
+
|
1069 |
+
def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
|
1070 |
+
method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
|
1071 |
+
atol=0.0078, rtol=0.05, return_intermediate=False,
|
1072 |
+
):
|
1073 |
+
"""
|
1074 |
+
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
|
1075 |
+
|
1076 |
+
=====================================================
|
1077 |
+
|
1078 |
+
We support the following algorithms for both noise prediction model and data prediction model:
|
1079 |
+
- 'singlestep':
|
1080 |
+
Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
|
1081 |
+
We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
|
1082 |
+
The total number of function evaluations (NFE) == `steps`.
|
1083 |
+
Given a fixed NFE == `steps`, the sampling procedure is:
|
1084 |
+
- If `order` == 1:
|
1085 |
+
- Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
|
1086 |
+
- If `order` == 2:
|
1087 |
+
- Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
|
1088 |
+
- If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
|
1089 |
+
- If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
|
1090 |
+
- If `order` == 3:
|
1091 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
1092 |
+
- If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
|
1093 |
+
- If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
|
1094 |
+
- If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
|
1095 |
+
- 'multistep':
|
1096 |
+
Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
|
1097 |
+
We initialize the first `order` values by lower order multistep solvers.
|
1098 |
+
Given a fixed NFE == `steps`, the sampling procedure is:
|
1099 |
+
Denote K = steps.
|
1100 |
+
- If `order` == 1:
|
1101 |
+
- We use K steps of DPM-Solver-1 (i.e. DDIM).
|
1102 |
+
- If `order` == 2:
|
1103 |
+
- We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
|
1104 |
+
- If `order` == 3:
|
1105 |
+
- We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
|
1106 |
+
- 'singlestep_fixed':
|
1107 |
+
Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
|
1108 |
+
We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
|
1109 |
+
- 'adaptive':
|
1110 |
+
Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
|
1111 |
+
We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
|
1112 |
+
You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
|
1113 |
+
(NFE) and the sample quality.
|
1114 |
+
- If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
|
1115 |
+
- If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
|
1116 |
+
|
1117 |
+
=====================================================
|
1118 |
+
|
1119 |
+
Some advices for choosing the algorithm:
|
1120 |
+
- For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
|
1121 |
+
Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
|
1122 |
+
e.g., DPM-Solver:
|
1123 |
+
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
|
1124 |
+
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
|
1125 |
+
skip_type='time_uniform', method='singlestep')
|
1126 |
+
e.g., DPM-Solver++:
|
1127 |
+
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
|
1128 |
+
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
|
1129 |
+
skip_type='time_uniform', method='singlestep')
|
1130 |
+
- For **guided sampling with large guidance scale** by DPMs:
|
1131 |
+
Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
|
1132 |
+
e.g.
|
1133 |
+
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
|
1134 |
+
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
|
1135 |
+
skip_type='time_uniform', method='multistep')
|
1136 |
+
|
1137 |
+
We support three types of `skip_type`:
|
1138 |
+
- 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
|
1139 |
+
- 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
|
1140 |
+
- 'time_quadratic': quadratic time for the time steps.
|
1141 |
+
|
1142 |
+
=====================================================
|
1143 |
+
Args:
|
1144 |
+
x: A pytorch tensor. The initial value at time `t_start`
|
1145 |
+
e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
|
1146 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
1147 |
+
t_start: A `float`. The starting time of the sampling.
|
1148 |
+
If `T` is None, we use self.noise_schedule.T (default is 1.0).
|
1149 |
+
t_end: A `float`. The ending time of the sampling.
|
1150 |
+
If `t_end` is None, we use 1. / self.noise_schedule.total_N.
|
1151 |
+
e.g. if total_N == 1000, we have `t_end` == 1e-3.
|
1152 |
+
For discrete-time DPMs:
|
1153 |
+
- We recommend `t_end` == 1. / self.noise_schedule.total_N.
|
1154 |
+
For continuous-time DPMs:
|
1155 |
+
- We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
|
1156 |
+
order: A `int`. The order of DPM-Solver.
|
1157 |
+
skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
|
1158 |
+
method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
|
1159 |
+
denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
|
1160 |
+
Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
|
1161 |
+
|
1162 |
+
This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
|
1163 |
+
score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
|
1164 |
+
for diffusion models sampling by diffusion SDEs for low-resolutional images
|
1165 |
+
(such as CIFAR-10). However, we observed that such trick does not matter for
|
1166 |
+
high-resolutional images. As it needs an additional NFE, we do not recommend
|
1167 |
+
it for high-resolutional images.
|
1168 |
+
lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
|
1169 |
+
Only valid for `method=multistep` and `steps < 15`. We empirically find that
|
1170 |
+
this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
|
1171 |
+
(especially for steps <= 10). So we recommend to set it to be `True`.
|
1172 |
+
solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
|
1173 |
+
atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
|
1174 |
+
rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
|
1175 |
+
return_intermediate: A `bool`. Whether to save the xt at each step.
|
1176 |
+
When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
|
1177 |
+
Returns:
|
1178 |
+
x_end: A pytorch tensor. The approximated solution at time `t_end`.
|
1179 |
+
|
1180 |
+
"""
|
1181 |
+
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
1182 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
1183 |
+
assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
1184 |
+
if return_intermediate:
|
1185 |
+
assert method in ['multistep', 'singlestep',
|
1186 |
+
'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values"
|
1187 |
+
if self.correcting_xt_fn is not None:
|
1188 |
+
assert method in ['multistep', 'singlestep',
|
1189 |
+
'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None"
|
1190 |
+
device = x.device
|
1191 |
+
intermediates = []
|
1192 |
+
with torch.no_grad():
|
1193 |
+
if method == 'adaptive':
|
1194 |
+
x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
|
1195 |
+
solver_type=solver_type)
|
1196 |
+
elif method == 'multistep':
|
1197 |
+
assert steps >= order
|
1198 |
+
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
1199 |
+
assert timesteps.shape[0] - 1 == steps
|
1200 |
+
# Init the initial values.
|
1201 |
+
step = 0
|
1202 |
+
t = timesteps[step]
|
1203 |
+
t_prev_list = [t]
|
1204 |
+
model_prev_list = [self.model_fn(x, t)]
|
1205 |
+
if self.correcting_xt_fn is not None:
|
1206 |
+
x = self.correcting_xt_fn(x, t, step)
|
1207 |
+
if return_intermediate:
|
1208 |
+
intermediates.append(x)
|
1209 |
+
# Init the first `order` values by lower order multistep DPM-Solver.
|
1210 |
+
for step in range(1, order):
|
1211 |
+
t = timesteps[step]
|
1212 |
+
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step,
|
1213 |
+
solver_type=solver_type)
|
1214 |
+
if self.correcting_xt_fn is not None:
|
1215 |
+
x = self.correcting_xt_fn(x, t, step)
|
1216 |
+
if return_intermediate:
|
1217 |
+
intermediates.append(x)
|
1218 |
+
t_prev_list.append(t)
|
1219 |
+
model_prev_list.append(self.model_fn(x, t))
|
1220 |
+
# Compute the remaining values by `order`-th order multistep DPM-Solver.
|
1221 |
+
for step in tqdm(range(order, steps + 1)):
|
1222 |
+
t = timesteps[step]
|
1223 |
+
# We only use lower order for steps < 10
|
1224 |
+
# if lower_order_final and steps < 10:
|
1225 |
+
if lower_order_final: # recommended by Shuchen Xue
|
1226 |
+
step_order = min(order, steps + 1 - step)
|
1227 |
+
else:
|
1228 |
+
step_order = order
|
1229 |
+
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order,
|
1230 |
+
solver_type=solver_type)
|
1231 |
+
if self.correcting_xt_fn is not None:
|
1232 |
+
x = self.correcting_xt_fn(x, t, step)
|
1233 |
+
if return_intermediate:
|
1234 |
+
intermediates.append(x)
|
1235 |
+
for i in range(order - 1):
|
1236 |
+
t_prev_list[i] = t_prev_list[i + 1]
|
1237 |
+
model_prev_list[i] = model_prev_list[i + 1]
|
1238 |
+
t_prev_list[-1] = t
|
1239 |
+
# We do not need to evaluate the final model value.
|
1240 |
+
if step < steps:
|
1241 |
+
model_prev_list[-1] = self.model_fn(x, t)
|
1242 |
+
elif method in ['singlestep', 'singlestep_fixed']:
|
1243 |
+
if method == 'singlestep':
|
1244 |
+
timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps,
|
1245 |
+
order=order,
|
1246 |
+
skip_type=skip_type,
|
1247 |
+
t_T=t_T, t_0=t_0,
|
1248 |
+
device=device)
|
1249 |
+
elif method == 'singlestep_fixed':
|
1250 |
+
K = steps // order
|
1251 |
+
orders = [order, ] * K
|
1252 |
+
timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
|
1253 |
+
for step, order in enumerate(orders):
|
1254 |
+
s, t = timesteps_outer[step], timesteps_outer[step + 1]
|
1255 |
+
timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order,
|
1256 |
+
device=device)
|
1257 |
+
lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
|
1258 |
+
h = lambda_inner[-1] - lambda_inner[0]
|
1259 |
+
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
|
1260 |
+
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
|
1261 |
+
x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
|
1262 |
+
if self.correcting_xt_fn is not None:
|
1263 |
+
x = self.correcting_xt_fn(x, t, step)
|
1264 |
+
if return_intermediate:
|
1265 |
+
intermediates.append(x)
|
1266 |
+
else:
|
1267 |
+
raise ValueError("Got wrong method {}".format(method))
|
1268 |
+
if denoise_to_zero:
|
1269 |
+
t = torch.ones((1,)).to(device) * t_0
|
1270 |
+
x = self.denoise_to_zero_fn(x, t)
|
1271 |
+
if self.correcting_xt_fn is not None:
|
1272 |
+
x = self.correcting_xt_fn(x, t, step + 1)
|
1273 |
+
if return_intermediate:
|
1274 |
+
intermediates.append(x)
|
1275 |
+
if return_intermediate:
|
1276 |
+
return x, intermediates
|
1277 |
+
else:
|
1278 |
+
return x
|
1279 |
+
|
1280 |
+
|
1281 |
+
#############################################################
|
1282 |
+
# other utility functions
|
1283 |
+
#############################################################
|
1284 |
+
|
1285 |
+
def interpolate_fn(x, xp, yp):
|
1286 |
+
"""
|
1287 |
+
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
1288 |
+
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
1289 |
+
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
1290 |
+
|
1291 |
+
Args:
|
1292 |
+
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
1293 |
+
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
1294 |
+
yp: PyTorch tensor with shape [C, K].
|
1295 |
+
Returns:
|
1296 |
+
The function values f(x), with shape [N, C].
|
1297 |
+
"""
|
1298 |
+
N, K = x.shape[0], xp.shape[1]
|
1299 |
+
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
1300 |
+
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
1301 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
1302 |
+
cand_start_idx = x_idx - 1
|
1303 |
+
start_idx = torch.where(
|
1304 |
+
torch.eq(x_idx, 0),
|
1305 |
+
torch.tensor(1, device=x.device),
|
1306 |
+
torch.where(
|
1307 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
1308 |
+
),
|
1309 |
+
)
|
1310 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
1311 |
+
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
1312 |
+
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
1313 |
+
start_idx2 = torch.where(
|
1314 |
+
torch.eq(x_idx, 0),
|
1315 |
+
torch.tensor(0, device=x.device),
|
1316 |
+
torch.where(
|
1317 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
1318 |
+
),
|
1319 |
+
)
|
1320 |
+
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
1321 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
1322 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
1323 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
1324 |
+
return cand
|
1325 |
+
|
1326 |
+
|
1327 |
+
def expand_dims(v, dims):
|
1328 |
+
"""
|
1329 |
+
Expand the tensor `v` to the dim `dims`.
|
1330 |
+
|
1331 |
+
Args:
|
1332 |
+
`v`: a PyTorch tensor with shape [N].
|
1333 |
+
`dim`: a `int`.
|
1334 |
+
Returns:
|
1335 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
1336 |
+
"""
|
1337 |
+
return v[(...,) + (None,) * (dims - 1)]
|
diffusion/model/edm_sample.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from diffusion.model.utils import *
|
6 |
+
|
7 |
+
|
8 |
+
# ----------------------------------------------------------------------------
|
9 |
+
# Proposed EDM sampler (Algorithm 2).
|
10 |
+
|
11 |
+
def edm_sampler(
|
12 |
+
net, latents, class_labels=None, cfg_scale=None, randn_like=torch.randn_like,
|
13 |
+
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
|
14 |
+
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, **kwargs
|
15 |
+
):
|
16 |
+
# Adjust noise levels based on what's supported by the network.
|
17 |
+
sigma_min = max(sigma_min, net.sigma_min)
|
18 |
+
sigma_max = min(sigma_max, net.sigma_max)
|
19 |
+
|
20 |
+
# Time step discretization.
|
21 |
+
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
|
22 |
+
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
|
23 |
+
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
24 |
+
t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
|
25 |
+
|
26 |
+
# Main sampling loop.
|
27 |
+
x_next = latents.to(torch.float64) * t_steps[0]
|
28 |
+
for i, (t_cur, t_next) in tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:])))): # 0, ..., N-1
|
29 |
+
x_cur = x_next
|
30 |
+
|
31 |
+
# Increase noise temporarily.
|
32 |
+
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
|
33 |
+
t_hat = net.round_sigma(t_cur + gamma * t_cur)
|
34 |
+
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
|
35 |
+
|
36 |
+
# Euler step.
|
37 |
+
denoised = net(x_hat.float(), t_hat, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64)
|
38 |
+
d_cur = (x_hat - denoised) / t_hat
|
39 |
+
x_next = x_hat + (t_next - t_hat) * d_cur
|
40 |
+
|
41 |
+
# Apply 2nd order correction.
|
42 |
+
if i < num_steps - 1:
|
43 |
+
denoised = net(x_next.float(), t_next, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64)
|
44 |
+
d_prime = (x_next - denoised) / t_next
|
45 |
+
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
46 |
+
|
47 |
+
return x_next
|
48 |
+
|
49 |
+
|
50 |
+
# ----------------------------------------------------------------------------
|
51 |
+
# Generalized ablation sampler, representing the superset of all sampling
|
52 |
+
# methods discussed in the paper.
|
53 |
+
|
54 |
+
def ablation_sampler(
|
55 |
+
net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like,
|
56 |
+
num_steps=18, sigma_min=None, sigma_max=None, rho=7,
|
57 |
+
solver='heun', discretization='edm', schedule='linear', scaling='none',
|
58 |
+
epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
|
59 |
+
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
|
60 |
+
):
|
61 |
+
assert solver in ['euler', 'heun']
|
62 |
+
assert discretization in ['vp', 've', 'iddpm', 'edm']
|
63 |
+
assert schedule in ['vp', 've', 'linear']
|
64 |
+
assert scaling in ['vp', 'none']
|
65 |
+
|
66 |
+
# Helper functions for VP & VE noise level schedules.
|
67 |
+
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
|
68 |
+
vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
|
69 |
+
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (
|
70 |
+
sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
|
71 |
+
ve_sigma = lambda t: t.sqrt()
|
72 |
+
ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
|
73 |
+
ve_sigma_inv = lambda sigma: sigma ** 2
|
74 |
+
|
75 |
+
# Select default noise level range based on the specified time step discretization.
|
76 |
+
if sigma_min is None:
|
77 |
+
vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s)
|
78 |
+
sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
|
79 |
+
if sigma_max is None:
|
80 |
+
vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1)
|
81 |
+
sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
|
82 |
+
|
83 |
+
# Adjust noise levels based on what's supported by the network.
|
84 |
+
sigma_min = max(sigma_min, net.sigma_min)
|
85 |
+
sigma_max = min(sigma_max, net.sigma_max)
|
86 |
+
|
87 |
+
# Compute corresponding betas for VP.
|
88 |
+
vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
|
89 |
+
vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
|
90 |
+
|
91 |
+
# Define time steps in terms of noise level.
|
92 |
+
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
|
93 |
+
if discretization == 'vp':
|
94 |
+
orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
|
95 |
+
sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
|
96 |
+
elif discretization == 've':
|
97 |
+
orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
|
98 |
+
sigma_steps = ve_sigma(orig_t_steps)
|
99 |
+
elif discretization == 'iddpm':
|
100 |
+
u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
|
101 |
+
alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
|
102 |
+
for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
|
103 |
+
u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
|
104 |
+
u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
|
105 |
+
sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
|
106 |
+
else:
|
107 |
+
assert discretization == 'edm'
|
108 |
+
sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
|
109 |
+
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
110 |
+
|
111 |
+
# Define noise level schedule.
|
112 |
+
if schedule == 'vp':
|
113 |
+
sigma = vp_sigma(vp_beta_d, vp_beta_min)
|
114 |
+
sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
|
115 |
+
sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
|
116 |
+
elif schedule == 've':
|
117 |
+
sigma = ve_sigma
|
118 |
+
sigma_deriv = ve_sigma_deriv
|
119 |
+
sigma_inv = ve_sigma_inv
|
120 |
+
else:
|
121 |
+
assert schedule == 'linear'
|
122 |
+
sigma = lambda t: t
|
123 |
+
sigma_deriv = lambda t: 1
|
124 |
+
sigma_inv = lambda sigma: sigma
|
125 |
+
|
126 |
+
# Define scaling schedule.
|
127 |
+
if scaling == 'vp':
|
128 |
+
s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
|
129 |
+
s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
|
130 |
+
else:
|
131 |
+
assert scaling == 'none'
|
132 |
+
s = lambda t: 1
|
133 |
+
s_deriv = lambda t: 0
|
134 |
+
|
135 |
+
# Compute final time steps based on the corresponding noise levels.
|
136 |
+
t_steps = sigma_inv(net.round_sigma(sigma_steps))
|
137 |
+
t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
|
138 |
+
|
139 |
+
# Main sampling loop.
|
140 |
+
t_next = t_steps[0]
|
141 |
+
x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
|
142 |
+
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
|
143 |
+
x_cur = x_next
|
144 |
+
|
145 |
+
# Increase noise temporarily.
|
146 |
+
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
|
147 |
+
t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
|
148 |
+
x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(
|
149 |
+
t_hat) * S_noise * randn_like(x_cur)
|
150 |
+
|
151 |
+
# Euler step.
|
152 |
+
h = t_next - t_hat
|
153 |
+
denoised = net(x_hat.float() / s(t_hat), sigma(t_hat), class_labels, cfg_scale, feat=feat)['x'].to(
|
154 |
+
torch.float64)
|
155 |
+
d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(
|
156 |
+
t_hat) / sigma(t_hat) * denoised
|
157 |
+
x_prime = x_hat + alpha * h * d_cur
|
158 |
+
t_prime = t_hat + alpha * h
|
159 |
+
|
160 |
+
# Apply 2nd order correction.
|
161 |
+
if solver == 'euler' or i == num_steps - 1:
|
162 |
+
x_next = x_hat + h * d_cur
|
163 |
+
else:
|
164 |
+
assert solver == 'heun'
|
165 |
+
denoised = net(x_prime.float() / s(t_prime), sigma(t_prime), class_labels, cfg_scale, feat=feat)['x'].to(
|
166 |
+
torch.float64)
|
167 |
+
d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(
|
168 |
+
t_prime) * s(t_prime) / sigma(t_prime) * denoised
|
169 |
+
x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
|
170 |
+
|
171 |
+
return x_next
|
diffusion/model/gaussian_diffusion.py
ADDED
@@ -0,0 +1,1041 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
|
7 |
+
import enum
|
8 |
+
import math
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch as th
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
|
15 |
+
|
16 |
+
|
17 |
+
def mean_flat(tensor):
|
18 |
+
"""
|
19 |
+
Take the mean over all non-batch dimensions.
|
20 |
+
"""
|
21 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
22 |
+
|
23 |
+
|
24 |
+
class ModelMeanType(enum.Enum):
|
25 |
+
"""
|
26 |
+
Which type of output the model predicts.
|
27 |
+
"""
|
28 |
+
|
29 |
+
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
30 |
+
START_X = enum.auto() # the model predicts x_0
|
31 |
+
EPSILON = enum.auto() # the model predicts epsilon
|
32 |
+
|
33 |
+
|
34 |
+
class ModelVarType(enum.Enum):
|
35 |
+
"""
|
36 |
+
What is used as the model's output variance.
|
37 |
+
The LEARNED_RANGE option has been added to allow the model to predict
|
38 |
+
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
39 |
+
"""
|
40 |
+
|
41 |
+
LEARNED = enum.auto()
|
42 |
+
FIXED_SMALL = enum.auto()
|
43 |
+
FIXED_LARGE = enum.auto()
|
44 |
+
LEARNED_RANGE = enum.auto()
|
45 |
+
|
46 |
+
|
47 |
+
class LossType(enum.Enum):
|
48 |
+
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
49 |
+
RESCALED_MSE = (
|
50 |
+
enum.auto()
|
51 |
+
) # use raw MSE loss (with RESCALED_KL when learning variances)
|
52 |
+
KL = enum.auto() # use the variational lower-bound
|
53 |
+
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
54 |
+
|
55 |
+
def is_vb(self):
|
56 |
+
return self == LossType.KL or self == LossType.RESCALED_KL
|
57 |
+
|
58 |
+
|
59 |
+
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
60 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
61 |
+
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
62 |
+
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
63 |
+
return betas
|
64 |
+
|
65 |
+
|
66 |
+
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
67 |
+
"""
|
68 |
+
This is the deprecated API for creating beta schedules.
|
69 |
+
See get_named_beta_schedule() for the new library of schedules.
|
70 |
+
"""
|
71 |
+
if beta_schedule == "quad":
|
72 |
+
betas = (
|
73 |
+
np.linspace(
|
74 |
+
beta_start ** 0.5,
|
75 |
+
beta_end ** 0.5,
|
76 |
+
num_diffusion_timesteps,
|
77 |
+
dtype=np.float64,
|
78 |
+
)
|
79 |
+
** 2
|
80 |
+
)
|
81 |
+
elif beta_schedule == "linear":
|
82 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
83 |
+
elif beta_schedule == "warmup10":
|
84 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
85 |
+
elif beta_schedule == "warmup50":
|
86 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
|
87 |
+
elif beta_schedule == "const":
|
88 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
89 |
+
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
90 |
+
betas = 1.0 / np.linspace(
|
91 |
+
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
raise NotImplementedError(beta_schedule)
|
95 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
96 |
+
return betas
|
97 |
+
|
98 |
+
|
99 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
100 |
+
"""
|
101 |
+
Get a pre-defined beta schedule for the given name.
|
102 |
+
The beta schedule library consists of beta schedules which remain similar
|
103 |
+
in the limit of num_diffusion_timesteps.
|
104 |
+
Beta schedules may be added, but should not be removed or changed once
|
105 |
+
they are committed to maintain backwards compatibility.
|
106 |
+
"""
|
107 |
+
if schedule_name == "linear":
|
108 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
109 |
+
# diffusion steps.
|
110 |
+
scale = 1000 / num_diffusion_timesteps
|
111 |
+
return get_beta_schedule(
|
112 |
+
"linear",
|
113 |
+
beta_start=scale * 0.0001,
|
114 |
+
beta_end=scale * 0.02,
|
115 |
+
num_diffusion_timesteps=num_diffusion_timesteps,
|
116 |
+
)
|
117 |
+
elif schedule_name == "squaredcos_cap_v2":
|
118 |
+
return betas_for_alpha_bar(
|
119 |
+
num_diffusion_timesteps,
|
120 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
124 |
+
|
125 |
+
|
126 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
127 |
+
"""
|
128 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
129 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
130 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
131 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
132 |
+
produces the cumulative product of (1-beta) up to that
|
133 |
+
part of the diffusion process.
|
134 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
135 |
+
prevent singularities.
|
136 |
+
"""
|
137 |
+
betas = []
|
138 |
+
for i in range(num_diffusion_timesteps):
|
139 |
+
t1 = i / num_diffusion_timesteps
|
140 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
141 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
142 |
+
return np.array(betas)
|
143 |
+
|
144 |
+
|
145 |
+
class GaussianDiffusion:
|
146 |
+
"""
|
147 |
+
Utilities for training and sampling diffusion models.
|
148 |
+
Original ported from this codebase:
|
149 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
150 |
+
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
151 |
+
starting at T and going to 1.
|
152 |
+
"""
|
153 |
+
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
*,
|
157 |
+
betas,
|
158 |
+
model_mean_type,
|
159 |
+
model_var_type,
|
160 |
+
loss_type,
|
161 |
+
snr=False,
|
162 |
+
return_startx=False,
|
163 |
+
):
|
164 |
+
|
165 |
+
self.model_mean_type = model_mean_type
|
166 |
+
self.model_var_type = model_var_type
|
167 |
+
self.loss_type = loss_type
|
168 |
+
self.snr = snr
|
169 |
+
self.return_startx = return_startx
|
170 |
+
|
171 |
+
# Use float64 for accuracy.
|
172 |
+
betas = np.array(betas, dtype=np.float64)
|
173 |
+
self.betas = betas
|
174 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
175 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
176 |
+
|
177 |
+
self.num_timesteps = int(betas.shape[0])
|
178 |
+
|
179 |
+
alphas = 1.0 - betas
|
180 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
181 |
+
|
182 |
+
if False:
|
183 |
+
target_resolution = 128 # 1024:128; 512:64; 256:32;
|
184 |
+
reference_resolution = 64 # Reference resolution (e.g., 64x64)
|
185 |
+
scaling_factor = (target_resolution / reference_resolution) ** 2
|
186 |
+
print('scaling_factor', scaling_factor)
|
187 |
+
|
188 |
+
# Adjust alphas and betas according to the scaling factor
|
189 |
+
alpha_cumprod_snr_shift = self.alphas_cumprod / (scaling_factor * (1 - self.alphas_cumprod) + self.alphas_cumprod)
|
190 |
+
alpha_cuspord_rmove1 = np.concatenate([np.ones([1]), alpha_cumprod_snr_shift[:999]])
|
191 |
+
alpha_snr_shift = alpha_cumprod_snr_shift / alpha_cuspord_rmove1
|
192 |
+
|
193 |
+
betas_snr_shift = 1 - alpha_snr_shift
|
194 |
+
|
195 |
+
# Update the class attributes with adjusted values
|
196 |
+
snr_ref = (self.alphas_cumprod / (1 - self.alphas_cumprod))
|
197 |
+
snr_cur = (alpha_cumprod_snr_shift / (1 - alpha_cumprod_snr_shift))
|
198 |
+
|
199 |
+
self.betas = betas_snr_shift
|
200 |
+
self.alphas_cumprod = np.cumprod(alpha_snr_shift, axis=0)
|
201 |
+
|
202 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
203 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
204 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
205 |
+
|
206 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
207 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
208 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
209 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
210 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
211 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
212 |
+
|
213 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
214 |
+
self.posterior_variance = (
|
215 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
216 |
+
)
|
217 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
218 |
+
self.posterior_log_variance_clipped = np.log(
|
219 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
220 |
+
) if len(self.posterior_variance) > 1 else np.array([])
|
221 |
+
|
222 |
+
self.posterior_mean_coef1 = (
|
223 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
224 |
+
)
|
225 |
+
self.posterior_mean_coef2 = (
|
226 |
+
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
227 |
+
)
|
228 |
+
|
229 |
+
def q_mean_variance(self, x_start, t):
|
230 |
+
"""
|
231 |
+
Get the distribution q(x_t | x_0).
|
232 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
233 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
234 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
235 |
+
"""
|
236 |
+
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
237 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
238 |
+
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
239 |
+
return mean, variance, log_variance
|
240 |
+
|
241 |
+
def q_sample(self, x_start, t, noise=None):
|
242 |
+
"""
|
243 |
+
Diffuse the data for a given number of diffusion steps.
|
244 |
+
In other words, sample from q(x_t | x_0).
|
245 |
+
:param x_start: the initial data batch.
|
246 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
247 |
+
:param noise: if specified, the split-out normal noise.
|
248 |
+
:return: A noisy version of x_start.
|
249 |
+
"""
|
250 |
+
if noise is None:
|
251 |
+
noise = th.randn_like(x_start)
|
252 |
+
assert noise.shape == x_start.shape
|
253 |
+
return (
|
254 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
255 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
256 |
+
)
|
257 |
+
|
258 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
259 |
+
"""
|
260 |
+
Compute the mean and variance of the diffusion posterior:
|
261 |
+
q(x_{t-1} | x_t, x_0)
|
262 |
+
"""
|
263 |
+
assert x_start.shape == x_t.shape
|
264 |
+
posterior_mean = (
|
265 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
266 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
267 |
+
)
|
268 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
269 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
270 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
271 |
+
)
|
272 |
+
assert (
|
273 |
+
posterior_mean.shape[0]
|
274 |
+
== posterior_variance.shape[0]
|
275 |
+
== posterior_log_variance_clipped.shape[0]
|
276 |
+
== x_start.shape[0]
|
277 |
+
)
|
278 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
279 |
+
|
280 |
+
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
|
281 |
+
"""
|
282 |
+
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
283 |
+
the initial x, x_0.
|
284 |
+
:param model: the model, which takes a signal and a batch of timesteps
|
285 |
+
as input.
|
286 |
+
:param x: the [N x C x ...] tensor at time t.
|
287 |
+
:param t: a 1-D Tensor of timesteps.
|
288 |
+
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
289 |
+
:param denoised_fn: if not None, a function which applies to the
|
290 |
+
x_start prediction before it is used to sample. Applies before
|
291 |
+
clip_denoised.
|
292 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
293 |
+
pass to the model. This can be used for conditioning.
|
294 |
+
:return: a dict with the following keys:
|
295 |
+
- 'mean': the model mean output.
|
296 |
+
- 'variance': the model variance output.
|
297 |
+
- 'log_variance': the log of 'variance'.
|
298 |
+
- 'pred_xstart': the prediction for x_0.
|
299 |
+
"""
|
300 |
+
if model_kwargs is None:
|
301 |
+
model_kwargs = {}
|
302 |
+
|
303 |
+
B, C = x.shape[:2]
|
304 |
+
assert t.shape == (B,)
|
305 |
+
model_output = model(x, t, **model_kwargs)
|
306 |
+
if isinstance(model_output, tuple):
|
307 |
+
model_output, extra = model_output
|
308 |
+
else:
|
309 |
+
extra = None
|
310 |
+
|
311 |
+
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
312 |
+
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
313 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
314 |
+
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
|
315 |
+
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
316 |
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
317 |
+
frac = (model_var_values + 1) / 2
|
318 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
319 |
+
model_variance = th.exp(model_log_variance)
|
320 |
+
elif self.model_var_type in [ModelVarType.FIXED_LARGE, ModelVarType.FIXED_SMALL]:
|
321 |
+
model_variance, model_log_variance = {
|
322 |
+
# for fixedlarge, we set the initial (log-)variance like so
|
323 |
+
# to get a better decoder log likelihood.
|
324 |
+
ModelVarType.FIXED_LARGE: (
|
325 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
326 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
327 |
+
),
|
328 |
+
ModelVarType.FIXED_SMALL: (
|
329 |
+
self.posterior_variance,
|
330 |
+
self.posterior_log_variance_clipped,
|
331 |
+
),
|
332 |
+
}[self.model_var_type]
|
333 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
334 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
335 |
+
else:
|
336 |
+
model_variance = th.zeros_like(model_output)
|
337 |
+
model_log_variance = th.zeros_like(model_output)
|
338 |
+
|
339 |
+
def process_xstart(x):
|
340 |
+
if denoised_fn is not None:
|
341 |
+
x = denoised_fn(x)
|
342 |
+
if clip_denoised:
|
343 |
+
return x.clamp(-1, 1)
|
344 |
+
return x
|
345 |
+
|
346 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
347 |
+
pred_xstart = process_xstart(model_output)
|
348 |
+
else:
|
349 |
+
pred_xstart = process_xstart(
|
350 |
+
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
351 |
+
)
|
352 |
+
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
353 |
+
|
354 |
+
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
355 |
+
return {
|
356 |
+
"mean": model_mean,
|
357 |
+
"variance": model_variance,
|
358 |
+
"log_variance": model_log_variance,
|
359 |
+
"pred_xstart": pred_xstart,
|
360 |
+
"extra": extra,
|
361 |
+
}
|
362 |
+
|
363 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
364 |
+
assert x_t.shape == eps.shape
|
365 |
+
return (
|
366 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
367 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
368 |
+
)
|
369 |
+
|
370 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
371 |
+
return (
|
372 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
373 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
374 |
+
|
375 |
+
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
376 |
+
"""
|
377 |
+
Compute the mean for the previous step, given a function cond_fn that
|
378 |
+
computes the gradient of a conditional log probability with respect to
|
379 |
+
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
380 |
+
condition on y.
|
381 |
+
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
382 |
+
"""
|
383 |
+
gradient = cond_fn(x, t, **model_kwargs)
|
384 |
+
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
385 |
+
return new_mean
|
386 |
+
|
387 |
+
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
388 |
+
"""
|
389 |
+
Compute what the p_mean_variance output would have been, should the
|
390 |
+
model's score function be conditioned by cond_fn.
|
391 |
+
See condition_mean() for details on cond_fn.
|
392 |
+
Unlike condition_mean(), this instead uses the conditioning strategy
|
393 |
+
from Song et al (2020).
|
394 |
+
"""
|
395 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
396 |
+
|
397 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
398 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
|
399 |
+
|
400 |
+
out = p_mean_var.copy()
|
401 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
402 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
|
403 |
+
return out
|
404 |
+
|
405 |
+
def p_sample(
|
406 |
+
self,
|
407 |
+
model,
|
408 |
+
x,
|
409 |
+
t,
|
410 |
+
clip_denoised=True,
|
411 |
+
denoised_fn=None,
|
412 |
+
cond_fn=None,
|
413 |
+
model_kwargs=None,
|
414 |
+
):
|
415 |
+
"""
|
416 |
+
Sample x_{t-1} from the model at the given timestep.
|
417 |
+
:param model: the model to sample from.
|
418 |
+
:param x: the current tensor at x_{t-1}.
|
419 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
420 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
421 |
+
:param denoised_fn: if not None, a function which applies to the
|
422 |
+
x_start prediction before it is used to sample.
|
423 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
424 |
+
similarly to the model.
|
425 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
426 |
+
pass to the model. This can be used for conditioning.
|
427 |
+
:return: a dict containing the following keys:
|
428 |
+
- 'sample': a random sample from the model.
|
429 |
+
- 'pred_xstart': a prediction of x_0.
|
430 |
+
"""
|
431 |
+
out = self.p_mean_variance(
|
432 |
+
model,
|
433 |
+
x,
|
434 |
+
t,
|
435 |
+
clip_denoised=clip_denoised,
|
436 |
+
denoised_fn=denoised_fn,
|
437 |
+
model_kwargs=model_kwargs,
|
438 |
+
)
|
439 |
+
noise = th.randn_like(x)
|
440 |
+
nonzero_mask = (
|
441 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
442 |
+
) # no noise when t == 0
|
443 |
+
if cond_fn is not None:
|
444 |
+
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
445 |
+
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
446 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
447 |
+
|
448 |
+
def p_sample_loop(
|
449 |
+
self,
|
450 |
+
model,
|
451 |
+
shape,
|
452 |
+
noise=None,
|
453 |
+
clip_denoised=True,
|
454 |
+
denoised_fn=None,
|
455 |
+
cond_fn=None,
|
456 |
+
model_kwargs=None,
|
457 |
+
device=None,
|
458 |
+
progress=False,
|
459 |
+
):
|
460 |
+
"""
|
461 |
+
Generate samples from the model.
|
462 |
+
:param model: the model module.
|
463 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
464 |
+
:param noise: if specified, the noise from the encoder to sample.
|
465 |
+
Should be of the same shape as `shape`.
|
466 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
467 |
+
:param denoised_fn: if not None, a function which applies to the
|
468 |
+
x_start prediction before it is used to sample.
|
469 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
470 |
+
similarly to the model.
|
471 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
472 |
+
pass to the model. This can be used for conditioning.
|
473 |
+
:param device: if specified, the device to create the samples on.
|
474 |
+
If not specified, use a model parameter's device.
|
475 |
+
:param progress: if True, show a tqdm progress bar.
|
476 |
+
:return: a non-differentiable batch of samples.
|
477 |
+
"""
|
478 |
+
final = None
|
479 |
+
for sample in self.p_sample_loop_progressive(
|
480 |
+
model,
|
481 |
+
shape,
|
482 |
+
noise=noise,
|
483 |
+
clip_denoised=clip_denoised,
|
484 |
+
denoised_fn=denoised_fn,
|
485 |
+
cond_fn=cond_fn,
|
486 |
+
model_kwargs=model_kwargs,
|
487 |
+
device=device,
|
488 |
+
progress=progress,
|
489 |
+
):
|
490 |
+
final = sample
|
491 |
+
return final["sample"]
|
492 |
+
|
493 |
+
def p_sample_loop_progressive(
|
494 |
+
self,
|
495 |
+
model,
|
496 |
+
shape,
|
497 |
+
noise=None,
|
498 |
+
clip_denoised=True,
|
499 |
+
denoised_fn=None,
|
500 |
+
cond_fn=None,
|
501 |
+
model_kwargs=None,
|
502 |
+
device=None,
|
503 |
+
progress=False,
|
504 |
+
):
|
505 |
+
"""
|
506 |
+
Generate samples from the model and yield intermediate samples from
|
507 |
+
each timestep of diffusion.
|
508 |
+
Arguments are the same as p_sample_loop().
|
509 |
+
Returns a generator over dicts, where each dict is the return value of
|
510 |
+
p_sample().
|
511 |
+
"""
|
512 |
+
if device is None:
|
513 |
+
device = next(model.parameters()).device
|
514 |
+
assert isinstance(shape, (tuple, list))
|
515 |
+
if noise is not None:
|
516 |
+
img = noise
|
517 |
+
else:
|
518 |
+
img = th.randn(*shape, device=device)
|
519 |
+
indices = list(range(self.num_timesteps))[::-1]
|
520 |
+
|
521 |
+
if progress:
|
522 |
+
# Lazy import so that we don't depend on tqdm.
|
523 |
+
from tqdm.auto import tqdm
|
524 |
+
|
525 |
+
indices = tqdm(indices)
|
526 |
+
|
527 |
+
for i in indices:
|
528 |
+
t = th.tensor([i] * shape[0], device=device)
|
529 |
+
with th.no_grad():
|
530 |
+
out = self.p_sample(
|
531 |
+
model,
|
532 |
+
img,
|
533 |
+
t,
|
534 |
+
clip_denoised=clip_denoised,
|
535 |
+
denoised_fn=denoised_fn,
|
536 |
+
cond_fn=cond_fn,
|
537 |
+
model_kwargs=model_kwargs,
|
538 |
+
)
|
539 |
+
yield out
|
540 |
+
img = out["sample"]
|
541 |
+
|
542 |
+
def ddim_sample(
|
543 |
+
self,
|
544 |
+
model,
|
545 |
+
x,
|
546 |
+
t,
|
547 |
+
clip_denoised=True,
|
548 |
+
denoised_fn=None,
|
549 |
+
cond_fn=None,
|
550 |
+
model_kwargs=None,
|
551 |
+
eta=0.0,
|
552 |
+
):
|
553 |
+
"""
|
554 |
+
Sample x_{t-1} from the model using DDIM.
|
555 |
+
Same usage as p_sample().
|
556 |
+
"""
|
557 |
+
out = self.p_mean_variance(
|
558 |
+
model,
|
559 |
+
x,
|
560 |
+
t,
|
561 |
+
clip_denoised=clip_denoised,
|
562 |
+
denoised_fn=denoised_fn,
|
563 |
+
model_kwargs=model_kwargs,
|
564 |
+
)
|
565 |
+
if cond_fn is not None:
|
566 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
567 |
+
|
568 |
+
# Usually our model outputs epsilon, but we re-derive it
|
569 |
+
# in case we used x_start or x_prev prediction.
|
570 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
571 |
+
|
572 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
573 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
574 |
+
sigma = (
|
575 |
+
eta
|
576 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
577 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
578 |
+
)
|
579 |
+
# Equation 12.
|
580 |
+
noise = th.randn_like(x)
|
581 |
+
mean_pred = (
|
582 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
583 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
584 |
+
)
|
585 |
+
nonzero_mask = (
|
586 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
587 |
+
) # no noise when t == 0
|
588 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
589 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
590 |
+
|
591 |
+
def ddim_reverse_sample(
|
592 |
+
self,
|
593 |
+
model,
|
594 |
+
x,
|
595 |
+
t,
|
596 |
+
clip_denoised=True,
|
597 |
+
denoised_fn=None,
|
598 |
+
cond_fn=None,
|
599 |
+
model_kwargs=None,
|
600 |
+
eta=0.0,
|
601 |
+
):
|
602 |
+
"""
|
603 |
+
Sample x_{t+1} from the model using DDIM reverse ODE.
|
604 |
+
"""
|
605 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
606 |
+
out = self.p_mean_variance(
|
607 |
+
model,
|
608 |
+
x,
|
609 |
+
t,
|
610 |
+
clip_denoised=clip_denoised,
|
611 |
+
denoised_fn=denoised_fn,
|
612 |
+
model_kwargs=model_kwargs,
|
613 |
+
)
|
614 |
+
if cond_fn is not None:
|
615 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
616 |
+
# Usually our model outputs epsilon, but we re-derive it
|
617 |
+
# in case we used x_start or x_prev prediction.
|
618 |
+
eps = (
|
619 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
620 |
+
- out["pred_xstart"]
|
621 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
622 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
623 |
+
|
624 |
+
# Equation 12. reversed
|
625 |
+
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
|
626 |
+
|
627 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
628 |
+
|
629 |
+
def ddim_sample_loop(
|
630 |
+
self,
|
631 |
+
model,
|
632 |
+
shape,
|
633 |
+
noise=None,
|
634 |
+
clip_denoised=True,
|
635 |
+
denoised_fn=None,
|
636 |
+
cond_fn=None,
|
637 |
+
model_kwargs=None,
|
638 |
+
device=None,
|
639 |
+
progress=False,
|
640 |
+
eta=0.0,
|
641 |
+
):
|
642 |
+
"""
|
643 |
+
Generate samples from the model using DDIM.
|
644 |
+
Same usage as p_sample_loop().
|
645 |
+
"""
|
646 |
+
final = None
|
647 |
+
for sample in self.ddim_sample_loop_progressive(
|
648 |
+
model,
|
649 |
+
shape,
|
650 |
+
noise=noise,
|
651 |
+
clip_denoised=clip_denoised,
|
652 |
+
denoised_fn=denoised_fn,
|
653 |
+
cond_fn=cond_fn,
|
654 |
+
model_kwargs=model_kwargs,
|
655 |
+
device=device,
|
656 |
+
progress=progress,
|
657 |
+
eta=eta,
|
658 |
+
):
|
659 |
+
final = sample
|
660 |
+
return final["sample"]
|
661 |
+
|
662 |
+
def ddim_sample_loop_progressive(
|
663 |
+
self,
|
664 |
+
model,
|
665 |
+
shape,
|
666 |
+
noise=None,
|
667 |
+
clip_denoised=True,
|
668 |
+
denoised_fn=None,
|
669 |
+
cond_fn=None,
|
670 |
+
model_kwargs=None,
|
671 |
+
device=None,
|
672 |
+
progress=False,
|
673 |
+
eta=0.0,
|
674 |
+
):
|
675 |
+
"""
|
676 |
+
Use DDIM to sample from the model and yield intermediate samples from
|
677 |
+
each timestep of DDIM.
|
678 |
+
Same usage as p_sample_loop_progressive().
|
679 |
+
"""
|
680 |
+
if device is None:
|
681 |
+
device = next(model.parameters()).device
|
682 |
+
assert isinstance(shape, (tuple, list))
|
683 |
+
if noise is not None:
|
684 |
+
img = noise
|
685 |
+
else:
|
686 |
+
img = th.randn(*shape, device=device)
|
687 |
+
indices = list(range(self.num_timesteps))[::-1]
|
688 |
+
|
689 |
+
if progress:
|
690 |
+
# Lazy import so that we don't depend on tqdm.
|
691 |
+
from tqdm.auto import tqdm
|
692 |
+
|
693 |
+
indices = tqdm(indices)
|
694 |
+
|
695 |
+
for i in indices:
|
696 |
+
t = th.tensor([i] * shape[0], device=device)
|
697 |
+
with th.no_grad():
|
698 |
+
out = self.ddim_sample(
|
699 |
+
model,
|
700 |
+
img,
|
701 |
+
t,
|
702 |
+
clip_denoised=clip_denoised,
|
703 |
+
denoised_fn=denoised_fn,
|
704 |
+
cond_fn=cond_fn,
|
705 |
+
model_kwargs=model_kwargs,
|
706 |
+
eta=eta,
|
707 |
+
)
|
708 |
+
yield out
|
709 |
+
img = out["sample"]
|
710 |
+
|
711 |
+
def _vb_terms_bpd(
|
712 |
+
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
|
713 |
+
):
|
714 |
+
"""
|
715 |
+
Get a term for the variational lower-bound.
|
716 |
+
The resulting units are bits (rather than nats, as one might expect).
|
717 |
+
This allows for comparison to other papers.
|
718 |
+
:return: a dict with the following keys:
|
719 |
+
- 'output': a shape [N] tensor of NLLs or KLs.
|
720 |
+
- 'pred_xstart': the x_0 predictions.
|
721 |
+
"""
|
722 |
+
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
723 |
+
x_start=x_start, x_t=x_t, t=t
|
724 |
+
)
|
725 |
+
out = self.p_mean_variance(
|
726 |
+
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
727 |
+
)
|
728 |
+
kl = normal_kl(
|
729 |
+
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
730 |
+
)
|
731 |
+
kl = mean_flat(kl) / np.log(2.0)
|
732 |
+
|
733 |
+
decoder_nll = -discretized_gaussian_log_likelihood(
|
734 |
+
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
735 |
+
)
|
736 |
+
assert decoder_nll.shape == x_start.shape
|
737 |
+
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
738 |
+
|
739 |
+
# At the first timestep return the decoder NLL,
|
740 |
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
741 |
+
output = th.where((t == 0), decoder_nll, kl)
|
742 |
+
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
743 |
+
|
744 |
+
def training_losses(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False):
|
745 |
+
"""
|
746 |
+
Compute training losses for a single timestep.
|
747 |
+
:param model: the model to evaluate loss on.
|
748 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
749 |
+
:param t: a batch of timestep indices.
|
750 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
751 |
+
pass to the model. This can be used for conditioning.
|
752 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
753 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
754 |
+
Some mean or variance settings may also have other keys.
|
755 |
+
"""
|
756 |
+
t = timestep
|
757 |
+
if model_kwargs is None:
|
758 |
+
model_kwargs = {}
|
759 |
+
if skip_noise:
|
760 |
+
x_t = x_start
|
761 |
+
else:
|
762 |
+
if noise is None:
|
763 |
+
noise = th.randn_like(x_start)
|
764 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
765 |
+
|
766 |
+
terms = {}
|
767 |
+
|
768 |
+
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
769 |
+
terms["loss"] = self._vb_terms_bpd(
|
770 |
+
model=model,
|
771 |
+
x_start=x_start,
|
772 |
+
x_t=x_t,
|
773 |
+
t=t,
|
774 |
+
clip_denoised=False,
|
775 |
+
model_kwargs=model_kwargs,
|
776 |
+
)["output"]
|
777 |
+
if self.loss_type == LossType.RESCALED_KL:
|
778 |
+
terms["loss"] *= self.num_timesteps
|
779 |
+
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
780 |
+
model_output = model(x_t, t, **model_kwargs)
|
781 |
+
if isinstance(model_output, dict) and model_output.get('x', None) is not None:
|
782 |
+
output = model_output['x']
|
783 |
+
else:
|
784 |
+
output = model_output
|
785 |
+
|
786 |
+
if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON:
|
787 |
+
B, C = x_t.shape[:2]
|
788 |
+
assert output.shape == (B, C * 2, *x_t.shape[2:])
|
789 |
+
output = th.split(output, C, dim=1)[0]
|
790 |
+
return output, self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output), x_t
|
791 |
+
|
792 |
+
if self.model_var_type in [
|
793 |
+
ModelVarType.LEARNED,
|
794 |
+
ModelVarType.LEARNED_RANGE,
|
795 |
+
]:
|
796 |
+
B, C = x_t.shape[:2]
|
797 |
+
assert output.shape == (B, C * 2, *x_t.shape[2:])
|
798 |
+
output, model_var_values = th.split(output, C, dim=1)
|
799 |
+
# Learn the variance using the variational bound, but don't let it affect our mean prediction.
|
800 |
+
frozen_out = th.cat([output.detach(), model_var_values], dim=1)
|
801 |
+
terms["vb"] = self._vb_terms_bpd(
|
802 |
+
model=lambda *args, r=frozen_out, **kwargs: r,
|
803 |
+
x_start=x_start,
|
804 |
+
x_t=x_t,
|
805 |
+
t=t,
|
806 |
+
clip_denoised=False,
|
807 |
+
)["output"]
|
808 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
809 |
+
# Divide by 1000 for equivalence with initial implementation.
|
810 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
811 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
812 |
+
|
813 |
+
target = {
|
814 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
815 |
+
x_start=x_start, x_t=x_t, t=t
|
816 |
+
)[0],
|
817 |
+
ModelMeanType.START_X: x_start,
|
818 |
+
ModelMeanType.EPSILON: noise,
|
819 |
+
}[self.model_mean_type]
|
820 |
+
assert output.shape == target.shape == x_start.shape
|
821 |
+
if self.snr:
|
822 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
823 |
+
pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output)
|
824 |
+
pred_startx = output
|
825 |
+
elif self.model_mean_type == ModelMeanType.EPSILON:
|
826 |
+
pred_noise = output
|
827 |
+
pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output)
|
828 |
+
# terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2)
|
829 |
+
# terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2)
|
830 |
+
|
831 |
+
t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32]
|
832 |
+
# best
|
833 |
+
target = th.where(t > 249, noise, x_start)
|
834 |
+
output = th.where(t > 249, pred_noise, pred_startx)
|
835 |
+
loss = (target - output) ** 2
|
836 |
+
if model_kwargs.get('mask_ratio', False) and model_kwargs['mask_ratio'] > 0:
|
837 |
+
assert 'mask' in model_output
|
838 |
+
loss = F.avg_pool2d(loss.mean(dim=1), model.model.module.patch_size).flatten(1)
|
839 |
+
mask = model_output['mask']
|
840 |
+
unmask = 1 - mask
|
841 |
+
terms['mse'] = mean_flat(loss * unmask) * unmask.shape[1]/unmask.sum(1)
|
842 |
+
if model_kwargs['mask_loss_coef'] > 0:
|
843 |
+
terms['mae'] = model_kwargs['mask_loss_coef'] * mean_flat(loss * mask) * mask.shape[1]/mask.sum(1)
|
844 |
+
else:
|
845 |
+
terms["mse"] = mean_flat(loss)
|
846 |
+
if "vb" in terms:
|
847 |
+
terms["loss"] = terms["mse"] + terms["vb"]
|
848 |
+
else:
|
849 |
+
terms["loss"] = terms["mse"]
|
850 |
+
if "mae" in terms:
|
851 |
+
terms["loss"] = terms["loss"] + terms["mae"]
|
852 |
+
else:
|
853 |
+
raise NotImplementedError(self.loss_type)
|
854 |
+
|
855 |
+
return terms
|
856 |
+
|
857 |
+
def training_losses_diffusers(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False):
|
858 |
+
"""
|
859 |
+
Compute training losses for a single timestep.
|
860 |
+
:param model: the model to evaluate loss on.
|
861 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
862 |
+
:param t: a batch of timestep indices.
|
863 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
864 |
+
pass to the model. This can be used for conditioning.
|
865 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
866 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
867 |
+
Some mean or variance settings may also have other keys.
|
868 |
+
"""
|
869 |
+
t = timestep
|
870 |
+
if model_kwargs is None:
|
871 |
+
model_kwargs = {}
|
872 |
+
if skip_noise:
|
873 |
+
x_t = x_start
|
874 |
+
else:
|
875 |
+
if noise is None:
|
876 |
+
noise = th.randn_like(x_start)
|
877 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
878 |
+
|
879 |
+
terms = {}
|
880 |
+
|
881 |
+
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
882 |
+
terms["loss"] = self._vb_terms_bpd(
|
883 |
+
model=model,
|
884 |
+
x_start=x_start,
|
885 |
+
x_t=x_t,
|
886 |
+
t=t,
|
887 |
+
clip_denoised=False,
|
888 |
+
model_kwargs=model_kwargs,
|
889 |
+
)["output"]
|
890 |
+
if self.loss_type == LossType.RESCALED_KL:
|
891 |
+
terms["loss"] *= self.num_timesteps
|
892 |
+
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
893 |
+
output = model(x_t, timestep=t, **model_kwargs, return_dict=False)[0]
|
894 |
+
|
895 |
+
if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON:
|
896 |
+
B, C = x_t.shape[:2]
|
897 |
+
assert output.shape == (B, C * 2, *x_t.shape[2:])
|
898 |
+
output = th.split(output, C, dim=1)[0]
|
899 |
+
return output, self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output), x_t
|
900 |
+
|
901 |
+
if self.model_var_type in [
|
902 |
+
ModelVarType.LEARNED,
|
903 |
+
ModelVarType.LEARNED_RANGE,
|
904 |
+
]:
|
905 |
+
B, C = x_t.shape[:2]
|
906 |
+
assert output.shape == (B, C * 2, *x_t.shape[2:])
|
907 |
+
output, model_var_values = th.split(output, C, dim=1)
|
908 |
+
# Learn the variance using the variational bound, but don't let it affect our mean prediction.
|
909 |
+
frozen_out = th.cat([output.detach(), model_var_values], dim=1)
|
910 |
+
terms["vb"] = self._vb_terms_bpd(
|
911 |
+
model=lambda *args, r=frozen_out, **kwargs: r,
|
912 |
+
x_start=x_start,
|
913 |
+
x_t=x_t,
|
914 |
+
t=t,
|
915 |
+
clip_denoised=False,
|
916 |
+
)["output"]
|
917 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
918 |
+
# Divide by 1000 for equivalence with initial implementation.
|
919 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
920 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
921 |
+
|
922 |
+
target = {
|
923 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
924 |
+
x_start=x_start, x_t=x_t, t=t
|
925 |
+
)[0],
|
926 |
+
ModelMeanType.START_X: x_start,
|
927 |
+
ModelMeanType.EPSILON: noise,
|
928 |
+
}[self.model_mean_type]
|
929 |
+
assert output.shape == target.shape == x_start.shape
|
930 |
+
if self.snr:
|
931 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
932 |
+
pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output)
|
933 |
+
pred_startx = output
|
934 |
+
elif self.model_mean_type == ModelMeanType.EPSILON:
|
935 |
+
pred_noise = output
|
936 |
+
pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output)
|
937 |
+
# terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2)
|
938 |
+
# terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2)
|
939 |
+
|
940 |
+
t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32]
|
941 |
+
# best
|
942 |
+
target = th.where(t > 249, noise, x_start)
|
943 |
+
output = th.where(t > 249, pred_noise, pred_startx)
|
944 |
+
loss = (target - output) ** 2
|
945 |
+
terms["mse"] = mean_flat(loss)
|
946 |
+
if "vb" in terms:
|
947 |
+
terms["loss"] = terms["mse"] + terms["vb"]
|
948 |
+
else:
|
949 |
+
terms["loss"] = terms["mse"]
|
950 |
+
if "mae" in terms:
|
951 |
+
terms["loss"] = terms["loss"] + terms["mae"]
|
952 |
+
else:
|
953 |
+
raise NotImplementedError(self.loss_type)
|
954 |
+
|
955 |
+
return terms
|
956 |
+
|
957 |
+
def _prior_bpd(self, x_start):
|
958 |
+
"""
|
959 |
+
Get the prior KL term for the variational lower-bound, measured in
|
960 |
+
bits-per-dim.
|
961 |
+
This term can't be optimized, as it only depends on the encoder.
|
962 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
963 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
964 |
+
"""
|
965 |
+
batch_size = x_start.shape[0]
|
966 |
+
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
967 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
968 |
+
kl_prior = normal_kl(
|
969 |
+
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
970 |
+
)
|
971 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
972 |
+
|
973 |
+
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
974 |
+
"""
|
975 |
+
Compute the entire variational lower-bound, measured in bits-per-dim,
|
976 |
+
as well as other related quantities.
|
977 |
+
:param model: the model to evaluate loss on.
|
978 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
979 |
+
:param clip_denoised: if True, clip denoised samples.
|
980 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
981 |
+
pass to the model. This can be used for conditioning.
|
982 |
+
:return: a dict containing the following keys:
|
983 |
+
- total_bpd: the total variational lower-bound, per batch element.
|
984 |
+
- prior_bpd: the prior term in the lower-bound.
|
985 |
+
- vb: an [N x T] tensor of terms in the lower-bound.
|
986 |
+
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
987 |
+
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
988 |
+
"""
|
989 |
+
device = x_start.device
|
990 |
+
batch_size = x_start.shape[0]
|
991 |
+
|
992 |
+
vb = []
|
993 |
+
xstart_mse = []
|
994 |
+
mse = []
|
995 |
+
for t in list(range(self.num_timesteps))[::-1]:
|
996 |
+
t_batch = th.tensor([t] * batch_size, device=device)
|
997 |
+
noise = th.randn_like(x_start)
|
998 |
+
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
999 |
+
# Calculate VLB term at the current timestep
|
1000 |
+
with th.no_grad():
|
1001 |
+
out = self._vb_terms_bpd(
|
1002 |
+
model,
|
1003 |
+
x_start=x_start,
|
1004 |
+
x_t=x_t,
|
1005 |
+
t=t_batch,
|
1006 |
+
clip_denoised=clip_denoised,
|
1007 |
+
model_kwargs=model_kwargs,
|
1008 |
+
)
|
1009 |
+
vb.append(out["output"])
|
1010 |
+
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
1011 |
+
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
1012 |
+
mse.append(mean_flat((eps - noise) ** 2))
|
1013 |
+
|
1014 |
+
vb = th.stack(vb, dim=1)
|
1015 |
+
xstart_mse = th.stack(xstart_mse, dim=1)
|
1016 |
+
mse = th.stack(mse, dim=1)
|
1017 |
+
|
1018 |
+
prior_bpd = self._prior_bpd(x_start)
|
1019 |
+
total_bpd = vb.sum(dim=1) + prior_bpd
|
1020 |
+
return {
|
1021 |
+
"total_bpd": total_bpd,
|
1022 |
+
"prior_bpd": prior_bpd,
|
1023 |
+
"vb": vb,
|
1024 |
+
"xstart_mse": xstart_mse,
|
1025 |
+
"mse": mse,
|
1026 |
+
}
|
1027 |
+
|
1028 |
+
|
1029 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
1030 |
+
"""
|
1031 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
1032 |
+
:param arr: the 1-D numpy array.
|
1033 |
+
:param timesteps: a tensor of indices into the array to extract.
|
1034 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
1035 |
+
dimension equal to the length of timesteps.
|
1036 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
1037 |
+
"""
|
1038 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
1039 |
+
while len(res.shape) < len(broadcast_shape):
|
1040 |
+
res = res[..., None]
|
1041 |
+
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
diffusion/model/llava/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from diffusion.model.llava.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig
|
diffusion/model/llava/llava_mpt.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
import warnings
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
from torch.nn import CrossEntropyLoss
|
23 |
+
|
24 |
+
import math
|
25 |
+
|
26 |
+
from transformers import AutoConfig, AutoModelForCausalLM, CLIPVisionModel, CLIPImageProcessor
|
27 |
+
|
28 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
29 |
+
|
30 |
+
from diffusion.model.llava.mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
|
31 |
+
|
32 |
+
|
33 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
34 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
35 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
36 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
37 |
+
|
38 |
+
|
39 |
+
class LlavaMPTConfig(MPTConfig):
|
40 |
+
model_type = "llava_mpt"
|
41 |
+
|
42 |
+
|
43 |
+
class LlavaMPTModel(MPTModel):
|
44 |
+
config_class = LlavaMPTConfig
|
45 |
+
|
46 |
+
def __init__(self, config: MPTConfig, mm_vision_tower=None, mm_hidden_size=None):
|
47 |
+
super(LlavaMPTModel, self).__init__(config)
|
48 |
+
|
49 |
+
if hasattr(config, "mm_vision_tower"):
|
50 |
+
# HACK: for FSDP
|
51 |
+
self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
|
52 |
+
# self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
|
53 |
+
|
54 |
+
if hasattr(config, "use_mm_proj"):
|
55 |
+
self.mm_projector = nn.Linear(config.mm_hidden_size, config.d_model)
|
56 |
+
|
57 |
+
def initialize_vision_modules(self, vision_tower, mm_vision_select_layer,
|
58 |
+
pretrain_mm_mlp_adapter=None, tune_mm_mlp_adapter=False):
|
59 |
+
self.config.mm_vision_tower = vision_tower
|
60 |
+
|
61 |
+
image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
62 |
+
|
63 |
+
if not hasattr(self, 'vision_tower'):
|
64 |
+
vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
|
65 |
+
else:
|
66 |
+
vision_tower = self.vision_tower[0]
|
67 |
+
vision_tower.requires_grad_(False)
|
68 |
+
vision_tower = vision_tower.to(torch.float16)
|
69 |
+
self.vision_tower = [vision_tower]
|
70 |
+
|
71 |
+
vision_config = vision_tower.config
|
72 |
+
num_patches = (vision_config.image_size // vision_config.patch_size) ** 2
|
73 |
+
|
74 |
+
self.config.use_mm_proj = True
|
75 |
+
self.config.mm_hidden_size = vision_config.hidden_size
|
76 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
77 |
+
|
78 |
+
if not hasattr(self, 'mm_projector'):
|
79 |
+
self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.d_model)
|
80 |
+
|
81 |
+
if pretrain_mm_mlp_adapter is not None:
|
82 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
83 |
+
self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items() if 'mm_projector' in k})
|
84 |
+
|
85 |
+
return dict(
|
86 |
+
image_processor=image_processor,
|
87 |
+
image_token_len=num_patches,
|
88 |
+
vision_config=vision_config
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None):
|
92 |
+
|
93 |
+
# HACK: replace back original embeddings for LLaVA pretraining
|
94 |
+
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
|
95 |
+
# if orig_embeds_params is not None:
|
96 |
+
# orig_embeds_params = orig_embeds_params[0]
|
97 |
+
# with torch.no_grad():
|
98 |
+
# self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data
|
99 |
+
|
100 |
+
inputs_embeds = self.wte(input_ids)
|
101 |
+
|
102 |
+
vision_tower = getattr(self, 'vision_tower', None)
|
103 |
+
if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
|
104 |
+
# TODO: this is a modified multimodal LLM -- Haotian Liu
|
105 |
+
vision_tower = vision_tower[0] # HACK: for FSDP
|
106 |
+
with torch.no_grad():
|
107 |
+
if type(images) is list:
|
108 |
+
# variable length images
|
109 |
+
image_features = []
|
110 |
+
for image in images:
|
111 |
+
image_forward_out = vision_tower(image.unsqueeze(0), output_hidden_states=True)
|
112 |
+
select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
|
113 |
+
select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
|
114 |
+
image_feature = select_hidden_state[:, 1:]
|
115 |
+
image_features.append(image_feature)
|
116 |
+
else:
|
117 |
+
image_forward_outs = vision_tower(images, output_hidden_states=True)
|
118 |
+
select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
|
119 |
+
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
|
120 |
+
image_features = select_hidden_state[:, 1:]
|
121 |
+
if type(images) is list:
|
122 |
+
image_features = [self.mm_projector(image_feature)[0] for image_feature in image_features]
|
123 |
+
else:
|
124 |
+
image_features = self.mm_projector(image_features)
|
125 |
+
dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
|
126 |
+
dummy_image_features = self.mm_projector(dummy_image_features)
|
127 |
+
|
128 |
+
new_input_embeds = []
|
129 |
+
cur_image_idx = 0
|
130 |
+
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
131 |
+
if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
|
132 |
+
# multimodal LLM, but the current sample is not multimodal
|
133 |
+
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
|
134 |
+
new_input_embeds.append(cur_input_embeds)
|
135 |
+
continue
|
136 |
+
if vision_tower.config.use_im_start_end:
|
137 |
+
cur_image_features = image_features[cur_image_idx]
|
138 |
+
num_patches = cur_image_features.shape[0]
|
139 |
+
if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum():
|
140 |
+
raise ValueError("The number of image start tokens and image end tokens should be the same.")
|
141 |
+
image_start_tokens = torch.where(cur_input_ids == vision_tower.config.im_start_token)[0]
|
142 |
+
for image_start_token_pos in image_start_tokens:
|
143 |
+
cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device)
|
144 |
+
num_patches = cur_image_features.shape[0]
|
145 |
+
if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token:
|
146 |
+
raise ValueError("The image end token should follow the image start token.")
|
147 |
+
if orig_embeds_params is not None:
|
148 |
+
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
|
149 |
+
else:
|
150 |
+
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
|
151 |
+
cur_image_idx += 1
|
152 |
+
new_input_embeds.append(cur_new_input_embeds)
|
153 |
+
else:
|
154 |
+
cur_image_features = image_features[cur_image_idx]
|
155 |
+
num_patches = cur_image_features.shape[0]
|
156 |
+
if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches:
|
157 |
+
raise ValueError("The number of image patch tokens should be the same as the number of image patches.")
|
158 |
+
masked_indices = torch.where(cur_input_ids == vision_tower.config.im_patch_token)[0]
|
159 |
+
mask_index_start = masked_indices[0]
|
160 |
+
if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
|
161 |
+
raise ValueError("The image patch tokens should be consecutive.")
|
162 |
+
if orig_embeds_params is not None:
|
163 |
+
cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0)
|
164 |
+
else:
|
165 |
+
cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
|
166 |
+
new_input_embeds.append(cur_new_input_embeds)
|
167 |
+
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
168 |
+
|
169 |
+
return super(LlavaMPTModel, self).forward(input_ids=None, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, tok_emb=inputs_embeds)
|
170 |
+
|
171 |
+
|
172 |
+
class LlavaMPTForCausalLM(MPTForCausalLM):
|
173 |
+
config_class = LlavaMPTConfig
|
174 |
+
supports_gradient_checkpointing = True
|
175 |
+
|
176 |
+
def __init__(self, config):
|
177 |
+
super(MPTForCausalLM, self).__init__(config)
|
178 |
+
|
179 |
+
if not config.tie_word_embeddings:
|
180 |
+
raise ValueError('MPTForCausalLM only supports tied word embeddings')
|
181 |
+
self.transformer = LlavaMPTModel(config)
|
182 |
+
self.logit_scale = None
|
183 |
+
if config.logit_scale is not None:
|
184 |
+
logit_scale = config.logit_scale
|
185 |
+
if isinstance(logit_scale, str):
|
186 |
+
if logit_scale == 'inv_sqrt_d_model':
|
187 |
+
logit_scale = 1 / math.sqrt(config.d_model)
|
188 |
+
else:
|
189 |
+
raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
|
190 |
+
self.logit_scale = logit_scale
|
191 |
+
|
192 |
+
def get_model(self):
|
193 |
+
return self.transformer
|
194 |
+
|
195 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
196 |
+
if isinstance(module, LlavaMPTModel):
|
197 |
+
module.gradient_checkpointing = value
|
198 |
+
|
199 |
+
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None):
|
200 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
201 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
202 |
+
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, images=images)
|
203 |
+
logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
|
204 |
+
if self.logit_scale is not None:
|
205 |
+
if self.logit_scale == 0:
|
206 |
+
warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
|
207 |
+
logits *= self.logit_scale
|
208 |
+
loss = None
|
209 |
+
if labels is not None:
|
210 |
+
labels = torch.roll(labels, shifts=-1)
|
211 |
+
labels[:, -1] = -100
|
212 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
|
213 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
|
214 |
+
|
215 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
216 |
+
if inputs_embeds is not None:
|
217 |
+
raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
|
218 |
+
attention_mask = kwargs['attention_mask'].bool()
|
219 |
+
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
220 |
+
raise NotImplementedError('MPT does not support generation with right padding.')
|
221 |
+
if self.transformer.attn_uses_sequence_id and self.training:
|
222 |
+
sequence_id = torch.zeros_like(input_ids[:1])
|
223 |
+
else:
|
224 |
+
sequence_id = None
|
225 |
+
if past_key_values is not None:
|
226 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
227 |
+
if self.transformer.prefix_lm:
|
228 |
+
prefix_mask = torch.ones_like(attention_mask)
|
229 |
+
if kwargs.get('use_cache') == False:
|
230 |
+
raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
|
231 |
+
else:
|
232 |
+
prefix_mask = None
|
233 |
+
return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)}
|
234 |
+
|
235 |
+
def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
|
236 |
+
tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None):
|
237 |
+
vision_config = self.get_model().vision_tower[0].config
|
238 |
+
vision_config.use_im_start_end = mm_use_im_start_end
|
239 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
240 |
+
self.resize_token_embeddings(len(tokenizer))
|
241 |
+
|
242 |
+
if mm_use_im_start_end:
|
243 |
+
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
244 |
+
self.resize_token_embeddings(len(tokenizer))
|
245 |
+
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
|
246 |
+
|
247 |
+
if num_new_tokens > 0:
|
248 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
249 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
250 |
+
|
251 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
252 |
+
dim=0, keepdim=True)
|
253 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
254 |
+
dim=0, keepdim=True)
|
255 |
+
|
256 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
257 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
258 |
+
|
259 |
+
if tune_mm_mlp_adapter:
|
260 |
+
self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
|
261 |
+
for p in self.get_input_embeddings().parameters():
|
262 |
+
p.requires_grad = True
|
263 |
+
for p in self.get_output_embeddings().parameters():
|
264 |
+
p.requires_grad = False
|
265 |
+
|
266 |
+
if pretrain_mm_mlp_adapter:
|
267 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
268 |
+
embed_tokens_weight = mm_projector_weights['transformer.wte.weight']
|
269 |
+
assert num_new_tokens == 2
|
270 |
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
271 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
272 |
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
273 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
274 |
+
else:
|
275 |
+
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
276 |
+
|
277 |
+
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
278 |
+
|
279 |
+
AutoConfig.register("llava_mpt", LlavaMPTConfig)
|
280 |
+
AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM)
|
diffusion/model/llava/mpt/attention.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Attention layers."""
|
2 |
+
import math
|
3 |
+
import warnings
|
4 |
+
from typing import Optional
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from einops import rearrange
|
8 |
+
from torch import nn
|
9 |
+
from .norm import LPLayerNorm
|
10 |
+
|
11 |
+
def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
|
12 |
+
if original_is_causal and num_query_tokens != num_key_tokens:
|
13 |
+
if num_query_tokens != 1:
|
14 |
+
raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.')
|
15 |
+
else:
|
16 |
+
return False
|
17 |
+
return original_is_causal
|
18 |
+
|
19 |
+
def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
20 |
+
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
|
21 |
+
k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads)
|
22 |
+
v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads)
|
23 |
+
min_val = torch.finfo(q.dtype).min
|
24 |
+
(b, _, s_q, d) = q.shape
|
25 |
+
s_k = k.size(-1)
|
26 |
+
if softmax_scale is None:
|
27 |
+
softmax_scale = 1 / math.sqrt(d)
|
28 |
+
attn_weight = q.matmul(k) * softmax_scale
|
29 |
+
if attn_bias is not None:
|
30 |
+
if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
|
31 |
+
raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
|
32 |
+
attn_weight = attn_weight + attn_bias
|
33 |
+
if key_padding_mask is not None:
|
34 |
+
if attn_bias is not None:
|
35 |
+
warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
|
36 |
+
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
37 |
+
if is_causal:
|
38 |
+
s = max(s_q, s_k)
|
39 |
+
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
|
40 |
+
causal_mask = causal_mask.tril()
|
41 |
+
causal_mask = causal_mask.to(torch.bool)
|
42 |
+
causal_mask = ~causal_mask
|
43 |
+
causal_mask = causal_mask[-s_q:, -s_k:]
|
44 |
+
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
|
45 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
46 |
+
if dropout_p:
|
47 |
+
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
|
48 |
+
out = attn_weight.matmul(v)
|
49 |
+
out = rearrange(out, 'b h s d -> b s (h d)')
|
50 |
+
if needs_weights:
|
51 |
+
return (out, attn_weight)
|
52 |
+
return (out, None)
|
53 |
+
|
54 |
+
def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|
55 |
+
for tensor in tensors:
|
56 |
+
if tensor.dtype not in valid_dtypes:
|
57 |
+
raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')
|
58 |
+
if not tensor.is_cuda:
|
59 |
+
raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
|
60 |
+
|
61 |
+
def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
62 |
+
try:
|
63 |
+
from flash_attn import bert_padding, flash_attn_interface
|
64 |
+
except:
|
65 |
+
raise RuntimeError('Please install flash-attn==1.0.3.post0')
|
66 |
+
check_valid_inputs(query, key, value)
|
67 |
+
if attn_bias is not None:
|
68 |
+
raise NotImplementedError(f'attn_bias not implemented for flash attn.')
|
69 |
+
(batch_size, seqlen) = query.shape[:2]
|
70 |
+
if key_padding_mask is None:
|
71 |
+
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
|
72 |
+
query_padding_mask = key_padding_mask[:, -query.size(1):]
|
73 |
+
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
|
74 |
+
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
|
75 |
+
(key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
|
76 |
+
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
|
77 |
+
(value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
|
78 |
+
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
|
79 |
+
if multiquery:
|
80 |
+
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
|
81 |
+
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
|
82 |
+
dropout_p = dropout_p if training else 0.0
|
83 |
+
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
84 |
+
output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
|
85 |
+
output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
|
86 |
+
return (output, None)
|
87 |
+
|
88 |
+
def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
89 |
+
try:
|
90 |
+
from flash_attn import flash_attn_triton
|
91 |
+
except:
|
92 |
+
raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
|
93 |
+
check_valid_inputs(query, key, value)
|
94 |
+
if dropout_p:
|
95 |
+
raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
|
96 |
+
if needs_weights:
|
97 |
+
raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')
|
98 |
+
if key_padding_mask is not None:
|
99 |
+
warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
|
100 |
+
(b_size, s_k) = key_padding_mask.shape[:2]
|
101 |
+
if attn_bias is None:
|
102 |
+
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
|
103 |
+
attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
|
104 |
+
query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
|
105 |
+
key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
|
106 |
+
value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
|
107 |
+
if multiquery:
|
108 |
+
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
|
109 |
+
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
|
110 |
+
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
111 |
+
attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
|
112 |
+
output = attn_output.view(*attn_output.shape[:2], -1)
|
113 |
+
return (output, None)
|
114 |
+
|
115 |
+
class MultiheadAttention(nn.Module):
|
116 |
+
"""Multi-head self attention.
|
117 |
+
|
118 |
+
Using torch or triton attention implemetation enables user to also use
|
119 |
+
additive bias.
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
|
123 |
+
super().__init__()
|
124 |
+
self.attn_impl = attn_impl
|
125 |
+
self.clip_qkv = clip_qkv
|
126 |
+
self.qk_ln = qk_ln
|
127 |
+
self.d_model = d_model
|
128 |
+
self.n_heads = n_heads
|
129 |
+
self.softmax_scale = softmax_scale
|
130 |
+
if self.softmax_scale is None:
|
131 |
+
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
132 |
+
self.attn_dropout_p = attn_pdrop
|
133 |
+
self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
|
134 |
+
fuse_splits = (d_model, 2 * d_model)
|
135 |
+
self.Wqkv._fused = (0, fuse_splits)
|
136 |
+
if self.qk_ln:
|
137 |
+
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
138 |
+
self.q_ln = layernorm_class(self.d_model, device=device)
|
139 |
+
self.k_ln = layernorm_class(self.d_model, device=device)
|
140 |
+
if self.attn_impl == 'flash':
|
141 |
+
self.attn_fn = flash_attn_fn
|
142 |
+
elif self.attn_impl == 'triton':
|
143 |
+
self.attn_fn = triton_flash_attn_fn
|
144 |
+
warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
|
145 |
+
elif self.attn_impl == 'torch':
|
146 |
+
self.attn_fn = scaled_multihead_dot_product_attention
|
147 |
+
if torch.cuda.is_available():
|
148 |
+
warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
|
149 |
+
else:
|
150 |
+
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
151 |
+
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
|
152 |
+
self.out_proj._is_residual = True
|
153 |
+
|
154 |
+
def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
|
155 |
+
qkv = self.Wqkv(x)
|
156 |
+
if self.clip_qkv:
|
157 |
+
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
158 |
+
(query, key, value) = qkv.chunk(3, dim=2)
|
159 |
+
key_padding_mask = attention_mask
|
160 |
+
if self.qk_ln:
|
161 |
+
dtype = query.dtype
|
162 |
+
query = self.q_ln(query).to(dtype)
|
163 |
+
key = self.k_ln(key).to(dtype)
|
164 |
+
if past_key_value is not None:
|
165 |
+
if len(past_key_value) != 0:
|
166 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
167 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
168 |
+
past_key_value = (key, value)
|
169 |
+
if attn_bias is not None:
|
170 |
+
attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
|
171 |
+
(context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
|
172 |
+
return (self.out_proj(context), attn_weights, past_key_value)
|
173 |
+
|
174 |
+
class MultiQueryAttention(nn.Module):
|
175 |
+
"""Multi-Query self attention.
|
176 |
+
|
177 |
+
Using torch or triton attention implemetation enables user to also use
|
178 |
+
additive bias.
|
179 |
+
"""
|
180 |
+
|
181 |
+
def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
|
182 |
+
super().__init__()
|
183 |
+
self.attn_impl = attn_impl
|
184 |
+
self.clip_qkv = clip_qkv
|
185 |
+
self.qk_ln = qk_ln
|
186 |
+
self.d_model = d_model
|
187 |
+
self.n_heads = n_heads
|
188 |
+
self.head_dim = d_model // n_heads
|
189 |
+
self.softmax_scale = softmax_scale
|
190 |
+
if self.softmax_scale is None:
|
191 |
+
self.softmax_scale = 1 / math.sqrt(self.head_dim)
|
192 |
+
self.attn_dropout_p = attn_pdrop
|
193 |
+
self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
|
194 |
+
fuse_splits = (d_model, d_model + self.head_dim)
|
195 |
+
self.Wqkv._fused = (0, fuse_splits)
|
196 |
+
if self.qk_ln:
|
197 |
+
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
198 |
+
self.q_ln = layernorm_class(d_model, device=device)
|
199 |
+
self.k_ln = layernorm_class(self.head_dim, device=device)
|
200 |
+
if self.attn_impl == 'flash':
|
201 |
+
self.attn_fn = flash_attn_fn
|
202 |
+
elif self.attn_impl == 'triton':
|
203 |
+
self.attn_fn = triton_flash_attn_fn
|
204 |
+
warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
|
205 |
+
elif self.attn_impl == 'torch':
|
206 |
+
self.attn_fn = scaled_multihead_dot_product_attention
|
207 |
+
if torch.cuda.is_available():
|
208 |
+
warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
|
209 |
+
else:
|
210 |
+
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
211 |
+
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
|
212 |
+
self.out_proj._is_residual = True
|
213 |
+
|
214 |
+
def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
|
215 |
+
qkv = self.Wqkv(x)
|
216 |
+
if self.clip_qkv:
|
217 |
+
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
218 |
+
(query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)
|
219 |
+
key_padding_mask = attention_mask
|
220 |
+
if self.qk_ln:
|
221 |
+
dtype = query.dtype
|
222 |
+
query = self.q_ln(query).to(dtype)
|
223 |
+
key = self.k_ln(key).to(dtype)
|
224 |
+
if past_key_value is not None:
|
225 |
+
if len(past_key_value) != 0:
|
226 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
227 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
228 |
+
past_key_value = (key, value)
|
229 |
+
if attn_bias is not None:
|
230 |
+
attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
|
231 |
+
(context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
|
232 |
+
return (self.out_proj(context), attn_weights, past_key_value)
|
233 |
+
|
234 |
+
def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
|
235 |
+
if attn_impl == 'flash':
|
236 |
+
return None
|
237 |
+
elif attn_impl in ['torch', 'triton']:
|
238 |
+
if alibi:
|
239 |
+
if (prefix_lm or not causal) or use_sequence_id:
|
240 |
+
return (1, n_heads, seq_len, seq_len)
|
241 |
+
return (1, n_heads, 1, seq_len)
|
242 |
+
elif prefix_lm or use_sequence_id:
|
243 |
+
return (1, 1, seq_len, seq_len)
|
244 |
+
return None
|
245 |
+
else:
|
246 |
+
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
247 |
+
|
248 |
+
def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
|
249 |
+
if attn_impl == 'flash':
|
250 |
+
return None
|
251 |
+
elif attn_impl in ['torch', 'triton']:
|
252 |
+
if alibi:
|
253 |
+
(device, dtype) = (attn_bias.device, attn_bias.dtype)
|
254 |
+
attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
|
255 |
+
return attn_bias
|
256 |
+
else:
|
257 |
+
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
258 |
+
|
259 |
+
def gen_slopes(n_heads, alibi_bias_max=8, device=None):
|
260 |
+
_n_heads = 2 ** math.ceil(math.log2(n_heads))
|
261 |
+
m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
|
262 |
+
m = m.mul(alibi_bias_max / _n_heads)
|
263 |
+
slopes = 1.0 / torch.pow(2, m)
|
264 |
+
if _n_heads != n_heads:
|
265 |
+
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
|
266 |
+
return slopes.view(1, n_heads, 1, 1)
|
267 |
+
|
268 |
+
def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None):
|
269 |
+
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
|
270 |
+
if full:
|
271 |
+
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
|
272 |
+
alibi_bias = alibi_bias.abs().mul(-1)
|
273 |
+
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
|
274 |
+
alibi_bias = alibi_bias * slopes
|
275 |
+
return alibi_bias.to(dtype=dtype)
|
276 |
+
ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
|
diffusion/model/llava/mpt/blocks.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""GPT Blocks used for the GPT Model."""
|
2 |
+
from typing import Dict, Optional, Tuple
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from .attention import ATTN_CLASS_REGISTRY
|
6 |
+
from .norm import NORM_CLASS_REGISTRY
|
7 |
+
|
8 |
+
class MPTMLP(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
|
11 |
+
super().__init__()
|
12 |
+
self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
|
13 |
+
self.act = nn.GELU(approximate='none')
|
14 |
+
self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
|
15 |
+
self.down_proj._is_residual = True
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return self.down_proj(self.act(self.up_proj(x)))
|
19 |
+
|
20 |
+
class MPTBlock(nn.Module):
|
21 |
+
|
22 |
+
def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
|
23 |
+
del kwargs
|
24 |
+
super().__init__()
|
25 |
+
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
26 |
+
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
|
27 |
+
self.norm_1 = norm_class(d_model, device=device)
|
28 |
+
self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
|
29 |
+
self.norm_2 = norm_class(d_model, device=device)
|
30 |
+
self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
|
31 |
+
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
32 |
+
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
33 |
+
|
34 |
+
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
|
35 |
+
a = self.norm_1(x)
|
36 |
+
(b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
|
37 |
+
x = x + self.resid_attn_dropout(b)
|
38 |
+
m = self.norm_2(x)
|
39 |
+
n = self.ffn(m)
|
40 |
+
x = x + self.resid_ffn_dropout(n)
|
41 |
+
return (x, past_key_value)
|
diffusion/model/llava/mpt/configuration_mpt.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A HuggingFace-style model configuration."""
|
2 |
+
from typing import Dict, Optional, Union
|
3 |
+
from transformers import PretrainedConfig
|
4 |
+
attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
|
5 |
+
init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'}
|
6 |
+
|
7 |
+
class MPTConfig(PretrainedConfig):
|
8 |
+
model_type = 'mpt'
|
9 |
+
|
10 |
+
def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs):
|
11 |
+
"""The MPT configuration class.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
d_model (int): The size of the embedding dimension of the model.
|
15 |
+
n_heads (int): The number of attention heads.
|
16 |
+
n_layers (int): The number of layers in the model.
|
17 |
+
expansion_ratio (int): The ratio of the up/down scale in the MLP.
|
18 |
+
max_seq_len (int): The maximum sequence length of the model.
|
19 |
+
vocab_size (int): The size of the vocabulary.
|
20 |
+
resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
|
21 |
+
emb_pdrop (float): The dropout probability for the embedding layer.
|
22 |
+
learned_pos_emb (bool): Whether to use learned positional embeddings
|
23 |
+
attn_config (Dict): A dictionary used to configure the model's attention module:
|
24 |
+
attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention
|
25 |
+
attn_pdrop (float): The dropout probability for the attention layers.
|
26 |
+
attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
|
27 |
+
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
|
28 |
+
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
|
29 |
+
this value.
|
30 |
+
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
|
31 |
+
use the default scale of ``1/sqrt(d_keys)``.
|
32 |
+
prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
|
33 |
+
extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
|
34 |
+
can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
|
35 |
+
attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
|
36 |
+
When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
|
37 |
+
which sub-sequence each token belongs to.
|
38 |
+
Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
|
39 |
+
alibi (bool): Whether to use the alibi bias instead of position embeddings.
|
40 |
+
alibi_bias_max (int): The maximum value of the alibi bias.
|
41 |
+
init_device (str): The device to use for parameter initialization.
|
42 |
+
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
|
43 |
+
no_bias (bool): Whether to use bias in all layers.
|
44 |
+
verbose (int): The verbosity level. 0 is silent.
|
45 |
+
embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
|
46 |
+
norm_type (str): choose type of norm to use
|
47 |
+
multiquery_attention (bool): Whether to use multiquery attention implementation.
|
48 |
+
use_cache (bool): Whether or not the model should return the last key/values attentions
|
49 |
+
init_config (Dict): A dictionary used to configure the model initialization:
|
50 |
+
init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
|
51 |
+
'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
|
52 |
+
'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
|
53 |
+
init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
|
54 |
+
emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
|
55 |
+
emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
|
56 |
+
used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
|
57 |
+
init_std (float): The standard deviation of the normal distribution used to initialize the model,
|
58 |
+
if using the baseline_ parameter initialization scheme.
|
59 |
+
init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
|
60 |
+
fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
|
61 |
+
init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
|
62 |
+
---
|
63 |
+
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
|
64 |
+
"""
|
65 |
+
self.d_model = d_model
|
66 |
+
self.n_heads = n_heads
|
67 |
+
self.n_layers = n_layers
|
68 |
+
self.expansion_ratio = expansion_ratio
|
69 |
+
self.max_seq_len = max_seq_len
|
70 |
+
self.vocab_size = vocab_size
|
71 |
+
self.resid_pdrop = resid_pdrop
|
72 |
+
self.emb_pdrop = emb_pdrop
|
73 |
+
self.learned_pos_emb = learned_pos_emb
|
74 |
+
self.attn_config = attn_config
|
75 |
+
self.init_device = init_device
|
76 |
+
self.logit_scale = logit_scale
|
77 |
+
self.no_bias = no_bias
|
78 |
+
self.verbose = verbose
|
79 |
+
self.embedding_fraction = embedding_fraction
|
80 |
+
self.norm_type = norm_type
|
81 |
+
self.use_cache = use_cache
|
82 |
+
self.init_config = init_config
|
83 |
+
if 'name' in kwargs:
|
84 |
+
del kwargs['name']
|
85 |
+
if 'loss_fn' in kwargs:
|
86 |
+
del kwargs['loss_fn']
|
87 |
+
super().__init__(**kwargs)
|
88 |
+
self._validate_config()
|
89 |
+
|
90 |
+
def _set_config_defaults(self, config, config_defaults):
|
91 |
+
for (k, v) in config_defaults.items():
|
92 |
+
if k not in config:
|
93 |
+
config[k] = v
|
94 |
+
return config
|
95 |
+
|
96 |
+
def _validate_config(self):
|
97 |
+
self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
|
98 |
+
self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
|
99 |
+
if self.d_model % self.n_heads != 0:
|
100 |
+
raise ValueError('d_model must be divisible by n_heads')
|
101 |
+
if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
|
102 |
+
raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
|
103 |
+
if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
|
104 |
+
raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
|
105 |
+
if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
|
106 |
+
raise NotImplementedError('prefix_lm only implemented with torch and triton attention.')
|
107 |
+
if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
|
108 |
+
raise NotImplementedError('alibi only implemented with torch and triton attention.')
|
109 |
+
if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
|
110 |
+
raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.')
|
111 |
+
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
|
112 |
+
raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')
|
113 |
+
if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model':
|
114 |
+
raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
|
115 |
+
if self.init_config.get('name', None) is None:
|
116 |
+
raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
|
117 |
+
if not self.learned_pos_emb and (not self.attn_config['alibi']):
|
118 |
+
raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.')
|
diffusion/model/llava/mpt/modeling_mpt.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A simple, flexible implementation of a GPT model.
|
2 |
+
|
3 |
+
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
4 |
+
"""
|
5 |
+
import math
|
6 |
+
import warnings
|
7 |
+
from typing import List, Optional, Tuple, Union
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
12 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
13 |
+
from .attention import attn_bias_shape, build_attn_bias
|
14 |
+
from .blocks import MPTBlock
|
15 |
+
from .norm import NORM_CLASS_REGISTRY
|
16 |
+
from .configuration_mpt import MPTConfig
|
17 |
+
from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
|
18 |
+
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
19 |
+
|
20 |
+
from transformers.utils import logging
|
21 |
+
logger = logging.get_logger(__name__)
|
22 |
+
|
23 |
+
class MPTPreTrainedModel(PreTrainedModel):
|
24 |
+
config_class = MPTConfig
|
25 |
+
base_model_prefix = 'model'
|
26 |
+
|
27 |
+
class MPTModel(MPTPreTrainedModel):
|
28 |
+
|
29 |
+
def __init__(self, config: MPTConfig):
|
30 |
+
config._validate_config()
|
31 |
+
super().__init__(config)
|
32 |
+
self.attn_impl = config.attn_config['attn_impl']
|
33 |
+
self.prefix_lm = config.attn_config['prefix_lm']
|
34 |
+
self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
|
35 |
+
self.alibi = config.attn_config['alibi']
|
36 |
+
self.alibi_bias_max = config.attn_config['alibi_bias_max']
|
37 |
+
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
|
38 |
+
norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
|
39 |
+
raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
|
40 |
+
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
|
41 |
+
self.embedding_fraction = config.embedding_fraction
|
42 |
+
self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device)
|
43 |
+
if not self.alibi:
|
44 |
+
self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
|
45 |
+
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
46 |
+
self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
|
47 |
+
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
48 |
+
if config.init_device != 'meta':
|
49 |
+
self.apply(self.param_init_fn)
|
50 |
+
self.is_causal = not self.prefix_lm
|
51 |
+
self._attn_bias_initialized = False
|
52 |
+
self.attn_bias = None
|
53 |
+
self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id)
|
54 |
+
if config.no_bias:
|
55 |
+
for module in self.modules():
|
56 |
+
if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
|
57 |
+
if config.verbose:
|
58 |
+
warnings.warn(f'Removing bias ({module.bias}) from {module}.')
|
59 |
+
module.register_parameter('bias', None)
|
60 |
+
if config.verbose and config.verbose > 2:
|
61 |
+
print(self)
|
62 |
+
if 'verbose' not in self.config.init_config:
|
63 |
+
self.config.init_config['verbose'] = self.config.verbose
|
64 |
+
if self.config.init_config['verbose'] > 1:
|
65 |
+
init_fn_name = self.config.init_config['name']
|
66 |
+
warnings.warn(f'Using {init_fn_name} initialization.')
|
67 |
+
self.gradient_checkpointing = False
|
68 |
+
|
69 |
+
def get_input_embeddings(self):
|
70 |
+
return self.wte
|
71 |
+
|
72 |
+
def set_input_embeddings(self, value):
|
73 |
+
self.wte = value
|
74 |
+
|
75 |
+
@torch.no_grad()
|
76 |
+
def _attn_bias(self, device, dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None):
|
77 |
+
if not self._attn_bias_initialized:
|
78 |
+
if self.attn_bias_shape:
|
79 |
+
self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
|
80 |
+
self.attn_bias = build_attn_bias(self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max)
|
81 |
+
self._attn_bias_initialized = True
|
82 |
+
if self.attn_impl == 'flash':
|
83 |
+
return (self.attn_bias, attention_mask)
|
84 |
+
if self.attn_bias is not None:
|
85 |
+
self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
|
86 |
+
attn_bias = self.attn_bias
|
87 |
+
if self.prefix_lm:
|
88 |
+
assert isinstance(attn_bias, torch.Tensor)
|
89 |
+
assert isinstance(prefix_mask, torch.Tensor)
|
90 |
+
attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
|
91 |
+
if self.attn_uses_sequence_id and sequence_id is not None:
|
92 |
+
assert isinstance(attn_bias, torch.Tensor)
|
93 |
+
attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
|
94 |
+
if attention_mask is not None:
|
95 |
+
s_k = attention_mask.shape[-1]
|
96 |
+
if attn_bias is None:
|
97 |
+
attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
|
98 |
+
else:
|
99 |
+
attn_bias = attn_bias[:, :, :, -s_k:]
|
100 |
+
if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
|
101 |
+
raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
|
102 |
+
min_val = torch.finfo(attn_bias.dtype).min
|
103 |
+
attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
|
104 |
+
return (attn_bias, None)
|
105 |
+
|
106 |
+
def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
|
107 |
+
(s_k, s_q) = attn_bias.shape[-2:]
|
108 |
+
if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
|
109 |
+
raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.')
|
110 |
+
seq_len = prefix_mask.shape[-1]
|
111 |
+
if seq_len > self.config.max_seq_len:
|
112 |
+
raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
|
113 |
+
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
114 |
+
causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
|
115 |
+
prefix = prefix_mask.view(-1, 1, 1, seq_len)
|
116 |
+
cannot_attend = ~torch.logical_or(causal, prefix.bool())
|
117 |
+
min_val = torch.finfo(attn_bias.dtype).min
|
118 |
+
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
119 |
+
return attn_bias
|
120 |
+
|
121 |
+
def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):
|
122 |
+
seq_len = sequence_id.shape[-1]
|
123 |
+
if seq_len > self.config.max_seq_len:
|
124 |
+
raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
|
125 |
+
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
126 |
+
cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1)
|
127 |
+
min_val = torch.finfo(attn_bias.dtype).min
|
128 |
+
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
129 |
+
return attn_bias
|
130 |
+
|
131 |
+
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, tok_emb: Optional[torch.FloatTensor]=None):
|
132 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
133 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
134 |
+
|
135 |
+
if self.gradient_checkpointing and self.training:
|
136 |
+
if use_cache:
|
137 |
+
logger.warning_once(
|
138 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
139 |
+
)
|
140 |
+
use_cache = False
|
141 |
+
if attention_mask is not None:
|
142 |
+
attention_mask = attention_mask.bool()
|
143 |
+
if prefix_mask is not None:
|
144 |
+
prefix_mask = prefix_mask.bool()
|
145 |
+
if not return_dict:
|
146 |
+
raise NotImplementedError('return_dict False is not implemented yet for MPT')
|
147 |
+
if output_attentions:
|
148 |
+
raise NotImplementedError('output_attentions is not implemented yet for MPT')
|
149 |
+
if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
|
150 |
+
raise NotImplementedError('MPT does not support training with left padding.')
|
151 |
+
if self.prefix_lm and prefix_mask is None:
|
152 |
+
raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
|
153 |
+
if self.training:
|
154 |
+
if self.attn_uses_sequence_id and sequence_id is None:
|
155 |
+
raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
|
156 |
+
elif self.attn_uses_sequence_id is False and sequence_id is not None:
|
157 |
+
warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
|
158 |
+
if input_ids is not None:
|
159 |
+
S = input_ids.size(1)
|
160 |
+
assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
|
161 |
+
tok_emb = self.wte(input_ids)
|
162 |
+
else:
|
163 |
+
assert tok_emb is not None
|
164 |
+
S = tok_emb.size(1)
|
165 |
+
if self.alibi:
|
166 |
+
x = tok_emb
|
167 |
+
else:
|
168 |
+
past_position = 0
|
169 |
+
if past_key_values is not None:
|
170 |
+
if len(past_key_values) != self.config.n_layers:
|
171 |
+
raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
|
172 |
+
past_position = past_key_values[0][0].size(1)
|
173 |
+
if S + past_position > self.config.max_seq_len:
|
174 |
+
raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
|
175 |
+
pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
|
176 |
+
if attention_mask is not None:
|
177 |
+
pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
|
178 |
+
pos_emb = self.wpe(pos)
|
179 |
+
x = tok_emb + pos_emb
|
180 |
+
if self.embedding_fraction == 1:
|
181 |
+
x = self.emb_drop(x)
|
182 |
+
else:
|
183 |
+
x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
|
184 |
+
assert isinstance(self.emb_drop, nn.Module)
|
185 |
+
x = self.emb_drop(x_shrunk)
|
186 |
+
(attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
|
187 |
+
if use_cache and past_key_values is None:
|
188 |
+
past_key_values = [() for _ in range(self.config.n_layers)]
|
189 |
+
all_hidden_states = () if output_hidden_states else None
|
190 |
+
for (b_idx, block) in enumerate(self.blocks):
|
191 |
+
if output_hidden_states:
|
192 |
+
assert all_hidden_states is not None
|
193 |
+
all_hidden_states = all_hidden_states + (x,)
|
194 |
+
past_key_value = past_key_values[b_idx] if past_key_values is not None else None
|
195 |
+
if self.gradient_checkpointing and self.training:
|
196 |
+
(x, past_key_value) = torch.utils.checkpoint.checkpoint(
|
197 |
+
block,
|
198 |
+
x, past_key_value, attn_bias, attention_mask, self.is_causal
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
(x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
|
202 |
+
if past_key_values is not None:
|
203 |
+
past_key_values[b_idx] = past_key_value
|
204 |
+
x = self.norm_f(x)
|
205 |
+
return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
|
206 |
+
|
207 |
+
def param_init_fn(self, module):
|
208 |
+
init_fn_name = self.config.init_config['name']
|
209 |
+
MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
|
210 |
+
|
211 |
+
def fsdp_wrap_fn(self, module):
|
212 |
+
return isinstance(module, MPTBlock)
|
213 |
+
|
214 |
+
def activation_checkpointing_fn(self, module):
|
215 |
+
return isinstance(module, MPTBlock)
|
216 |
+
|
217 |
+
class MPTForCausalLM(MPTPreTrainedModel):
|
218 |
+
|
219 |
+
def __init__(self, config: MPTConfig):
|
220 |
+
super().__init__(config)
|
221 |
+
if not config.tie_word_embeddings:
|
222 |
+
raise ValueError('MPTForCausalLM only supports tied word embeddings')
|
223 |
+
self.transformer = MPTModel(config)
|
224 |
+
self.logit_scale = None
|
225 |
+
if config.logit_scale is not None:
|
226 |
+
logit_scale = config.logit_scale
|
227 |
+
if isinstance(logit_scale, str):
|
228 |
+
if logit_scale == 'inv_sqrt_d_model':
|
229 |
+
logit_scale = 1 / math.sqrt(config.d_model)
|
230 |
+
else:
|
231 |
+
raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
|
232 |
+
self.logit_scale = logit_scale
|
233 |
+
|
234 |
+
def get_input_embeddings(self):
|
235 |
+
return self.transformer.wte
|
236 |
+
|
237 |
+
def set_input_embeddings(self, value):
|
238 |
+
self.transformer.wte = value
|
239 |
+
|
240 |
+
def get_output_embeddings(self):
|
241 |
+
return self.transformer.wte
|
242 |
+
|
243 |
+
def set_output_embeddings(self, new_embeddings):
|
244 |
+
self.transformer.wte = new_embeddings
|
245 |
+
|
246 |
+
def set_decoder(self, decoder):
|
247 |
+
self.transformer = decoder
|
248 |
+
|
249 |
+
def get_decoder(self):
|
250 |
+
return self.transformer
|
251 |
+
|
252 |
+
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
|
253 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
254 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
255 |
+
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
256 |
+
logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
|
257 |
+
if self.logit_scale is not None:
|
258 |
+
if self.logit_scale == 0:
|
259 |
+
warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
|
260 |
+
logits *= self.logit_scale
|
261 |
+
loss = None
|
262 |
+
if labels is not None:
|
263 |
+
labels = torch.roll(labels, shifts=-1)
|
264 |
+
labels[:, -1] = -100
|
265 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
|
266 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
|
267 |
+
|
268 |
+
def param_init_fn(self, module):
|
269 |
+
init_fn_name = self.config.init_config['name']
|
270 |
+
MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
|
271 |
+
|
272 |
+
def fsdp_wrap_fn(self, module):
|
273 |
+
return isinstance(module, MPTBlock)
|
274 |
+
|
275 |
+
def activation_checkpointing_fn(self, module):
|
276 |
+
return isinstance(module, MPTBlock)
|
277 |
+
|
278 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
279 |
+
if inputs_embeds is not None:
|
280 |
+
raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
|
281 |
+
attention_mask = kwargs['attention_mask'].bool()
|
282 |
+
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
283 |
+
raise NotImplementedError('MPT does not support generation with right padding.')
|
284 |
+
if self.transformer.attn_uses_sequence_id and self.training:
|
285 |
+
sequence_id = torch.zeros_like(input_ids[:1])
|
286 |
+
else:
|
287 |
+
sequence_id = None
|
288 |
+
if past_key_values is not None:
|
289 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
290 |
+
if self.transformer.prefix_lm:
|
291 |
+
prefix_mask = torch.ones_like(attention_mask)
|
292 |
+
if kwargs.get('use_cache') == False:
|
293 |
+
raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
|
294 |
+
else:
|
295 |
+
prefix_mask = None
|
296 |
+
return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)}
|
297 |
+
|
298 |
+
@staticmethod
|
299 |
+
def _reorder_cache(past_key_values, beam_idx):
|
300 |
+
"""Used by HuggingFace generate when using beam search with kv-caching.
|
301 |
+
|
302 |
+
See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
|
303 |
+
for an example in transformers.
|
304 |
+
"""
|
305 |
+
reordered_past = []
|
306 |
+
for layer_past in past_key_values:
|
307 |
+
reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
|
308 |
+
return reordered_past
|
diffusion/model/llava/mpt/norm.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def _cast_if_autocast_enabled(tensor):
|
4 |
+
if torch.is_autocast_enabled():
|
5 |
+
if tensor.device.type == 'cuda':
|
6 |
+
dtype = torch.get_autocast_gpu_dtype()
|
7 |
+
elif tensor.device.type == 'cpu':
|
8 |
+
dtype = torch.get_autocast_cpu_dtype()
|
9 |
+
else:
|
10 |
+
raise NotImplementedError()
|
11 |
+
return tensor.to(dtype=dtype)
|
12 |
+
return tensor
|
13 |
+
|
14 |
+
class LPLayerNorm(torch.nn.LayerNorm):
|
15 |
+
|
16 |
+
def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
|
17 |
+
super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
module_device = x.device
|
21 |
+
downcast_x = _cast_if_autocast_enabled(x)
|
22 |
+
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
23 |
+
downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
24 |
+
with torch.autocast(enabled=False, device_type=module_device.type):
|
25 |
+
return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
|
26 |
+
|
27 |
+
def rms_norm(x, weight=None, eps=1e-05):
|
28 |
+
output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
29 |
+
if weight is not None:
|
30 |
+
return output * weight
|
31 |
+
return output
|
32 |
+
|
33 |
+
class RMSNorm(torch.nn.Module):
|
34 |
+
|
35 |
+
def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
|
36 |
+
super().__init__()
|
37 |
+
self.eps = eps
|
38 |
+
if weight:
|
39 |
+
self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
|
40 |
+
else:
|
41 |
+
self.register_parameter('weight', None)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
|
45 |
+
|
46 |
+
class LPRMSNorm(RMSNorm):
|
47 |
+
|
48 |
+
def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
|
49 |
+
super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
downcast_x = _cast_if_autocast_enabled(x)
|
53 |
+
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
54 |
+
with torch.autocast(enabled=False, device_type=x.device.type):
|
55 |
+
return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
|
56 |
+
NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
|
diffusion/model/llava/mpt/param_init_fns.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from collections.abc import Sequence
|
4 |
+
from functools import partial
|
5 |
+
from typing import Optional, Tuple, Union
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from .norm import NORM_CLASS_REGISTRY
|
9 |
+
|
10 |
+
def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs):
|
11 |
+
del kwargs
|
12 |
+
if verbose > 1:
|
13 |
+
warnings.warn(f"Initializing network using module's reset_parameters attribute")
|
14 |
+
if hasattr(module, 'reset_parameters'):
|
15 |
+
module.reset_parameters()
|
16 |
+
|
17 |
+
def fused_init_helper_(module: nn.Module, init_fn_):
|
18 |
+
_fused = getattr(module, '_fused', None)
|
19 |
+
if _fused is None:
|
20 |
+
raise RuntimeError(f'Internal logic error')
|
21 |
+
(dim, splits) = _fused
|
22 |
+
splits = (0, *splits, module.weight.size(dim))
|
23 |
+
for (s, e) in zip(splits[:-1], splits[1:]):
|
24 |
+
slice_indices = [slice(None)] * module.weight.ndim
|
25 |
+
slice_indices[dim] = slice(s, e)
|
26 |
+
init_fn_(module.weight[slice_indices])
|
27 |
+
|
28 |
+
def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
|
29 |
+
del kwargs
|
30 |
+
if verbose > 1:
|
31 |
+
warnings.warn(f'If model has bias parameters they are initialized to 0.')
|
32 |
+
init_div_is_residual = init_div_is_residual
|
33 |
+
if init_div_is_residual is False:
|
34 |
+
div_is_residual = 1.0
|
35 |
+
elif init_div_is_residual is True:
|
36 |
+
div_is_residual = math.sqrt(2 * n_layers)
|
37 |
+
elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
|
38 |
+
div_is_residual = init_div_is_residual
|
39 |
+
elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
|
40 |
+
div_is_residual = float(init_div_is_residual)
|
41 |
+
else:
|
42 |
+
div_is_residual = 1.0
|
43 |
+
raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
|
44 |
+
if init_div_is_residual is not False:
|
45 |
+
if verbose > 1:
|
46 |
+
warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.')
|
47 |
+
if isinstance(module, nn.Linear):
|
48 |
+
if hasattr(module, '_fused'):
|
49 |
+
fused_init_helper_(module, init_fn_)
|
50 |
+
else:
|
51 |
+
init_fn_(module.weight)
|
52 |
+
if module.bias is not None:
|
53 |
+
torch.nn.init.zeros_(module.bias)
|
54 |
+
if init_div_is_residual is not False and getattr(module, '_is_residual', False):
|
55 |
+
with torch.no_grad():
|
56 |
+
module.weight.div_(div_is_residual)
|
57 |
+
elif isinstance(module, nn.Embedding):
|
58 |
+
if emb_init_std is not None:
|
59 |
+
std = emb_init_std
|
60 |
+
if std == 0:
|
61 |
+
warnings.warn(f'Embedding layer initialized to 0.')
|
62 |
+
emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
|
63 |
+
if verbose > 1:
|
64 |
+
warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.')
|
65 |
+
elif emb_init_uniform_lim is not None:
|
66 |
+
lim = emb_init_uniform_lim
|
67 |
+
if isinstance(lim, Sequence):
|
68 |
+
if len(lim) > 2:
|
69 |
+
raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.')
|
70 |
+
if lim[0] == lim[1]:
|
71 |
+
warnings.warn(f'Embedding layer initialized to {lim[0]}.')
|
72 |
+
else:
|
73 |
+
if lim == 0:
|
74 |
+
warnings.warn(f'Embedding layer initialized to 0.')
|
75 |
+
lim = [-lim, lim]
|
76 |
+
(a, b) = lim
|
77 |
+
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
|
78 |
+
if verbose > 1:
|
79 |
+
warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.')
|
80 |
+
else:
|
81 |
+
emb_init_fn_ = init_fn_
|
82 |
+
emb_init_fn_(module.weight)
|
83 |
+
elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
|
84 |
+
if verbose > 1:
|
85 |
+
warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.')
|
86 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
87 |
+
torch.nn.init.ones_(module.weight)
|
88 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
89 |
+
torch.nn.init.zeros_(module.bias)
|
90 |
+
elif isinstance(module, nn.MultiheadAttention):
|
91 |
+
if module._qkv_same_embed_dim:
|
92 |
+
assert module.in_proj_weight is not None
|
93 |
+
assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
|
94 |
+
assert d_model is not None
|
95 |
+
_d = d_model
|
96 |
+
splits = (0, _d, 2 * _d, 3 * _d)
|
97 |
+
for (s, e) in zip(splits[:-1], splits[1:]):
|
98 |
+
init_fn_(module.in_proj_weight[s:e])
|
99 |
+
else:
|
100 |
+
assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
|
101 |
+
assert module.in_proj_weight is None
|
102 |
+
init_fn_(module.q_proj_weight)
|
103 |
+
init_fn_(module.k_proj_weight)
|
104 |
+
init_fn_(module.v_proj_weight)
|
105 |
+
if module.in_proj_bias is not None:
|
106 |
+
torch.nn.init.zeros_(module.in_proj_bias)
|
107 |
+
if module.bias_k is not None:
|
108 |
+
torch.nn.init.zeros_(module.bias_k)
|
109 |
+
if module.bias_v is not None:
|
110 |
+
torch.nn.init.zeros_(module.bias_v)
|
111 |
+
init_fn_(module.out_proj.weight)
|
112 |
+
if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False):
|
113 |
+
with torch.no_grad():
|
114 |
+
module.out_proj.weight.div_(div_is_residual)
|
115 |
+
if module.out_proj.bias is not None:
|
116 |
+
torch.nn.init.zeros_(module.out_proj.bias)
|
117 |
+
else:
|
118 |
+
for _ in module.parameters(recurse=False):
|
119 |
+
raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
|
120 |
+
|
121 |
+
def _normal_init_(std, mean=0.0):
|
122 |
+
return partial(torch.nn.init.normal_, mean=mean, std=std)
|
123 |
+
|
124 |
+
def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
|
125 |
+
del kwargs
|
126 |
+
init_fn_ = _normal_init_(std=std)
|
127 |
+
if verbose > 1:
|
128 |
+
warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
|
129 |
+
generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
130 |
+
|
131 |
+
def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
|
132 |
+
del kwargs
|
133 |
+
if init_std is None:
|
134 |
+
raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
|
135 |
+
_normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
136 |
+
|
137 |
+
def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
|
138 |
+
del kwargs
|
139 |
+
std = math.sqrt(2 / (5 * d_model))
|
140 |
+
_normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
141 |
+
|
142 |
+
def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
|
143 |
+
"""From section 2.3.1 of GPT-NeoX-20B:
|
144 |
+
|
145 |
+
An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
|
146 |
+
see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
|
147 |
+
and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
|
148 |
+
"""
|
149 |
+
del kwargs
|
150 |
+
residual_div = n_layers / math.sqrt(10)
|
151 |
+
if verbose > 1:
|
152 |
+
warnings.warn(f'setting init_div_is_residual to {residual_div}')
|
153 |
+
small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
154 |
+
|
155 |
+
def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
|
156 |
+
del kwargs
|
157 |
+
if verbose > 1:
|
158 |
+
warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
|
159 |
+
kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
160 |
+
generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
161 |
+
|
162 |
+
def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
|
163 |
+
del kwargs
|
164 |
+
if verbose > 1:
|
165 |
+
warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
|
166 |
+
kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
167 |
+
generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
168 |
+
|
169 |
+
def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
|
170 |
+
del kwargs
|
171 |
+
xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
|
172 |
+
if verbose > 1:
|
173 |
+
warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}')
|
174 |
+
generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
175 |
+
|
176 |
+
def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
|
177 |
+
xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
|
178 |
+
if verbose > 1:
|
179 |
+
warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}')
|
180 |
+
generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
181 |
+
MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
|
diffusion/model/nets/PixArt.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
9 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
10 |
+
# --------------------------------------------------------
|
11 |
+
import math
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import os
|
15 |
+
import numpy as np
|
16 |
+
from timm.models.layers import DropPath
|
17 |
+
from timm.models.vision_transformer import PatchEmbed, Mlp
|
18 |
+
|
19 |
+
from diffusion.model.builder import MODELS
|
20 |
+
from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
|
21 |
+
from diffusion.model.nets.PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, LabelEmbedder, FinalLayer
|
22 |
+
from diffusion.utils.logger import get_root_logger
|
23 |
+
|
24 |
+
|
25 |
+
class PixArtBlock(nn.Module):
|
26 |
+
"""
|
27 |
+
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0, input_size=None,
|
31 |
+
sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs):
|
32 |
+
super().__init__()
|
33 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
34 |
+
self.attn = AttentionKVCompress(
|
35 |
+
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
36 |
+
qk_norm=qk_norm, **block_kwargs
|
37 |
+
)
|
38 |
+
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
|
39 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
40 |
+
# to be compatible with lower version pytorch
|
41 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
42 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
|
43 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
44 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
45 |
+
self.sampling = sampling
|
46 |
+
self.sr_ratio = sr_ratio
|
47 |
+
|
48 |
+
def forward(self, x, y, t, mask=None, **kwargs):
|
49 |
+
B, N, C = x.shape
|
50 |
+
|
51 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
52 |
+
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
|
53 |
+
x = x + self.cross_attn(x, y, mask)
|
54 |
+
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
55 |
+
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
#############################################################################
|
60 |
+
# Core PixArt Model #
|
61 |
+
#################################################################################
|
62 |
+
@MODELS.register_module()
|
63 |
+
class PixArt(nn.Module):
|
64 |
+
"""
|
65 |
+
Diffusion model with a Transformer backbone.
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
input_size=32,
|
71 |
+
patch_size=2,
|
72 |
+
in_channels=4,
|
73 |
+
hidden_size=1152,
|
74 |
+
depth=28,
|
75 |
+
num_heads=16,
|
76 |
+
mlp_ratio=4.0,
|
77 |
+
class_dropout_prob=0.1,
|
78 |
+
pred_sigma=True,
|
79 |
+
drop_path: float = 0.,
|
80 |
+
caption_channels=4096,
|
81 |
+
pe_interpolation=1.0,
|
82 |
+
config=None,
|
83 |
+
model_max_length=120,
|
84 |
+
qk_norm=False,
|
85 |
+
kv_compress_config=None,
|
86 |
+
**kwargs,
|
87 |
+
):
|
88 |
+
super().__init__()
|
89 |
+
self.pred_sigma = pred_sigma
|
90 |
+
self.in_channels = in_channels
|
91 |
+
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
92 |
+
self.patch_size = patch_size
|
93 |
+
self.num_heads = num_heads
|
94 |
+
self.pe_interpolation = pe_interpolation
|
95 |
+
self.depth = depth
|
96 |
+
|
97 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
98 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
99 |
+
num_patches = self.x_embedder.num_patches
|
100 |
+
self.base_size = input_size // self.patch_size
|
101 |
+
# Will use fixed sin-cos embedding:
|
102 |
+
self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
|
103 |
+
|
104 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
105 |
+
self.t_block = nn.Sequential(
|
106 |
+
nn.SiLU(),
|
107 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
108 |
+
)
|
109 |
+
self.y_embedder = CaptionEmbedder(
|
110 |
+
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
|
111 |
+
act_layer=approx_gelu, token_num=model_max_length)
|
112 |
+
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
113 |
+
self.kv_compress_config = kv_compress_config
|
114 |
+
if kv_compress_config is None:
|
115 |
+
self.kv_compress_config = {
|
116 |
+
'sampling': None,
|
117 |
+
'scale_factor': 1,
|
118 |
+
'kv_compress_layer': [],
|
119 |
+
}
|
120 |
+
self.blocks = nn.ModuleList([
|
121 |
+
PixArtBlock(
|
122 |
+
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
123 |
+
input_size=(input_size // patch_size, input_size // patch_size),
|
124 |
+
sampling=self.kv_compress_config['sampling'],
|
125 |
+
sr_ratio=int(
|
126 |
+
self.kv_compress_config['scale_factor']
|
127 |
+
) if i in self.kv_compress_config['kv_compress_layer'] else 1,
|
128 |
+
qk_norm=qk_norm,
|
129 |
+
)
|
130 |
+
for i in range(depth)
|
131 |
+
])
|
132 |
+
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
|
133 |
+
|
134 |
+
self.initialize_weights()
|
135 |
+
|
136 |
+
if config:
|
137 |
+
logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
|
138 |
+
logger.warning(f"position embed interpolation: {self.pe_interpolation}, base size: {self.base_size}")
|
139 |
+
logger.warning(f"kv compress config: {self.kv_compress_config}")
|
140 |
+
else:
|
141 |
+
print(f'Warning: position embed interpolation: {self.pe_interpolation}, base size: {self.base_size}')
|
142 |
+
print(f"kv compress config: {self.kv_compress_config}")
|
143 |
+
|
144 |
+
|
145 |
+
def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
|
146 |
+
"""
|
147 |
+
Forward pass of PixArt.
|
148 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
149 |
+
t: (N,) tensor of diffusion timesteps
|
150 |
+
y: (N, 1, 120, C) tensor of class labels
|
151 |
+
"""
|
152 |
+
x = x.to(self.dtype)
|
153 |
+
timestep = timestep.to(self.dtype)
|
154 |
+
y = y.to(self.dtype)
|
155 |
+
pos_embed = self.pos_embed.to(self.dtype)
|
156 |
+
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
|
157 |
+
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
158 |
+
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
|
159 |
+
t0 = self.t_block(t)
|
160 |
+
y = self.y_embedder(y, self.training) # (N, 1, L, D)
|
161 |
+
if mask is not None:
|
162 |
+
if mask.shape[0] != y.shape[0]:
|
163 |
+
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
164 |
+
mask = mask.squeeze(1).squeeze(1)
|
165 |
+
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
166 |
+
y_lens = mask.sum(dim=1).tolist()
|
167 |
+
else:
|
168 |
+
y_lens = [y.shape[2]] * y.shape[0]
|
169 |
+
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
170 |
+
for block in self.blocks:
|
171 |
+
x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint
|
172 |
+
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
173 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
174 |
+
return x
|
175 |
+
|
176 |
+
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
|
177 |
+
"""
|
178 |
+
dpm solver donnot need variance prediction
|
179 |
+
"""
|
180 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
181 |
+
model_out = self.forward(x, timestep, y, mask)
|
182 |
+
return model_out.chunk(2, dim=1)[0]
|
183 |
+
|
184 |
+
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs):
|
185 |
+
"""
|
186 |
+
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
|
187 |
+
"""
|
188 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
189 |
+
half = x[: len(x) // 2]
|
190 |
+
combined = torch.cat([half, half], dim=0)
|
191 |
+
model_out = self.forward(combined, timestep, y, mask, kwargs)
|
192 |
+
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
|
193 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
194 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
195 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
196 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
197 |
+
return torch.cat([eps, rest], dim=1)
|
198 |
+
|
199 |
+
def unpatchify(self, x):
|
200 |
+
"""
|
201 |
+
x: (N, T, patch_size**2 * C)
|
202 |
+
imgs: (N, H, W, C)
|
203 |
+
"""
|
204 |
+
c = self.out_channels
|
205 |
+
p = self.x_embedder.patch_size[0]
|
206 |
+
h = w = int(x.shape[1] ** 0.5)
|
207 |
+
assert h * w == x.shape[1]
|
208 |
+
|
209 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
210 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
211 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
212 |
+
return imgs
|
213 |
+
|
214 |
+
def initialize_weights(self):
|
215 |
+
# Initialize transformer layers:
|
216 |
+
def _basic_init(module):
|
217 |
+
if isinstance(module, nn.Linear):
|
218 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
219 |
+
if module.bias is not None:
|
220 |
+
nn.init.constant_(module.bias, 0)
|
221 |
+
|
222 |
+
self.apply(_basic_init)
|
223 |
+
|
224 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
225 |
+
pos_embed = get_2d_sincos_pos_embed(
|
226 |
+
self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5),
|
227 |
+
pe_interpolation=self.pe_interpolation, base_size=self.base_size
|
228 |
+
)
|
229 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
230 |
+
|
231 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
232 |
+
w = self.x_embedder.proj.weight.data
|
233 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
234 |
+
|
235 |
+
# Initialize timestep embedding MLP:
|
236 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
237 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
238 |
+
nn.init.normal_(self.t_block[1].weight, std=0.02)
|
239 |
+
|
240 |
+
# Initialize caption embedding MLP:
|
241 |
+
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
|
242 |
+
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
|
243 |
+
|
244 |
+
# Zero-out adaLN modulation layers in PixArt blocks:
|
245 |
+
for block in self.blocks:
|
246 |
+
nn.init.constant_(block.cross_attn.proj.weight, 0)
|
247 |
+
nn.init.constant_(block.cross_attn.proj.bias, 0)
|
248 |
+
|
249 |
+
# Zero-out output layers:
|
250 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
251 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
252 |
+
|
253 |
+
@property
|
254 |
+
def dtype(self):
|
255 |
+
return next(self.parameters()).dtype
|
256 |
+
|
257 |
+
|
258 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0, base_size=16):
|
259 |
+
"""
|
260 |
+
grid_size: int of the grid height and width
|
261 |
+
return:
|
262 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
263 |
+
"""
|
264 |
+
if isinstance(grid_size, int):
|
265 |
+
grid_size = to_2tuple(grid_size)
|
266 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0]/base_size) / pe_interpolation
|
267 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1]/base_size) / pe_interpolation
|
268 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
269 |
+
grid = np.stack(grid, axis=0)
|
270 |
+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
271 |
+
|
272 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
273 |
+
if cls_token and extra_tokens > 0:
|
274 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
275 |
+
return pos_embed
|
276 |
+
|
277 |
+
|
278 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
279 |
+
assert embed_dim % 2 == 0
|
280 |
+
|
281 |
+
# use half of dimensions to encode grid_h
|
282 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
283 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
284 |
+
|
285 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
286 |
+
return emb
|
287 |
+
|
288 |
+
|
289 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
290 |
+
"""
|
291 |
+
embed_dim: output dimension for each position
|
292 |
+
pos: a list of positions to be encoded: size (M,)
|
293 |
+
out: (M, D)
|
294 |
+
"""
|
295 |
+
assert embed_dim % 2 == 0
|
296 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
297 |
+
omega /= embed_dim / 2.
|
298 |
+
omega = 1. / 10000 ** omega # (D/2,)
|
299 |
+
|
300 |
+
pos = pos.reshape(-1) # (M,)
|
301 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
302 |
+
|
303 |
+
emb_sin = np.sin(out) # (M, D/2)
|
304 |
+
emb_cos = np.cos(out) # (M, D/2)
|
305 |
+
|
306 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
307 |
+
return emb
|
308 |
+
|
309 |
+
|
310 |
+
#################################################################################
|
311 |
+
# PixArt Configs #
|
312 |
+
#################################################################################
|
313 |
+
@MODELS.register_module()
|
314 |
+
def PixArt_XL_2(**kwargs):
|
315 |
+
return PixArt(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
diffusion/model/nets/PixArtMS.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
9 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
10 |
+
# --------------------------------------------------------
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from timm.models.layers import DropPath
|
14 |
+
from timm.models.vision_transformer import Mlp
|
15 |
+
|
16 |
+
from diffusion.model.builder import MODELS
|
17 |
+
from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
|
18 |
+
from diffusion.model.nets.PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, SizeEmbedder
|
19 |
+
from diffusion.model.nets.PixArt import PixArt, get_2d_sincos_pos_embed
|
20 |
+
|
21 |
+
|
22 |
+
class PatchEmbed(nn.Module):
|
23 |
+
""" 2D Image to Patch Embedding
|
24 |
+
"""
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
patch_size=16,
|
28 |
+
in_chans=3,
|
29 |
+
embed_dim=768,
|
30 |
+
norm_layer=None,
|
31 |
+
flatten=True,
|
32 |
+
bias=True,
|
33 |
+
):
|
34 |
+
super().__init__()
|
35 |
+
patch_size = to_2tuple(patch_size)
|
36 |
+
self.patch_size = patch_size
|
37 |
+
self.flatten = flatten
|
38 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
39 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x = self.proj(x)
|
43 |
+
if self.flatten:
|
44 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
45 |
+
x = self.norm(x)
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
class PixArtMSBlock(nn.Module):
|
50 |
+
"""
|
51 |
+
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
|
55 |
+
sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs):
|
56 |
+
super().__init__()
|
57 |
+
self.hidden_size = hidden_size
|
58 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
59 |
+
self.attn = AttentionKVCompress(
|
60 |
+
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
61 |
+
qk_norm=qk_norm, **block_kwargs
|
62 |
+
)
|
63 |
+
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
|
64 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
65 |
+
# to be compatible with lower version pytorch
|
66 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
67 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
|
68 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
69 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
70 |
+
|
71 |
+
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
|
72 |
+
B, N, C = x.shape
|
73 |
+
|
74 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
75 |
+
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
76 |
+
x = x + self.cross_attn(x, y, mask)
|
77 |
+
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
78 |
+
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
#############################################################################
|
83 |
+
# Core PixArt Model #
|
84 |
+
#################################################################################
|
85 |
+
@MODELS.register_module()
|
86 |
+
class PixArtMS(PixArt):
|
87 |
+
"""
|
88 |
+
Diffusion model with a Transformer backbone.
|
89 |
+
"""
|
90 |
+
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
input_size=32,
|
94 |
+
patch_size=2,
|
95 |
+
in_channels=4,
|
96 |
+
hidden_size=1152,
|
97 |
+
depth=28,
|
98 |
+
num_heads=16,
|
99 |
+
mlp_ratio=4.0,
|
100 |
+
class_dropout_prob=0.1,
|
101 |
+
learn_sigma=True,
|
102 |
+
pred_sigma=True,
|
103 |
+
drop_path: float = 0.,
|
104 |
+
caption_channels=4096,
|
105 |
+
pe_interpolation=1.,
|
106 |
+
config=None,
|
107 |
+
model_max_length=120,
|
108 |
+
micro_condition=False,
|
109 |
+
qk_norm=False,
|
110 |
+
kv_compress_config=None,
|
111 |
+
**kwargs,
|
112 |
+
):
|
113 |
+
super().__init__(
|
114 |
+
input_size=input_size,
|
115 |
+
patch_size=patch_size,
|
116 |
+
in_channels=in_channels,
|
117 |
+
hidden_size=hidden_size,
|
118 |
+
depth=depth,
|
119 |
+
num_heads=num_heads,
|
120 |
+
mlp_ratio=mlp_ratio,
|
121 |
+
class_dropout_prob=class_dropout_prob,
|
122 |
+
learn_sigma=learn_sigma,
|
123 |
+
pred_sigma=pred_sigma,
|
124 |
+
drop_path=drop_path,
|
125 |
+
pe_interpolation=pe_interpolation,
|
126 |
+
config=config,
|
127 |
+
model_max_length=model_max_length,
|
128 |
+
qk_norm=qk_norm,
|
129 |
+
kv_compress_config=kv_compress_config,
|
130 |
+
**kwargs,
|
131 |
+
)
|
132 |
+
self.h = self.w = 0
|
133 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
134 |
+
self.t_block = nn.Sequential(
|
135 |
+
nn.SiLU(),
|
136 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
137 |
+
)
|
138 |
+
self.x_embedder = PatchEmbed(patch_size, in_channels, hidden_size, bias=True)
|
139 |
+
self.y_embedder = CaptionEmbedder(in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, act_layer=approx_gelu, token_num=model_max_length)
|
140 |
+
self.micro_conditioning = micro_condition
|
141 |
+
if self.micro_conditioning:
|
142 |
+
self.csize_embedder = SizeEmbedder(hidden_size//3) # c_size embed
|
143 |
+
self.ar_embedder = SizeEmbedder(hidden_size//3) # aspect ratio embed
|
144 |
+
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
145 |
+
if kv_compress_config is None:
|
146 |
+
kv_compress_config = {
|
147 |
+
'sampling': None,
|
148 |
+
'scale_factor': 1,
|
149 |
+
'kv_compress_layer': [],
|
150 |
+
}
|
151 |
+
self.blocks = nn.ModuleList([
|
152 |
+
PixArtMSBlock(
|
153 |
+
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
154 |
+
input_size=(input_size // patch_size, input_size // patch_size),
|
155 |
+
sampling=kv_compress_config['sampling'],
|
156 |
+
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
|
157 |
+
qk_norm=qk_norm,
|
158 |
+
)
|
159 |
+
for i in range(depth)
|
160 |
+
])
|
161 |
+
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
|
162 |
+
|
163 |
+
self.initialize()
|
164 |
+
|
165 |
+
def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
|
166 |
+
"""
|
167 |
+
Forward pass of PixArt.
|
168 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
169 |
+
t: (N,) tensor of diffusion timesteps
|
170 |
+
y: (N, 1, 120, C) tensor of class labels
|
171 |
+
"""
|
172 |
+
bs = x.shape[0]
|
173 |
+
x = x.to(self.dtype)
|
174 |
+
timestep = timestep.to(self.dtype)
|
175 |
+
y = y.to(self.dtype)
|
176 |
+
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
|
177 |
+
pos_embed = torch.from_numpy(
|
178 |
+
get_2d_sincos_pos_embed(
|
179 |
+
self.pos_embed.shape[-1], (self.h, self.w), pe_interpolation=self.pe_interpolation,
|
180 |
+
base_size=self.base_size
|
181 |
+
)
|
182 |
+
).unsqueeze(0).to(x.device).to(self.dtype)
|
183 |
+
|
184 |
+
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
185 |
+
t = self.t_embedder(timestep) # (N, D)
|
186 |
+
|
187 |
+
if self.micro_conditioning:
|
188 |
+
c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype)
|
189 |
+
csize = self.csize_embedder(c_size, bs) # (N, D)
|
190 |
+
ar = self.ar_embedder(ar, bs) # (N, D)
|
191 |
+
t = t + torch.cat([csize, ar], dim=1)
|
192 |
+
|
193 |
+
t0 = self.t_block(t)
|
194 |
+
y = self.y_embedder(y, self.training) # (N, D)
|
195 |
+
|
196 |
+
if mask is not None:
|
197 |
+
if mask.shape[0] != y.shape[0]:
|
198 |
+
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
199 |
+
mask = mask.squeeze(1).squeeze(1)
|
200 |
+
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
201 |
+
y_lens = mask.sum(dim=1).tolist()
|
202 |
+
else:
|
203 |
+
y_lens = [y.shape[2]] * y.shape[0]
|
204 |
+
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
205 |
+
for block in self.blocks:
|
206 |
+
x = auto_grad_checkpoint(block, x, y, t0, y_lens, (self.h, self.w), **kwargs) # (N, T, D) #support grad checkpoint
|
207 |
+
|
208 |
+
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
209 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
210 |
+
|
211 |
+
return x
|
212 |
+
|
213 |
+
def forward_with_dpmsolver(self, x, timestep, y, data_info, **kwargs):
|
214 |
+
"""
|
215 |
+
dpm solver donnot need variance prediction
|
216 |
+
"""
|
217 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
218 |
+
model_out = self.forward(x, timestep, y, data_info=data_info, **kwargs)
|
219 |
+
return model_out.chunk(2, dim=1)[0]
|
220 |
+
|
221 |
+
def forward_with_cfg(self, x, timestep, y, cfg_scale, data_info, mask=None, **kwargs):
|
222 |
+
"""
|
223 |
+
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
|
224 |
+
"""
|
225 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
226 |
+
half = x[: len(x) // 2]
|
227 |
+
combined = torch.cat([half, half], dim=0)
|
228 |
+
model_out = self.forward(combined, timestep, y, mask, data_info=data_info, **kwargs)
|
229 |
+
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
|
230 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
231 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
232 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
233 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
234 |
+
return torch.cat([eps, rest], dim=1)
|
235 |
+
|
236 |
+
def unpatchify(self, x):
|
237 |
+
"""
|
238 |
+
x: (N, T, patch_size**2 * C)
|
239 |
+
imgs: (N, H, W, C)
|
240 |
+
"""
|
241 |
+
c = self.out_channels
|
242 |
+
p = self.x_embedder.patch_size[0]
|
243 |
+
assert self.h * self.w == x.shape[1]
|
244 |
+
|
245 |
+
x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c))
|
246 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
247 |
+
imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p))
|
248 |
+
return imgs
|
249 |
+
|
250 |
+
def initialize(self):
|
251 |
+
# Initialize transformer layers:
|
252 |
+
def _basic_init(module):
|
253 |
+
if isinstance(module, nn.Linear):
|
254 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
255 |
+
if module.bias is not None:
|
256 |
+
nn.init.constant_(module.bias, 0)
|
257 |
+
|
258 |
+
self.apply(_basic_init)
|
259 |
+
|
260 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
261 |
+
w = self.x_embedder.proj.weight.data
|
262 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
263 |
+
|
264 |
+
# Initialize timestep embedding MLP:
|
265 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
266 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
267 |
+
nn.init.normal_(self.t_block[1].weight, std=0.02)
|
268 |
+
if self.micro_conditioning:
|
269 |
+
nn.init.normal_(self.csize_embedder.mlp[0].weight, std=0.02)
|
270 |
+
nn.init.normal_(self.csize_embedder.mlp[2].weight, std=0.02)
|
271 |
+
nn.init.normal_(self.ar_embedder.mlp[0].weight, std=0.02)
|
272 |
+
nn.init.normal_(self.ar_embedder.mlp[2].weight, std=0.02)
|
273 |
+
|
274 |
+
# Initialize caption embedding MLP:
|
275 |
+
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
|
276 |
+
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
|
277 |
+
|
278 |
+
# Zero-out adaLN modulation layers in PixArt blocks:
|
279 |
+
for block in self.blocks:
|
280 |
+
nn.init.constant_(block.cross_attn.proj.weight, 0)
|
281 |
+
nn.init.constant_(block.cross_attn.proj.bias, 0)
|
282 |
+
|
283 |
+
# Zero-out output layers:
|
284 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
285 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
286 |
+
|
287 |
+
|
288 |
+
#################################################################################
|
289 |
+
# PixArt Configs #
|
290 |
+
#################################################################################
|
291 |
+
@MODELS.register_module()
|
292 |
+
def PixArtMS_XL_2(**kwargs):
|
293 |
+
return PixArtMS(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|