wpeebles commited on
Commit
2ea65a3
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +33 -0
  2. .gitignore +2 -0
  3. LICENSE.txt +400 -0
  4. README.md +13 -0
  5. app.py +136 -0
  6. diffusion/__init__.py +46 -0
  7. diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
  8. diffusion/__pycache__/diffusion_utils.cpython-39.pyc +0 -0
  9. diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc +0 -0
  10. diffusion/__pycache__/respace.cpython-39.pyc +0 -0
  11. diffusion/diffusion_utils.py +88 -0
  12. diffusion/gaussian_diffusion.py +873 -0
  13. diffusion/respace.py +129 -0
  14. diffusion/timestep_sampler.py +150 -0
  15. download.py +47 -0
  16. gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/captions.json +1 -0
  17. gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpe9n9k0az54l_f3gw.png +0 -0
  18. gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpjtpzxdb26dcckrfs.png +0 -0
  19. gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpmoaadwkx_iwdl820.png +0 -0
  20. gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpurgtzs1y9ie5h2s2.png +0 -0
  21. gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/captions.json +1 -0
  22. gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpff9tyvtyjf15zmm7.png +0 -0
  23. gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpfz82ezx8is9yjfdi.png +0 -0
  24. gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmppl9m3bza1t5k15ta.png +0 -0
  25. gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpz96a09w08a8kcl84.png +0 -0
  26. gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/captions.json +1 -0
  27. gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp63ria_p2997pqmtd.png +0 -0
  28. gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp8pyeu6a8l_en8xim.png +0 -0
  29. gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpipmoifyv87sepx5o.png +0 -0
  30. gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpkhiwf5_97pu17qfx.png +0 -0
  31. gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/captions.json +1 -0
  32. gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp1j0z6hrnm0vw0rai.png +0 -0
  33. gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp3oqwhs41q0ygkzgo.png +0 -0
  34. gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpgnbz7k1qi372rdol.png +0 -0
  35. gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpv4smnu57thdhxlo5.png +0 -0
  36. gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/captions.json +1 -0
  37. gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp2xnh42_y5utj14_w.png +0 -0
  38. gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp7th5inqczdoynzrl.png +0 -0
  39. gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpamh42w2dmr65v467.png +0 -0
  40. gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpnleysn66i6zycwi5.png +0 -0
  41. gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/captions.json +1 -0
  42. gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmp_haqfqy2_0t9ntvn.png +0 -0
  43. gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpe7h1fv6ojayij55x.png +0 -0
  44. gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpvza140f8v_pikrxf.png +0 -0
  45. gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpwy2tstw8dlphar0w.png +0 -0
  46. gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/captions.json +1 -0
  47. gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp2u8kljaep52iold0.png +0 -0
  48. gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp3wbfuj7t3s4jrwxg.png +0 -0
  49. gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp5a1jyli246vbchhw.png +0 -0
  50. gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmpbbapqj_eo3ty6ub2.png +0 -0
.gitattributes ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.npy filter=lfs diff=lfs merge=lfs -text
14
+ *.npz filter=lfs diff=lfs merge=lfs -text
15
+ *.onnx filter=lfs diff=lfs merge=lfs -text
16
+ *.ot filter=lfs diff=lfs merge=lfs -text
17
+ *.parquet filter=lfs diff=lfs merge=lfs -text
18
+ *.pb filter=lfs diff=lfs merge=lfs -text
19
+ *.pickle filter=lfs diff=lfs merge=lfs -text
20
+ *.pkl filter=lfs diff=lfs merge=lfs -text
21
+ *.pt filter=lfs diff=lfs merge=lfs -text
22
+ *.pth filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
25
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
27
+ *.tflite filter=lfs diff=lfs merge=lfs -text
28
+ *.tgz filter=lfs diff=lfs merge=lfs -text
29
+ *.wasm filter=lfs diff=lfs merge=lfs -text
30
+ *.xz filter=lfs diff=lfs merge=lfs -text
31
+ *.zip filter=lfs diff=lfs merge=lfs -text
32
+ *.zst filter=lfs diff=lfs merge=lfs -text
33
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
1
+ .idea
2
+ pretrained_models
LICENSE.txt ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Attribution-NonCommercial 4.0 International
3
+
4
+ =======================================================================
5
+
6
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
7
+ does not provide legal services or legal advice. Distribution of
8
+ Creative Commons public licenses does not create a lawyer-client or
9
+ other relationship. Creative Commons makes its licenses and related
10
+ information available on an "as-is" basis. Creative Commons gives no
11
+ warranties regarding its licenses, any material licensed under their
12
+ terms and conditions, or any related information. Creative Commons
13
+ disclaims all liability for damages resulting from their use to the
14
+ fullest extent possible.
15
+
16
+ Using Creative Commons Public Licenses
17
+
18
+ Creative Commons public licenses provide a standard set of terms and
19
+ conditions that creators and other rights holders may use to share
20
+ original works of authorship and other material subject to copyright
21
+ and certain other rights specified in the public license below. The
22
+ following considerations are for informational purposes only, are not
23
+ exhaustive, and do not form part of our licenses.
24
+
25
+ Considerations for licensors: Our public licenses are
26
+ intended for use by those authorized to give the public
27
+ permission to use material in ways otherwise restricted by
28
+ copyright and certain other rights. Our licenses are
29
+ irrevocable. Licensors should read and understand the terms
30
+ and conditions of the license they choose before applying it.
31
+ Licensors should also secure all rights necessary before
32
+ applying our licenses so that the public can reuse the
33
+ material as expected. Licensors should clearly mark any
34
+ material not subject to the license. This includes other CC-
35
+ licensed material, or material used under an exception or
36
+ limitation to copyright. More considerations for licensors:
37
+ wiki.creativecommons.org/Considerations_for_licensors
38
+
39
+ Considerations for the public: By using one of our public
40
+ licenses, a licensor grants the public permission to use the
41
+ licensed material under specified terms and conditions. If
42
+ the licensor's permission is not necessary for any reason--for
43
+ example, because of any applicable exception or limitation to
44
+ copyright--then that use is not regulated by the license. Our
45
+ licenses grant only permissions under copyright and certain
46
+ other rights that a licensor has authority to grant. Use of
47
+ the licensed material may still be restricted for other
48
+ reasons, including because others have copyright or other
49
+ rights in the material. A licensor may make special requests,
50
+ such as asking that all changes be marked or described.
51
+ Although not required by our licenses, you are encouraged to
52
+ respect those requests where reasonable. More_considerations
53
+ for the public:
54
+ wiki.creativecommons.org/Considerations_for_licensees
55
+
56
+ =======================================================================
57
+
58
+ Creative Commons Attribution-NonCommercial 4.0 International Public
59
+ License
60
+
61
+ By exercising the Licensed Rights (defined below), You accept and agree
62
+ to be bound by the terms and conditions of this Creative Commons
63
+ Attribution-NonCommercial 4.0 International Public License ("Public
64
+ License"). To the extent this Public License may be interpreted as a
65
+ contract, You are granted the Licensed Rights in consideration of Your
66
+ acceptance of these terms and conditions, and the Licensor grants You
67
+ such rights in consideration of benefits the Licensor receives from
68
+ making the Licensed Material available under these terms and
69
+ conditions.
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. Copyright and Similar Rights means copyright and/or similar rights
88
+ closely related to copyright including, without limitation,
89
+ performance, broadcast, sound recording, and Sui Generis Database
90
+ Rights, without regard to how the rights are labeled or
91
+ categorized. For purposes of this Public License, the rights
92
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
93
+ Rights.
94
+ d. Effective Technological Measures means those measures that, in the
95
+ absence of proper authority, may not be circumvented under laws
96
+ fulfilling obligations under Article 11 of the WIPO Copyright
97
+ Treaty adopted on December 20, 1996, and/or similar international
98
+ agreements.
99
+
100
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
101
+ any other exception or limitation to Copyright and Similar Rights
102
+ that applies to Your use of the Licensed Material.
103
+
104
+ f. Licensed Material means the artistic or literary work, database,
105
+ or other material to which the Licensor applied this Public
106
+ License.
107
+
108
+ g. Licensed Rights means the rights granted to You subject to the
109
+ terms and conditions of this Public License, which are limited to
110
+ all Copyright and Similar Rights that apply to Your use of the
111
+ Licensed Material and that the Licensor has authority to license.
112
+
113
+ h. Licensor means the individual(s) or entity(ies) granting rights
114
+ under this Public License.
115
+
116
+ i. NonCommercial means not primarily intended for or directed towards
117
+ commercial advantage or monetary compensation. For purposes of
118
+ this Public License, the exchange of the Licensed Material for
119
+ other material subject to Copyright and Similar Rights by digital
120
+ file-sharing or similar means is NonCommercial provided there is
121
+ no payment of monetary compensation in connection with the
122
+ exchange.
123
+
124
+ j. Share means to provide material to the public by any means or
125
+ process that requires permission under the Licensed Rights, such
126
+ as reproduction, public display, public performance, distribution,
127
+ dissemination, communication, or importation, and to make material
128
+ available to the public including in ways that members of the
129
+ public may access the material from a place and at a time
130
+ individually chosen by them.
131
+
132
+ k. Sui Generis Database Rights means rights other than copyright
133
+ resulting from Directive 96/9/EC of the European Parliament and of
134
+ the Council of 11 March 1996 on the legal protection of databases,
135
+ as amended and/or succeeded, as well as other essentially
136
+ equivalent rights anywhere in the world.
137
+
138
+ l. You means the individual or entity exercising the Licensed Rights
139
+ under this Public License. Your has a corresponding meaning.
140
+
141
+ Section 2 -- Scope.
142
+
143
+ a. License grant.
144
+
145
+ 1. Subject to the terms and conditions of this Public License,
146
+ the Licensor hereby grants You a worldwide, royalty-free,
147
+ non-sublicensable, non-exclusive, irrevocable license to
148
+ exercise the Licensed Rights in the Licensed Material to:
149
+
150
+ a. reproduce and Share the Licensed Material, in whole or
151
+ in part, for NonCommercial purposes only; and
152
+
153
+ b. produce, reproduce, and Share Adapted Material for
154
+ NonCommercial purposes only.
155
+
156
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
157
+ Exceptions and Limitations apply to Your use, this Public
158
+ License does not apply, and You do not need to comply with
159
+ its terms and conditions.
160
+
161
+ 3. Term. The term of this Public License is specified in Section
162
+ 6(a).
163
+
164
+ 4. Media and formats; technical modifications allowed. The
165
+ Licensor authorizes You to exercise the Licensed Rights in
166
+ all media and formats whether now known or hereafter created,
167
+ and to make technical modifications necessary to do so. The
168
+ Licensor waives and/or agrees not to assert any right or
169
+ authority to forbid You from making technical modifications
170
+ necessary to exercise the Licensed Rights, including
171
+ technical modifications necessary to circumvent Effective
172
+ Technological Measures. For purposes of this Public License,
173
+ simply making modifications authorized by this Section 2(a)
174
+ (4) never produces Adapted Material.
175
+
176
+ 5. Downstream recipients.
177
+
178
+ a. Offer from the Licensor -- Licensed Material. Every
179
+ recipient of the Licensed Material automatically
180
+ receives an offer from the Licensor to exercise the
181
+ Licensed Rights under the terms and conditions of this
182
+ Public License.
183
+
184
+ b. No downstream restrictions. You may not offer or impose
185
+ any additional or different terms or conditions on, or
186
+ apply any Effective Technological Measures to, the
187
+ Licensed Material if doing so restricts exercise of the
188
+ Licensed Rights by any recipient of the Licensed
189
+ Material.
190
+
191
+ 6. No endorsement. Nothing in this Public License constitutes or
192
+ may be construed as permission to assert or imply that You
193
+ are, or that Your use of the Licensed Material is, connected
194
+ with, or sponsored, endorsed, or granted official status by,
195
+ the Licensor or others designated to receive attribution as
196
+ provided in Section 3(a)(1)(A)(i).
197
+
198
+ b. Other rights.
199
+
200
+ 1. Moral rights, such as the right of integrity, are not
201
+ licensed under this Public License, nor are publicity,
202
+ privacy, and/or other similar personality rights; however, to
203
+ the extent possible, the Licensor waives and/or agrees not to
204
+ assert any such rights held by the Licensor to the limited
205
+ extent necessary to allow You to exercise the Licensed
206
+ Rights, but not otherwise.
207
+
208
+ 2. Patent and trademark rights are not licensed under this
209
+ Public License.
210
+
211
+ 3. To the extent possible, the Licensor waives any right to
212
+ collect royalties from You for the exercise of the Licensed
213
+ Rights, whether directly or through a collecting society
214
+ under any voluntary or waivable statutory or compulsory
215
+ licensing scheme. In all other cases the Licensor expressly
216
+ reserves any right to collect such royalties, including when
217
+ the Licensed Material is used other than for NonCommercial
218
+ purposes.
219
+
220
+ Section 3 -- License Conditions.
221
+
222
+ Your exercise of the Licensed Rights is expressly made subject to the
223
+ following conditions.
224
+
225
+ a. Attribution.
226
+
227
+ 1. If You Share the Licensed Material (including in modified
228
+ form), You must:
229
+
230
+ a. retain the following if it is supplied by the Licensor
231
+ with the Licensed Material:
232
+
233
+ i. identification of the creator(s) of the Licensed
234
+ Material and any others designated to receive
235
+ attribution, in any reasonable manner requested by
236
+ the Licensor (including by pseudonym if
237
+ designated);
238
+
239
+ ii. a copyright notice;
240
+
241
+ iii. a notice that refers to this Public License;
242
+
243
+ iv. a notice that refers to the disclaimer of
244
+ warranties;
245
+
246
+ v. a URI or hyperlink to the Licensed Material to the
247
+ extent reasonably practicable;
248
+
249
+ b. indicate if You modified the Licensed Material and
250
+ retain an indication of any previous modifications; and
251
+
252
+ c. indicate the Licensed Material is licensed under this
253
+ Public License, and include the text of, or the URI or
254
+ hyperlink to, this Public License.
255
+
256
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
257
+ reasonable manner based on the medium, means, and context in
258
+ which You Share the Licensed Material. For example, it may be
259
+ reasonable to satisfy the conditions by providing a URI or
260
+ hyperlink to a resource that includes the required
261
+ information.
262
+
263
+ 3. If requested by the Licensor, You must remove any of the
264
+ information required by Section 3(a)(1)(A) to the extent
265
+ reasonably practicable.
266
+
267
+ 4. If You Share Adapted Material You produce, the Adapter's
268
+ License You apply must not prevent recipients of the Adapted
269
+ Material from complying with this Public License.
270
+
271
+ Section 4 -- Sui Generis Database Rights.
272
+
273
+ Where the Licensed Rights include Sui Generis Database Rights that
274
+ apply to Your use of the Licensed Material:
275
+
276
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
277
+ to extract, reuse, reproduce, and Share all or a substantial
278
+ portion of the contents of the database for NonCommercial purposes
279
+ only;
280
+
281
+ b. if You include all or a substantial portion of the database
282
+ contents in a database in which You have Sui Generis Database
283
+ Rights, then the database in which You have Sui Generis Database
284
+ Rights (but not its individual contents) is Adapted Material; and
285
+
286
+ c. You must comply with the conditions in Section 3(a) if You Share
287
+ all or a substantial portion of the contents of the database.
288
+
289
+ For the avoidance of doubt, this Section 4 supplements and does not
290
+ replace Your obligations under this Public License where the Licensed
291
+ Rights include other Copyright and Similar Rights.
292
+
293
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
294
+
295
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
296
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
297
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
298
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
299
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
300
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
301
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
302
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
303
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
304
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
305
+
306
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
307
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
308
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
309
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
310
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
311
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
312
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
313
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
314
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
315
+
316
+ c. The disclaimer of warranties and limitation of liability provided
317
+ above shall be interpreted in a manner that, to the extent
318
+ possible, most closely approximates an absolute disclaimer and
319
+ waiver of all liability.
320
+
321
+ Section 6 -- Term and Termination.
322
+
323
+ a. This Public License applies for the term of the Copyright and
324
+ Similar Rights licensed here. However, if You fail to comply with
325
+ this Public License, then Your rights under this Public License
326
+ terminate automatically.
327
+
328
+ b. Where Your right to use the Licensed Material has terminated under
329
+ Section 6(a), it reinstates:
330
+
331
+ 1. automatically as of the date the violation is cured, provided
332
+ it is cured within 30 days of Your discovery of the
333
+ violation; or
334
+
335
+ 2. upon express reinstatement by the Licensor.
336
+
337
+ For the avoidance of doubt, this Section 6(b) does not affect any
338
+ right the Licensor may have to seek remedies for Your violations
339
+ of this Public License.
340
+
341
+ c. For the avoidance of doubt, the Licensor may also offer the
342
+ Licensed Material under separate terms or conditions or stop
343
+ distributing the Licensed Material at any time; however, doing so
344
+ will not terminate this Public License.
345
+
346
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
347
+ License.
348
+
349
+ Section 7 -- Other Terms and Conditions.
350
+
351
+ a. The Licensor shall not be bound by any additional or different
352
+ terms or conditions communicated by You unless expressly agreed.
353
+
354
+ b. Any arrangements, understandings, or agreements regarding the
355
+ Licensed Material not stated herein are separate from and
356
+ independent of the terms and conditions of this Public License.
357
+
358
+ Section 8 -- Interpretation.
359
+
360
+ a. For the avoidance of doubt, this Public License does not, and
361
+ shall not be interpreted to, reduce, limit, restrict, or impose
362
+ conditions on any use of the Licensed Material that could lawfully
363
+ be made without permission under this Public License.
364
+
365
+ b. To the extent possible, if any provision of this Public License is
366
+ deemed unenforceable, it shall be automatically reformed to the
367
+ minimum extent necessary to make it enforceable. If the provision
368
+ cannot be reformed, it shall be severed from this Public License
369
+ without affecting the enforceability of the remaining terms and
370
+ conditions.
371
+
372
+ c. No term or condition of this Public License will be waived and no
373
+ failure to comply consented to unless expressly agreed to by the
374
+ Licensor.
375
+
376
+ d. Nothing in this Public License constitutes or may be interpreted
377
+ as a limitation upon, or waiver of, any privileges and immunities
378
+ that apply to the Licensor or You, including from the legal
379
+ processes of any jurisdiction or authority.
380
+
381
+ =======================================================================
382
+
383
+ Creative Commons is not a party to its public
384
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
385
+ its public licenses to material it publishes and in those instances
386
+ will be considered the “Licensor.” The text of the Creative Commons
387
+ public licenses is dedicated to the public domain under the CC0 Public
388
+ Domain Dedication. Except for the limited purpose of indicating that
389
+ material is shared under a Creative Commons public license or as
390
+ otherwise permitted by the Creative Commons policies published at
391
+ creativecommons.org/policies, Creative Commons does not authorize the
392
+ use of the trademark "Creative Commons" or any other trademark or logo
393
+ of Creative Commons without its prior written consent including,
394
+ without limitation, in connection with any unauthorized modifications
395
+ to any of its public licenses or any other arrangements,
396
+ understandings, or agreements concerning use of licensed material. For
397
+ the avoidance of doubt, this paragraph does not form part of the
398
+ public licenses.
399
+
400
+ Creative Commons may be contacted at creativecommons.org.
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Diffusion Transformers (DiT)
3
+ emoji: 🚀
4
+ colorFrom: yellow
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.6
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-4.0
11
+ ---
12
+
13
+ The code and model weights are licensed under CC-BY-NC. See LICENSE.txt for details.
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.utils import make_grid
3
+ import math
4
+ from PIL import Image
5
+ from diffusion import create_diffusion
6
+ from diffusers.models import AutoencoderKL
7
+ import gradio as gr
8
+ from imagenet_class_data import IMAGENET_1K_CLASSES
9
+ from download import find_model
10
+ from models import DiT_XL_2
11
+
12
+
13
+ def load_model(image_size=256):
14
+ assert image_size in [256, 512]
15
+ latent_size = image_size // 8
16
+ model = DiT_XL_2(input_size=latent_size).to(device)
17
+ state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt")
18
+ model.load_state_dict(state_dict)
19
+ model.eval()
20
+ return model
21
+
22
+
23
+ torch.set_grad_enabled(False)
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ model = load_model(image_size=256)
26
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
27
+ current_image_size = 256
28
+ current_vae_model = "stabilityai/sd-vae-ft-mse"
29
+
30
+
31
+ def generate(image_size, vae_model, class_label, cfg_scale, num_sampling_steps, n, seed):
32
+ image_size = int(image_size.split("x")[0])
33
+ global current_image_size
34
+ if image_size != current_image_size:
35
+ global model
36
+ del model
37
+ # if device == "cuda":
38
+ # torch.cuda.empty_cache()
39
+ model = load_model(image_size=image_size)
40
+ current_image_size = image_size
41
+
42
+ global current_vae_model
43
+ if vae_model != current_vae_model:
44
+ global vae
45
+ if device == "cuda":
46
+ vae.to("cpu")
47
+ del vae
48
+ vae = AutoencoderKL.from_pretrained(vae_model).to(device)
49
+
50
+ # Seed PyTorch:
51
+ torch.manual_seed(seed)
52
+
53
+ # Setup diffusion
54
+ diffusion = create_diffusion(str(num_sampling_steps))
55
+
56
+ # Create sampling noise:
57
+ latent_size = image_size // 8
58
+ z = torch.randn(n, 4, latent_size, latent_size, device=device)
59
+ y = torch.tensor([class_label] * n, device=device)
60
+
61
+ # Setup classifier-free guidance:
62
+ z = torch.cat([z, z], 0)
63
+ y_null = torch.tensor([1000] * n, device=device)
64
+ y = torch.cat([y, y_null], 0)
65
+ model_kwargs = dict(y=y, cfg_scale=cfg_scale)
66
+
67
+ # Sample images:
68
+ samples = diffusion.p_sample_loop(
69
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
70
+ )
71
+ samples, _ = samples.chunk(2, dim=0) # Remove null class samples
72
+ samples = vae.decode(samples / 0.18215).sample
73
+
74
+ # Convert to PIL.Image format:
75
+ samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy()
76
+ samples = [Image.fromarray(sample) for sample in samples]
77
+ return samples
78
+
79
+
80
+ description = '''This is a demo of our DiT image generation models. DiTs are a new class of diffusion models with
81
+ transformer backbones. They are class-conditional models trained on ImageNet-1K, and they outperform prior DDPMs.'''
82
+
83
+ project_links = '''
84
+ <p style="text-align: center">
85
+ <a href="https://www.wpeebles.com/DiT.html">Project Page</a> &#183;
86
+ <a href="http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb">Colab</a> &#183;
87
+ <a href="http://arxiv.org/abs/2212.09748">Paper</a> &#183;
88
+ <a href="https://github.com/facebookresearch/DiT">GitHub</a></p>'''
89
+
90
+ examples = [
91
+ ["512x512", "stabilityai/sd-vae-ft-mse", "golden retriever", 4.0, 200, 4, 1000],
92
+ ["512x512", "stabilityai/sd-vae-ft-mse", "macaw", 4.0, 200, 4, 1],
93
+ ["512x512", "stabilityai/sd-vae-ft-mse", "balloon", 4.0, 200, 4, 1],
94
+ ["512x512", "stabilityai/sd-vae-ft-mse", "cliff, drop, drop-off", 4.0, 200, 4, 7],
95
+ ["512x512", "stabilityai/sd-vae-ft-mse", "Pembroke, Pembroke Welsh corgi", 4.0, 200, 4, 0],
96
+ ["256x256", "stabilityai/sd-vae-ft-mse", "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", 4.0, 200,
97
+ 4, 1],
98
+ ["256x256", "stabilityai/sd-vae-ft-mse", "teddy, teddy bear", 4.0, 200, 4, 3],
99
+ ["256x256", "stabilityai/sd-vae-ft-mse", "cheeseburger", 4.0, 200, 4, 2],
100
+
101
+ ]
102
+
103
+ with gr.Blocks() as demo:
104
+ gr.Markdown("<h1 style='text-align: center'>Scalable Diffusion Models with Transformers (DiT)</h1>")
105
+ gr.Markdown(project_links)
106
+ gr.Markdown(description)
107
+
108
+ with gr.Tabs():
109
+ with gr.TabItem('Generate'):
110
+ with gr.Row():
111
+ with gr.Column():
112
+ with gr.Row():
113
+ image_size = gr.inputs.Radio(choices=["256x256", "512x512"], default="256x256", label='DiT Model Resolution')
114
+ vae_model = gr.inputs.Radio(choices=["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"],
115
+ default="stabilityai/sd-vae-ft-mse", label='VAE Decoder')
116
+ with gr.Row():
117
+ i1k_class = gr.inputs.Dropdown(
118
+ list(IMAGENET_1K_CLASSES.values()),
119
+ default='golden retriever',
120
+ type="index", label='ImageNet-1K Class'
121
+ )
122
+ cfg_scale = gr.inputs.Slider(minimum=1, maximum=25, step=0.1, default=4.0, label='Classifier-free Guidance Scale')
123
+ steps = gr.inputs.Slider(minimum=4, maximum=1000, step=1, default=75, label='Sampling Steps')
124
+ n = gr.inputs.Slider(minimum=1, maximum=16, step=1, default=1, label='Number of Samples')
125
+ seed = gr.inputs.Number(default=0, label='Seed')
126
+ button = gr.Button("Generate", variant="primary")
127
+ with gr.Column():
128
+ output = gr.Gallery(label='Generated Images').style(grid=[2], height="auto")
129
+ button.click(generate, inputs=[image_size, vae_model, i1k_class, cfg_scale, steps, n, seed], outputs=[output])
130
+ with gr.Row():
131
+ ex = gr.Examples(examples=examples, fn=generate,
132
+ inputs=[image_size, vae_model, i1k_class, cfg_scale, steps, n, seed],
133
+ outputs=[output],
134
+ cache_examples=True)
135
+
136
+ demo.launch()
diffusion/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from . import gaussian_diffusion as gd
7
+ from .respace import SpacedDiffusion, space_timesteps
8
+
9
+
10
+ def create_diffusion(
11
+ timestep_respacing,
12
+ noise_schedule="linear",
13
+ use_kl=False,
14
+ sigma_small=False,
15
+ predict_xstart=False,
16
+ learn_sigma=True,
17
+ rescale_learned_sigmas=False,
18
+ diffusion_steps=1000
19
+ ):
20
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21
+ if use_kl:
22
+ loss_type = gd.LossType.RESCALED_KL
23
+ elif rescale_learned_sigmas:
24
+ loss_type = gd.LossType.RESCALED_MSE
25
+ else:
26
+ loss_type = gd.LossType.MSE
27
+ if timestep_respacing is None or timestep_respacing == "":
28
+ timestep_respacing = [diffusion_steps]
29
+ return SpacedDiffusion(
30
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31
+ betas=betas,
32
+ model_mean_type=(
33
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34
+ ),
35
+ model_var_type=(
36
+ (
37
+ gd.ModelVarType.FIXED_LARGE
38
+ if not sigma_small
39
+ else gd.ModelVarType.FIXED_SMALL
40
+ )
41
+ if not learn_sigma
42
+ else gd.ModelVarType.LEARNED_RANGE
43
+ ),
44
+ loss_type=loss_type
45
+ # rescale_timesteps=rescale_timesteps,
46
+ )
diffusion/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (982 Bytes). View file
diffusion/__pycache__/diffusion_utils.cpython-39.pyc ADDED
Binary file (2.83 kB). View file
diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc ADDED
Binary file (24.3 kB). View file
diffusion/__pycache__/respace.cpython-39.pyc ADDED
Binary file (5.06 kB). View file
diffusion/diffusion_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import torch as th
7
+ import numpy as np
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a continuous Gaussian distribution.
50
+ :param x: the targets
51
+ :param means: the Gaussian mean Tensor.
52
+ :param log_scales: the Gaussian log stddev Tensor.
53
+ :return: a tensor like x of log probabilities (in nats).
54
+ """
55
+ centered_x = x - means
56
+ inv_stdv = th.exp(-log_scales)
57
+ normalized_x = centered_x * inv_stdv
58
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59
+ return log_probs
60
+
61
+
62
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63
+ """
64
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
65
+ given image.
66
+ :param x: the target images. It is assumed that this was uint8 values,
67
+ rescaled to the range [-1, 1].
68
+ :param means: the Gaussian mean Tensor.
69
+ :param log_scales: the Gaussian log stddev Tensor.
70
+ :return: a tensor like x of log probabilities (in nats).
71
+ """
72
+ assert x.shape == means.shape == log_scales.shape
73
+ centered_x = x - means
74
+ inv_stdv = th.exp(-log_scales)
75
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76
+ cdf_plus = approx_standard_normal_cdf(plus_in)
77
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78
+ cdf_min = approx_standard_normal_cdf(min_in)
79
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81
+ cdf_delta = cdf_plus - cdf_min
82
+ log_probs = th.where(
83
+ x < -0.999,
84
+ log_cdf_plus,
85
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86
+ )
87
+ assert log_probs.shape == x.shape
88
+ return log_probs
diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch as th
11
+ import enum
12
+
13
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14
+
15
+
16
+ def mean_flat(tensor):
17
+ """
18
+ Take the mean over all non-batch dimensions.
19
+ """
20
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
21
+
22
+
23
+ class ModelMeanType(enum.Enum):
24
+ """
25
+ Which type of output the model predicts.
26
+ """
27
+
28
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29
+ START_X = enum.auto() # the model predicts x_0
30
+ EPSILON = enum.auto() # the model predicts epsilon
31
+
32
+
33
+ class ModelVarType(enum.Enum):
34
+ """
35
+ What is used as the model's output variance.
36
+ The LEARNED_RANGE option has been added to allow the model to predict
37
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38
+ """
39
+
40
+ LEARNED = enum.auto()
41
+ FIXED_SMALL = enum.auto()
42
+ FIXED_LARGE = enum.auto()
43
+ LEARNED_RANGE = enum.auto()
44
+
45
+
46
+ class LossType(enum.Enum):
47
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48
+ RESCALED_MSE = (
49
+ enum.auto()
50
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
51
+ KL = enum.auto() # use the variational lower-bound
52
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
53
+
54
+ def is_vb(self):
55
+ return self == LossType.KL or self == LossType.RESCALED_KL
56
+
57
+
58
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
59
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
60
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
61
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
62
+ return betas
63
+
64
+
65
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
66
+ """
67
+ This is the deprecated API for creating beta schedules.
68
+ See get_named_beta_schedule() for the new library of schedules.
69
+ """
70
+ if beta_schedule == "quad":
71
+ betas = (
72
+ np.linspace(
73
+ beta_start ** 0.5,
74
+ beta_end ** 0.5,
75
+ num_diffusion_timesteps,
76
+ dtype=np.float64,
77
+ )
78
+ ** 2
79
+ )
80
+ elif beta_schedule == "linear":
81
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
82
+ elif beta_schedule == "warmup10":
83
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
84
+ elif beta_schedule == "warmup50":
85
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
86
+ elif beta_schedule == "const":
87
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
88
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
89
+ betas = 1.0 / np.linspace(
90
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
91
+ )
92
+ else:
93
+ raise NotImplementedError(beta_schedule)
94
+ assert betas.shape == (num_diffusion_timesteps,)
95
+ return betas
96
+
97
+
98
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
99
+ """
100
+ Get a pre-defined beta schedule for the given name.
101
+ The beta schedule library consists of beta schedules which remain similar
102
+ in the limit of num_diffusion_timesteps.
103
+ Beta schedules may be added, but should not be removed or changed once
104
+ they are committed to maintain backwards compatibility.
105
+ """
106
+ if schedule_name == "linear":
107
+ # Linear schedule from Ho et al, extended to work for any number of
108
+ # diffusion steps.
109
+ scale = 1000 / num_diffusion_timesteps
110
+ return get_beta_schedule(
111
+ "linear",
112
+ beta_start=scale * 0.0001,
113
+ beta_end=scale * 0.02,
114
+ num_diffusion_timesteps=num_diffusion_timesteps,
115
+ )
116
+ elif schedule_name == "squaredcos_cap_v2":
117
+ return betas_for_alpha_bar(
118
+ num_diffusion_timesteps,
119
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
120
+ )
121
+ else:
122
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
123
+
124
+
125
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
126
+ """
127
+ Create a beta schedule that discretizes the given alpha_t_bar function,
128
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
129
+ :param num_diffusion_timesteps: the number of betas to produce.
130
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
131
+ produces the cumulative product of (1-beta) up to that
132
+ part of the diffusion process.
133
+ :param max_beta: the maximum beta to use; use values lower than 1 to
134
+ prevent singularities.
135
+ """
136
+ betas = []
137
+ for i in range(num_diffusion_timesteps):
138
+ t1 = i / num_diffusion_timesteps
139
+ t2 = (i + 1) / num_diffusion_timesteps
140
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
141
+ return np.array(betas)
142
+
143
+
144
+ class GaussianDiffusion:
145
+ """
146
+ Utilities for training and sampling diffusion models.
147
+ Original ported from this codebase:
148
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
149
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
150
+ starting at T and going to 1.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ *,
156
+ betas,
157
+ model_mean_type,
158
+ model_var_type,
159
+ loss_type
160
+ ):
161
+
162
+ self.model_mean_type = model_mean_type
163
+ self.model_var_type = model_var_type
164
+ self.loss_type = loss_type
165
+
166
+ # Use float64 for accuracy.
167
+ betas = np.array(betas, dtype=np.float64)
168
+ self.betas = betas
169
+ assert len(betas.shape) == 1, "betas must be 1-D"
170
+ assert (betas > 0).all() and (betas <= 1).all()
171
+
172
+ self.num_timesteps = int(betas.shape[0])
173
+
174
+ alphas = 1.0 - betas
175
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
176
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
177
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
178
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
179
+
180
+ # calculations for diffusion q(x_t | x_{t-1}) and others
181
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
182
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
183
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
184
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
185
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
186
+
187
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
188
+ self.posterior_variance = (
189
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
190
+ )
191
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
192
+ self.posterior_log_variance_clipped = np.log(
193
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
194
+ ) if len(self.posterior_variance) > 1 else np.array([])
195
+
196
+ self.posterior_mean_coef1 = (
197
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
198
+ )
199
+ self.posterior_mean_coef2 = (
200
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
201
+ )
202
+
203
+ def q_mean_variance(self, x_start, t):
204
+ """
205
+ Get the distribution q(x_t | x_0).
206
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
207
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
208
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
209
+ """
210
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
211
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
212
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
213
+ return mean, variance, log_variance
214
+
215
+ def q_sample(self, x_start, t, noise=None):
216
+ """
217
+ Diffuse the data for a given number of diffusion steps.
218
+ In other words, sample from q(x_t | x_0).
219
+ :param x_start: the initial data batch.
220
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
221
+ :param noise: if specified, the split-out normal noise.
222
+ :return: A noisy version of x_start.
223
+ """
224
+ if noise is None:
225
+ noise = th.randn_like(x_start)
226
+ assert noise.shape == x_start.shape
227
+ return (
228
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
229
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
230
+ )
231
+
232
+ def q_posterior_mean_variance(self, x_start, x_t, t):
233
+ """
234
+ Compute the mean and variance of the diffusion posterior:
235
+ q(x_{t-1} | x_t, x_0)
236
+ """
237
+ assert x_start.shape == x_t.shape
238
+ posterior_mean = (
239
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
240
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
241
+ )
242
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
243
+ posterior_log_variance_clipped = _extract_into_tensor(
244
+ self.posterior_log_variance_clipped, t, x_t.shape
245
+ )
246
+ assert (
247
+ posterior_mean.shape[0]
248
+ == posterior_variance.shape[0]
249
+ == posterior_log_variance_clipped.shape[0]
250
+ == x_start.shape[0]
251
+ )
252
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
253
+
254
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
255
+ """
256
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
257
+ the initial x, x_0.
258
+ :param model: the model, which takes a signal and a batch of timesteps
259
+ as input.
260
+ :param x: the [N x C x ...] tensor at time t.
261
+ :param t: a 1-D Tensor of timesteps.
262
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
263
+ :param denoised_fn: if not None, a function which applies to the
264
+ x_start prediction before it is used to sample. Applies before
265
+ clip_denoised.
266
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
267
+ pass to the model. This can be used for conditioning.
268
+ :return: a dict with the following keys:
269
+ - 'mean': the model mean output.
270
+ - 'variance': the model variance output.
271
+ - 'log_variance': the log of 'variance'.
272
+ - 'pred_xstart': the prediction for x_0.
273
+ """
274
+ if model_kwargs is None:
275
+ model_kwargs = {}
276
+
277
+ B, C = x.shape[:2]
278
+ assert t.shape == (B,)
279
+ model_output = model(x, t, **model_kwargs)
280
+ if isinstance(model_output, tuple):
281
+ model_output, extra = model_output
282
+ else:
283
+ extra = None
284
+
285
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
286
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
287
+ model_output, model_var_values = th.split(model_output, C, dim=1)
288
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
289
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
290
+ # The model_var_values is [-1, 1] for [min_var, max_var].
291
+ frac = (model_var_values + 1) / 2
292
+ model_log_variance = frac * max_log + (1 - frac) * min_log
293
+ model_variance = th.exp(model_log_variance)
294
+ else:
295
+ model_variance, model_log_variance = {
296
+ # for fixedlarge, we set the initial (log-)variance like so
297
+ # to get a better decoder log likelihood.
298
+ ModelVarType.FIXED_LARGE: (
299
+ np.append(self.posterior_variance[1], self.betas[1:]),
300
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
301
+ ),
302
+ ModelVarType.FIXED_SMALL: (
303
+ self.posterior_variance,
304
+ self.posterior_log_variance_clipped,
305
+ ),
306
+ }[self.model_var_type]
307
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
308
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
309
+
310
+ def process_xstart(x):
311
+ if denoised_fn is not None:
312
+ x = denoised_fn(x)
313
+ if clip_denoised:
314
+ return x.clamp(-1, 1)
315
+ return x
316
+
317