djkesu commited on
Commit
3bbf2c7
1 Parent(s): d350233

added model

Browse files
CITATION.cff ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ cff-version: 1.3.0
2
+ message: "If you use this software, please cite it as below."
3
+ authors:
4
+ - family-names: "Betker"
5
+ given-names: "James"
6
+ orcid: "https://orcid.org/my-orcid?orcid=0000-0003-3259-4862"
7
+ title: "TorToiSe text-to-speech"
8
+ version: 2.0
9
+ date-released: 2022-04-28
10
+ url: "https://github.com/neonbjb/tortoise-tts"
LICENSE ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU AFFERO GENERAL PUBLIC LICENSE
2
+ Version 3, 19 November 2007
3
+
4
+ Copyright (c) 2023 152334H
5
+
6
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
7
+ Everyone is permitted to copy and distribute verbatim copies
8
+ of this license document, but changing it is not allowed.
9
+
10
+ Preamble
11
+
12
+ The GNU Affero General Public License is a free, copyleft license for
13
+ software and other kinds of works, specifically designed to ensure
14
+ cooperation with the community in the case of network server software.
15
+
16
+ The licenses for most software and other practical works are designed
17
+ to take away your freedom to share and change the works. By contrast,
18
+ our General Public Licenses are intended to guarantee your freedom to
19
+ share and change all versions of a program--to make sure it remains free
20
+ software for all its users.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ Developers that use our General Public Licenses protect your rights
30
+ with two steps: (1) assert copyright on the software, and (2) offer
31
+ you this License which gives you legal permission to copy, distribute
32
+ and/or modify the software.
33
+
34
+ A secondary benefit of defending all users' freedom is that
35
+ improvements made in alternate versions of the program, if they
36
+ receive widespread use, become available for other developers to
37
+ incorporate. Many developers of free software are heartened and
38
+ encouraged by the resulting cooperation. However, in the case of
39
+ software used on network servers, this result may fail to come about.
40
+ The GNU General Public License permits making a modified version and
41
+ letting the public access it on a server without ever releasing its
42
+ source code to the public.
43
+
44
+ The GNU Affero General Public License is designed specifically to
45
+ ensure that, in such cases, the modified source code becomes available
46
+ to the community. It requires the operator of a network server to
47
+ provide the source code of the modified version running there to the
48
+ users of that server. Therefore, public use of a modified version, on
49
+ a publicly accessible server, gives the public access to the source
50
+ code of the modified version.
51
+
52
+ An older license, called the Affero General Public License and
53
+ published by Affero, was designed to accomplish similar goals. This is
54
+ a different license, not a version of the Affero GPL, but Affero has
55
+ released a new version of the Affero GPL which permits relicensing under
56
+ this license.
57
+
58
+ The precise terms and conditions for copying, distribution and
59
+ modification follow.
60
+
61
+ TERMS AND CONDITIONS
62
+
63
+ 0. Definitions.
64
+
65
+ "This License" refers to version 3 of the GNU Affero General Public License.
66
+
67
+ "Copyright" also means copyright-like laws that apply to other kinds of
68
+ works, such as semiconductor masks.
69
+
70
+ "The Program" refers to any copyrightable work licensed under this
71
+ License. Each licensee is addressed as "you". "Licensees" and
72
+ "recipients" may be individuals or organizations.
73
+
74
+ To "modify" a work means to copy from or adapt all or part of the work
75
+ in a fashion requiring copyright permission, other than the making of an
76
+ exact copy. The resulting work is called a "modified version" of the
77
+ earlier work or a work "based on" the earlier work.
78
+
79
+ A "covered work" means either the unmodified Program or a work based
80
+ on the Program.
81
+
82
+ To "propagate" a work means to do anything with it that, without
83
+ permission, would make you directly or secondarily liable for
84
+ infringement under applicable copyright law, except executing it on a
85
+ computer or modifying a private copy. Propagation includes copying,
86
+ distribution (with or without modification), making available to the
87
+ public, and in some countries other activities as well.
88
+
89
+ To "convey" a work means any kind of propagation that enables other
90
+ parties to make or receive copies. Mere interaction with a user through
91
+ a computer network, with no transfer of a copy, is not conveying.
92
+
93
+ An interactive user interface displays "Appropriate Legal Notices"
94
+ to the extent that it includes a convenient and prominently visible
95
+ feature that (1) displays an appropriate copyright notice, and (2)
96
+ tells the user that there is no warranty for the work (except to the
97
+ extent that warranties are provided), that licensees may convey the
98
+ work under this License, and how to view a copy of this License. If
99
+ the interface presents a list of user commands or options, such as a
100
+ menu, a prominent item in the list meets this criterion.
101
+
102
+ 1. Source Code.
103
+
104
+ The "source code" for a work means the preferred form of the work
105
+ for making modifications to it. "Object code" means any non-source
106
+ form of a work.
107
+
108
+ A "Standard Interface" means an interface that either is an official
109
+ standard defined by a recognized standards body, or, in the case of
110
+ interfaces specified for a particular programming language, one that
111
+ is widely used among developers working in that language.
112
+
113
+ The "System Libraries" of an executable work include anything, other
114
+ than the work as a whole, that (a) is included in the normal form of
115
+ packaging a Major Component, but which is not part of that Major
116
+ Component, and (b) serves only to enable use of the work with that
117
+ Major Component, or to implement a Standard Interface for which an
118
+ implementation is available to the public in source code form. A
119
+ "Major Component", in this context, means a major essential component
120
+ (kernel, window system, and so on) of the specific operating system
121
+ (if any) on which the executable work runs, or a compiler used to
122
+ produce the work, or an object code interpreter used to run it.
123
+
124
+ The "Corresponding Source" for a work in object code form means all
125
+ the source code needed to generate, install, and (for an executable
126
+ work) run the object code and to modify the work, including scripts to
127
+ control those activities. However, it does not include the work's
128
+ System Libraries, or general-purpose tools or generally available free
129
+ programs which are used unmodified in performing those activities but
130
+ which are not part of the work. For example, Corresponding Source
131
+ includes interface definition files associated with source files for
132
+ the work, and the source code for shared libraries and dynamically
133
+ linked subprograms that the work is specifically designed to require,
134
+ such as by intimate data communication or control flow between those
135
+ subprograms and other parts of the work.
136
+
137
+ The Corresponding Source need not include anything that users
138
+ can regenerate automatically from other parts of the Corresponding
139
+ Source.
140
+
141
+ The Corresponding Source for a work in source code form is that
142
+ same work.
143
+
144
+ 2. Basic Permissions.
145
+
146
+ All rights granted under this License are granted for the term of
147
+ copyright on the Program, and are irrevocable provided the stated
148
+ conditions are met. This License explicitly affirms your unlimited
149
+ permission to run the unmodified Program. The output from running a
150
+ covered work is covered by this License only if the output, given its
151
+ content, constitutes a covered work. This License acknowledges your
152
+ rights of fair use or other equivalent, as provided by copyright law.
153
+
154
+ You may make, run and propagate covered works that you do not
155
+ convey, without conditions so long as your license otherwise remains
156
+ in force. You may convey covered works to others for the sole purpose
157
+ of having them make modifications exclusively for you, or provide you
158
+ with facilities for running those works, provided that you comply with
159
+ the terms of this License in conveying all material for which you do
160
+ not control copyright. Those thus making or running the covered works
161
+ for you must do so exclusively on your behalf, under your direction
162
+ and control, on terms that prohibit them from making any copies of
163
+ your copyrighted material outside their relationship with you.
164
+
165
+ Conveying under any other circumstances is permitted solely under
166
+ the conditions stated below. Sublicensing is not allowed; section 10
167
+ makes it unnecessary.
168
+
169
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
170
+
171
+ No covered work shall be deemed part of an effective technological
172
+ measure under any applicable law fulfilling obligations under article
173
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
174
+ similar laws prohibiting or restricting circumvention of such
175
+ measures.
176
+
177
+ When you convey a covered work, you waive any legal power to forbid
178
+ circumvention of technological measures to the extent such circumvention
179
+ is effected by exercising rights under this License with respect to
180
+ the covered work, and you disclaim any intention to limit operation or
181
+ modification of the work as a means of enforcing, against the work's
182
+ users, your or third parties' legal rights to forbid circumvention of
183
+ technological measures.
184
+
185
+ 4. Conveying Verbatim Copies.
186
+
187
+ You may convey verbatim copies of the Program's source code as you
188
+ receive it, in any medium, provided that you conspicuously and
189
+ appropriately publish on each copy an appropriate copyright notice;
190
+ keep intact all notices stating that this License and any
191
+ non-permissive terms added in accord with section 7 apply to the code;
192
+ keep intact all notices of the absence of any warranty; and give all
193
+ recipients a copy of this License along with the Program.
194
+
195
+ You may charge any price or no price for each copy that you convey,
196
+ and you may offer support or warranty protection for a fee.
197
+
198
+ 5. Conveying Modified Source Versions.
199
+
200
+ You may convey a work based on the Program, or the modifications to
201
+ produce it from the Program, in the form of source code under the
202
+ terms of section 4, provided that you also meet all of these conditions:
203
+
204
+ a) The work must carry prominent notices stating that you modified
205
+ it, and giving a relevant date.
206
+
207
+ b) The work must carry prominent notices stating that it is
208
+ released under this License and any conditions added under section
209
+ 7. This requirement modifies the requirement in section 4 to
210
+ "keep intact all notices".
211
+
212
+ c) You must license the entire work, as a whole, under this
213
+ License to anyone who comes into possession of a copy. This
214
+ License will therefore apply, along with any applicable section 7
215
+ additional terms, to the whole of the work, and all its parts,
216
+ regardless of how they are packaged. This License gives no
217
+ permission to license the work in any other way, but it does not
218
+ invalidate such permission if you have separately received it.
219
+
220
+ d) If the work has interactive user interfaces, each must display
221
+ Appropriate Legal Notices; however, if the Program has interactive
222
+ interfaces that do not display Appropriate Legal Notices, your
223
+ work need not make them do so.
224
+
225
+ A compilation of a covered work with other separate and independent
226
+ works, which are not by their nature extensions of the covered work,
227
+ and which are not combined with it such as to form a larger program,
228
+ in or on a volume of a storage or distribution medium, is called an
229
+ "aggregate" if the compilation and its resulting copyright are not
230
+ used to limit the access or legal rights of the compilation's users
231
+ beyond what the individual works permit. Inclusion of a covered work
232
+ in an aggregate does not cause this License to apply to the other
233
+ parts of the aggregate.
234
+
235
+ 6. Conveying Non-Source Forms.
236
+
237
+ You may convey a covered work in object code form under the terms
238
+ of sections 4 and 5, provided that you also convey the
239
+ machine-readable Corresponding Source under the terms of this License,
240
+ in one of these ways:
241
+
242
+ a) Convey the object code in, or embodied in, a physical product
243
+ (including a physical distribution medium), accompanied by the
244
+ Corresponding Source fixed on a durable physical medium
245
+ customarily used for software interchange.
246
+
247
+ b) Convey the object code in, or embodied in, a physical product
248
+ (including a physical distribution medium), accompanied by a
249
+ written offer, valid for at least three years and valid for as
250
+ long as you offer spare parts or customer support for that product
251
+ model, to give anyone who possesses the object code either (1) a
252
+ copy of the Corresponding Source for all the software in the
253
+ product that is covered by this License, on a durable physical
254
+ medium customarily used for software interchange, for a price no
255
+ more than your reasonable cost of physically performing this
256
+ conveying of source, or (2) access to copy the
257
+ Corresponding Source from a network server at no charge.
258
+
259
+ c) Convey individual copies of the object code with a copy of the
260
+ written offer to provide the Corresponding Source. This
261
+ alternative is allowed only occasionally and noncommercially, and
262
+ only if you received the object code with such an offer, in accord
263
+ with subsection 6b.
264
+
265
+ d) Convey the object code by offering access from a designated
266
+ place (gratis or for a charge), and offer equivalent access to the
267
+ Corresponding Source in the same way through the same place at no
268
+ further charge. You need not require recipients to copy the
269
+ Corresponding Source along with the object code. If the place to
270
+ copy the object code is a network server, the Corresponding Source
271
+ may be on a different server (operated by you or a third party)
272
+ that supports equivalent copying facilities, provided you maintain
273
+ clear directions next to the object code saying where to find the
274
+ Corresponding Source. Regardless of what server hosts the
275
+ Corresponding Source, you remain obligated to ensure that it is
276
+ available for as long as needed to satisfy these requirements.
277
+
278
+ e) Convey the object code using peer-to-peer transmission, provided
279
+ you inform other peers where the object code and Corresponding
280
+ Source of the work are being offered to the general public at no
281
+ charge under subsection 6d.
282
+
283
+ A separable portion of the object code, whose source code is excluded
284
+ from the Corresponding Source as a System Library, need not be
285
+ included in conveying the object code work.
286
+
287
+ A "User Product" is either (1) a "consumer product", which means any
288
+ tangible personal property which is normally used for personal, family,
289
+ or household purposes, or (2) anything designed or sold for incorporation
290
+ into a dwelling. In determining whether a product is a consumer product,
291
+ doubtful cases shall be resolved in favor of coverage. For a particular
292
+ product received by a particular user, "normally used" refers to a
293
+ typical or common use of that class of product, regardless of the status
294
+ of the particular user or of the way in which the particular user
295
+ actually uses, or expects or is expected to use, the product. A product
296
+ is a consumer product regardless of whether the product has substantial
297
+ commercial, industrial or non-consumer uses, unless such uses represent
298
+ the only significant mode of use of the product.
299
+
300
+ "Installation Information" for a User Product means any methods,
301
+ procedures, authorization keys, or other information required to install
302
+ and execute modified versions of a covered work in that User Product from
303
+ a modified version of its Corresponding Source. The information must
304
+ suffice to ensure that the continued functioning of the modified object
305
+ code is in no case prevented or interfered with solely because
306
+ modification has been made.
307
+
308
+ If you convey an object code work under this section in, or with, or
309
+ specifically for use in, a User Product, and the conveying occurs as
310
+ part of a transaction in which the right of possession and use of the
311
+ User Product is transferred to the recipient in perpetuity or for a
312
+ fixed term (regardless of how the transaction is characterized), the
313
+ Corresponding Source conveyed under this section must be accompanied
314
+ by the Installation Information. But this requirement does not apply
315
+ if neither you nor any third party retains the ability to install
316
+ modified object code on the User Product (for example, the work has
317
+ been installed in ROM).
318
+
319
+ The requirement to provide Installation Information does not include a
320
+ requirement to continue to provide support service, warranty, or updates
321
+ for a work that has been modified or installed by the recipient, or for
322
+ the User Product in which it has been modified or installed. Access to a
323
+ network may be denied when the modification itself materially and
324
+ adversely affects the operation of the network or violates the rules and
325
+ protocols for communication across the network.
326
+
327
+ Corresponding Source conveyed, and Installation Information provided,
328
+ in accord with this section must be in a format that is publicly
329
+ documented (and with an implementation available to the public in
330
+ source code form), and must require no special password or key for
331
+ unpacking, reading or copying.
332
+
333
+ 7. Additional Terms.
334
+
335
+ "Additional permissions" are terms that supplement the terms of this
336
+ License by making exceptions from one or more of its conditions.
337
+ Additional permissions that are applicable to the entire Program shall
338
+ be treated as though they were included in this License, to the extent
339
+ that they are valid under applicable law. If additional permissions
340
+ apply only to part of the Program, that part may be used separately
341
+ under those permissions, but the entire Program remains governed by
342
+ this License without regard to the additional permissions.
343
+
344
+ When you convey a copy of a covered work, you may at your option
345
+ remove any additional permissions from that copy, or from any part of
346
+ it. (Additional permissions may be written to require their own
347
+ removal in certain cases when you modify the work.) You may place
348
+ additional permissions on material, added by you to a covered work,
349
+ for which you have or can give appropriate copyright permission.
350
+
351
+ Notwithstanding any other provision of this License, for material you
352
+ add to a covered work, you may (if authorized by the copyright holders of
353
+ that material) supplement the terms of this License with terms:
354
+
355
+ a) Disclaiming warranty or limiting liability differently from the
356
+ terms of sections 15 and 16 of this License; or
357
+
358
+ b) Requiring preservation of specified reasonable legal notices or
359
+ author attributions in that material or in the Appropriate Legal
360
+ Notices displayed by works containing it; or
361
+
362
+ c) Prohibiting misrepresentation of the origin of that material, or
363
+ requiring that modified versions of such material be marked in
364
+ reasonable ways as different from the original version; or
365
+
366
+ d) Limiting the use for publicity purposes of names of licensors or
367
+ authors of the material; or
368
+
369
+ e) Declining to grant rights under trademark law for use of some
370
+ trade names, trademarks, or service marks; or
371
+
372
+ f) Requiring indemnification of licensors and authors of that
373
+ material by anyone who conveys the material (or modified versions of
374
+ it) with contractual assumptions of liability to the recipient, for
375
+ any liability that these contractual assumptions directly impose on
376
+ those licensors and authors.
377
+
378
+ All other non-permissive additional terms are considered "further
379
+ restrictions" within the meaning of section 10. If the Program as you
380
+ received it, or any part of it, contains a notice stating that it is
381
+ governed by this License along with a term that is a further
382
+ restriction, you may remove that term. If a license document contains
383
+ a further restriction but permits relicensing or conveying under this
384
+ License, you may add to a covered work material governed by the terms
385
+ of that license document, provided that the further restriction does
386
+ not survive such relicensing or conveying.
387
+
388
+ If you add terms to a covered work in accord with this section, you
389
+ must place, in the relevant source files, a statement of the
390
+ additional terms that apply to those files, or a notice indicating
391
+ where to find the applicable terms.
392
+
393
+ Additional terms, permissive or non-permissive, may be stated in the
394
+ form of a separately written license, or stated as exceptions;
395
+ the above requirements apply either way.
396
+
397
+ 8. Termination.
398
+
399
+ You may not propagate or modify a covered work except as expressly
400
+ provided under this License. Any attempt otherwise to propagate or
401
+ modify it is void, and will automatically terminate your rights under
402
+ this License (including any patent licenses granted under the third
403
+ paragraph of section 11).
404
+
405
+ However, if you cease all violation of this License, then your
406
+ license from a particular copyright holder is reinstated (a)
407
+ provisionally, unless and until the copyright holder explicitly and
408
+ finally terminates your license, and (b) permanently, if the copyright
409
+ holder fails to notify you of the violation by some reasonable means
410
+ prior to 60 days after the cessation.
411
+
412
+ Moreover, your license from a particular copyright holder is
413
+ reinstated permanently if the copyright holder notifies you of the
414
+ violation by some reasonable means, this is the first time you have
415
+ received notice of violation of this License (for any work) from that
416
+ copyright holder, and you cure the violation prior to 30 days after
417
+ your receipt of the notice.
418
+
419
+ Termination of your rights under this section does not terminate the
420
+ licenses of parties who have received copies or rights from you under
421
+ this License. If your rights have been terminated and not permanently
422
+ reinstated, you do not qualify to receive new licenses for the same
423
+ material under section 10.
424
+
425
+ 9. Acceptance Not Required for Having Copies.
426
+
427
+ You are not required to accept this License in order to receive or
428
+ run a copy of the Program. Ancillary propagation of a covered work
429
+ occurring solely as a consequence of using peer-to-peer transmission
430
+ to receive a copy likewise does not require acceptance. However,
431
+ nothing other than this License grants you permission to propagate or
432
+ modify any covered work. These actions infringe copyright if you do
433
+ not accept this License. Therefore, by modifying or propagating a
434
+ covered work, you indicate your acceptance of this License to do so.
435
+
436
+ 10. Automatic Licensing of Downstream Recipients.
437
+
438
+ Each time you convey a covered work, the recipient automatically
439
+ receives a license from the original licensors, to run, modify and
440
+ propagate that work, subject to this License. You are not responsible
441
+ for enforcing compliance by third parties with this License.
442
+
443
+ An "entity transaction" is a transaction transferring control of an
444
+ organization, or substantially all assets of one, or subdividing an
445
+ organization, or merging organizations. If propagation of a covered
446
+ work results from an entity transaction, each party to that
447
+ transaction who receives a copy of the work also receives whatever
448
+ licenses to the work the party's predecessor in interest had or could
449
+ give under the previous paragraph, plus a right to possession of the
450
+ Corresponding Source of the work from the predecessor in interest, if
451
+ the predecessor has it or can get it with reasonable efforts.
452
+
453
+ You may not impose any further restrictions on the exercise of the
454
+ rights granted or affirmed under this License. For example, you may
455
+ not impose a license fee, royalty, or other charge for exercise of
456
+ rights granted under this License, and you may not initiate litigation
457
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
458
+ any patent claim is infringed by making, using, selling, offering for
459
+ sale, or importing the Program or any portion of it.
460
+
461
+ 11. Patents.
462
+
463
+ A "contributor" is a copyright holder who authorizes use under this
464
+ License of the Program or a work on which the Program is based. The
465
+ work thus licensed is called the contributor's "contributor version".
466
+
467
+ A contributor's "essential patent claims" are all patent claims
468
+ owned or controlled by the contributor, whether already acquired or
469
+ hereafter acquired, that would be infringed by some manner, permitted
470
+ by this License, of making, using, or selling its contributor version,
471
+ but do not include claims that would be infringed only as a
472
+ consequence of further modification of the contributor version. For
473
+ purposes of this definition, "control" includes the right to grant
474
+ patent sublicenses in a manner consistent with the requirements of
475
+ this License.
476
+
477
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
478
+ patent license under the contributor's essential patent claims, to
479
+ make, use, sell, offer for sale, import and otherwise run, modify and
480
+ propagate the contents of its contributor version.
481
+
482
+ In the following three paragraphs, a "patent license" is any express
483
+ agreement or commitment, however denominated, not to enforce a patent
484
+ (such as an express permission to practice a patent or covenant not to
485
+ sue for patent infringement). To "grant" such a patent license to a
486
+ party means to make such an agreement or commitment not to enforce a
487
+ patent against the party.
488
+
489
+ If you convey a covered work, knowingly relying on a patent license,
490
+ and the Corresponding Source of the work is not available for anyone
491
+ to copy, free of charge and under the terms of this License, through a
492
+ publicly available network server or other readily accessible means,
493
+ then you must either (1) cause the Corresponding Source to be so
494
+ available, or (2) arrange to deprive yourself of the benefit of the
495
+ patent license for this particular work, or (3) arrange, in a manner
496
+ consistent with the requirements of this License, to extend the patent
497
+ license to downstream recipients. "Knowingly relying" means you have
498
+ actual knowledge that, but for the patent license, your conveying the
499
+ covered work in a country, or your recipient's use of the covered work
500
+ in a country, would infringe one or more identifiable patents in that
501
+ country that you have reason to believe are valid.
502
+
503
+ If, pursuant to or in connection with a single transaction or
504
+ arrangement, you convey, or propagate by procuring conveyance of, a
505
+ covered work, and grant a patent license to some of the parties
506
+ receiving the covered work authorizing them to use, propagate, modify
507
+ or convey a specific copy of the covered work, then the patent license
508
+ you grant is automatically extended to all recipients of the covered
509
+ work and works based on it.
510
+
511
+ A patent license is "discriminatory" if it does not include within
512
+ the scope of its coverage, prohibits the exercise of, or is
513
+ conditioned on the non-exercise of one or more of the rights that are
514
+ specifically granted under this License. You may not convey a covered
515
+ work if you are a party to an arrangement with a third party that is
516
+ in the business of distributing software, under which you make payment
517
+ to the third party based on the extent of your activity of conveying
518
+ the work, and under which the third party grants, to any of the
519
+ parties who would receive the covered work from you, a discriminatory
520
+ patent license (a) in connection with copies of the covered work
521
+ conveyed by you (or copies made from those copies), or (b) primarily
522
+ for and in connection with specific products or compilations that
523
+ contain the covered work, unless you entered into that arrangement,
524
+ or that patent license was granted, prior to 28 March 2007.
525
+
526
+ Nothing in this License shall be construed as excluding or limiting
527
+ any implied license or other defenses to infringement that may
528
+ otherwise be available to you under applicable patent law.
529
+
530
+ 12. No Surrender of Others' Freedom.
531
+
532
+ If conditions are imposed on you (whether by court order, agreement or
533
+ otherwise) that contradict the conditions of this License, they do not
534
+ excuse you from the conditions of this License. If you cannot convey a
535
+ covered work so as to satisfy simultaneously your obligations under this
536
+ License and any other pertinent obligations, then as a consequence you may
537
+ not convey it at all. For example, if you agree to terms that obligate you
538
+ to collect a royalty for further conveying from those to whom you convey
539
+ the Program, the only way you could satisfy both those terms and this
540
+ License would be to refrain entirely from conveying the Program.
541
+
542
+ 13. Remote Network Interaction; Use with the GNU General Public License.
543
+
544
+ Notwithstanding any other provision of this License, if you modify the
545
+ Program, your modified version must prominently offer all users
546
+ interacting with it remotely through a computer network (if your version
547
+ supports such interaction) an opportunity to receive the Corresponding
548
+ Source of your version by providing access to the Corresponding Source
549
+ from a network server at no charge, through some standard or customary
550
+ means of facilitating copying of software. This Corresponding Source
551
+ shall include the Corresponding Source for any work covered by version 3
552
+ of the GNU General Public License that is incorporated pursuant to the
553
+ following paragraph.
554
+
555
+ Notwithstanding any other provision of this License, you have
556
+ permission to link or combine any covered work with a work licensed
557
+ under version 3 of the GNU General Public License into a single
558
+ combined work, and to convey the resulting work. The terms of this
559
+ License will continue to apply to the part which is the covered work,
560
+ but the work with which it is combined will remain governed by version
561
+ 3 of the GNU General Public License.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU Affero General Public License from time to time. Such new versions
567
+ will be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU Affero General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU Affero General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU Affero General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU Affero General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU Affero General Public License for more details.
646
+
647
+ You should have received a copy of the GNU Affero General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If your software can interact with users remotely through a computer
653
+ network, you should also make sure that it provides a way for users to
654
+ get its source. For example, if your program is a web application, its
655
+ interface could display a "Source" link that leads users to an archive
656
+ of the code. There are many ways you could offer source, and different
657
+ solutions will be better for different programs; see section 13 for the
658
+ specific requirements.
659
+
660
+ You should also get your employer (if you work as a programmer) or school,
661
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
662
+ For more information on this, and how to apply and follow the GNU AGPL, see
663
+ <https://www.gnu.org/licenses/>.
app.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AGPL: a notification must be added stating that changes have been made to that file.
2
+
3
+ import os
4
+ import shutil
5
+ from pathlib import Path
6
+
7
+ import streamlit as st
8
+ from random import randint
9
+
10
+ from tortoise.api import MODELS_DIR
11
+ from tortoise.inference import (
12
+ infer_on_texts,
13
+ run_and_save_tts,
14
+ split_and_recombine_text,
15
+ )
16
+ from tortoise.utils.diffusion import SAMPLERS
17
+ from app_utils.filepicker import st_file_selector
18
+ from app_utils.conf import TortoiseConfig
19
+
20
+ from app_utils.funcs import (
21
+ timeit,
22
+ load_model,
23
+ list_voices,
24
+ load_voice_conditionings,
25
+ )
26
+
27
+
28
+ LATENT_MODES = [
29
+ "Tortoise original (bad)",
30
+ "average per 4.27s (broken on small files)",
31
+ "average per voice file (broken on small files)",
32
+ ]
33
+
34
+ def main():
35
+ conf = TortoiseConfig()
36
+
37
+ with st.expander("Create New Voice", expanded=True):
38
+ if "file_uploader_key" not in st.session_state:
39
+ st.session_state["file_uploader_key"] = str(randint(1000, 100000000))
40
+ st.session_state["text_input_key"] = str(randint(1000, 100000000))
41
+
42
+ uploaded_files = st.file_uploader(
43
+ "Upload Audio Samples for a New Voice",
44
+ accept_multiple_files=True,
45
+ type=["wav"],
46
+ key=st.session_state["file_uploader_key"]
47
+ )
48
+
49
+ voice_name = st.text_input(
50
+ "New Voice Name",
51
+ help="Enter a name for your new voice.",
52
+ value="",
53
+ key=st.session_state["text_input_key"]
54
+ )
55
+
56
+ create_voice_button = st.button(
57
+ "Create Voice",
58
+ disabled = ((voice_name.strip() == "") | (len(uploaded_files) == 0))
59
+ )
60
+ if create_voice_button:
61
+ st.write(st.session_state)
62
+ with st.spinner(f"Creating new voice: {voice_name}"):
63
+ new_voice_name = voice_name.strip().replace(" ", "_")
64
+
65
+ voices_dir = f'./tortoise/voices/{new_voice_name}/'
66
+ if os.path.exists(voices_dir):
67
+ shutil.rmtree(voices_dir)
68
+ os.makedirs(voices_dir)
69
+
70
+ for index, uploaded_file in enumerate(uploaded_files):
71
+ bytes_data = uploaded_file.read()
72
+ with open(f"{voices_dir}voice_sample{index}.wav", "wb") as wav_file:
73
+ wav_file.write(bytes_data)
74
+
75
+ st.session_state["text_input_key"] = str(randint(1000, 100000000))
76
+ st.session_state["file_uploader_key"] = str(randint(1000, 100000000))
77
+ st.experimental_rerun()
78
+
79
+ text = st.text_area(
80
+ "Text",
81
+ help="Text to speak.",
82
+ value="The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them.",
83
+ )
84
+
85
+ voices = [v for v in os.listdir("tortoise/voices") if v != "cond_latent_example"]
86
+
87
+ voice = st.selectbox(
88
+ "Voice",
89
+ voices,
90
+ help="Selects the voice to use for generation. See options in voices/ directory (and add your own!) "
91
+ "Use the & character to join two voices together. Use a comma to perform inference on multiple voices.",
92
+ index=0,
93
+ )
94
+ preset = st.selectbox(
95
+ "Preset",
96
+ (
97
+ "single_sample",
98
+ "ultra_fast",
99
+ "very_fast",
100
+ "ultra_fast_old",
101
+ "fast",
102
+ "standard",
103
+ "high_quality",
104
+ ),
105
+ help="Which voice preset to use.",
106
+ index=1,
107
+ )
108
+ with st.expander("Advanced"):
109
+ col1, col2 = st.columns(2)
110
+ with col1:
111
+ """#### Model parameters"""
112
+ candidates = st.number_input(
113
+ "Candidates",
114
+ help="How many output candidates to produce per-voice.",
115
+ value=1,
116
+ )
117
+ latent_averaging_mode = st.radio(
118
+ "Latent averaging mode",
119
+ LATENT_MODES,
120
+ help="How voice samples should be averaged together.",
121
+ index=0,
122
+ )
123
+ sampler = st.radio(
124
+ "Sampler",
125
+ #SAMPLERS,
126
+ ["dpm++2m", "p", "ddim"],
127
+ help="Diffusion sampler. Note that dpm++2m is experimental and typically requires more steps.",
128
+ index=1,
129
+ )
130
+ steps = st.number_input(
131
+ "Steps",
132
+ help="Override the steps used for diffusion (default depends on preset)",
133
+ value=10,
134
+ )
135
+ seed = st.number_input(
136
+ "Seed",
137
+ help="Random seed which can be used to reproduce results.",
138
+ value=-1,
139
+ )
140
+ if seed == -1:
141
+ seed = None
142
+ voice_fixer = st.checkbox(
143
+ "Voice fixer",
144
+ help="Use `voicefixer` to improve audio quality. This is a post-processing step which can be applied to any output.",
145
+ value=True,
146
+ )
147
+ """#### Directories"""
148
+ output_path = st.text_input(
149
+ "Output Path", help="Where to store outputs.", value="results/"
150
+ )
151
+
152
+ with col2:
153
+ """#### Optimizations"""
154
+ high_vram = not st.checkbox(
155
+ "Low VRAM",
156
+ help="Re-enable default offloading behaviour of tortoise",
157
+ value=True,
158
+ )
159
+ half = st.checkbox(
160
+ "Half-Precision",
161
+ help="Enable autocast to half precision for autoregressive model",
162
+ value=False,
163
+ )
164
+ kv_cache = st.checkbox(
165
+ "Key-Value Cache",
166
+ help="Enable kv_cache usage, leading to drastic speedups but worse memory usage",
167
+ value=True,
168
+ )
169
+ cond_free = st.checkbox(
170
+ "Conditioning Free",
171
+ help="Force conditioning free diffusion",
172
+ value=True,
173
+ )
174
+ no_cond_free = st.checkbox(
175
+ "Force Not Conditioning Free",
176
+ help="Force disable conditioning free diffusion",
177
+ value=False,
178
+ )
179
+
180
+ """#### Text Splitting"""
181
+ min_chars_to_split = st.number_input(
182
+ "Min Chars to Split",
183
+ help="Minimum number of characters to split text on",
184
+ min_value=50,
185
+ value=200,
186
+ step=1,
187
+ )
188
+
189
+ """#### Debug"""
190
+ produce_debug_state = st.checkbox(
191
+ "Produce Debug State",
192
+ help="Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.",
193
+ value=True,
194
+ )
195
+
196
+ ar_checkpoint = "."
197
+ diff_checkpoint = "."
198
+ if st.button("Update Basic Settings"):
199
+ conf.update(
200
+ EXTRA_VOICES_DIR=extra_voices_dir,
201
+ LOW_VRAM=not high_vram,
202
+ AR_CHECKPOINT=ar_checkpoint,
203
+ DIFF_CHECKPOINT=diff_checkpoint,
204
+ )
205
+
206
+ ar_checkpoint = None
207
+ diff_checkpoint = None
208
+ tts = load_model(MODELS_DIR, high_vram, kv_cache, ar_checkpoint, diff_checkpoint)
209
+
210
+ if st.button("Start"):
211
+ assert latent_averaging_mode
212
+ assert preset
213
+ assert voice
214
+
215
+ def show_generation(fp, filename: str):
216
+ """
217
+ audio_buffer = BytesIO()
218
+ save_gen_with_voicefix(g, audio_buffer, squeeze=False)
219
+ torchaudio.save(audio_buffer, g, 24000, format='wav')
220
+ """
221
+ st.audio(str(fp), format="audio/wav")
222
+ st.download_button(
223
+ "Download sample",
224
+ str(fp),
225
+ file_name=filename, # this doesn't actually seem to work lol
226
+ )
227
+
228
+ with st.spinner(
229
+ f"Generating {candidates} candidates for voice {voice} (seed={seed}). You can see progress in the terminal"
230
+ ):
231
+ os.makedirs(output_path, exist_ok=True)
232
+
233
+ selected_voices = voice.split(",")
234
+ for k, selected_voice in enumerate(selected_voices):
235
+ if "&" in selected_voice:
236
+ voice_sel = selected_voice.split("&")
237
+ else:
238
+ voice_sel = [selected_voice]
239
+ voice_samples, conditioning_latents = load_voice_conditionings(
240
+ voice_sel, []
241
+ )
242
+
243
+ voice_path = Path(os.path.join(output_path, selected_voice))
244
+
245
+ with timeit(
246
+ f"Generating {candidates} candidates for voice {selected_voice} (seed={seed})"
247
+ ):
248
+ nullable_kwargs = {
249
+ k: v
250
+ for k, v in zip(
251
+ ["sampler", "diffusion_iterations", "cond_free"],
252
+ [sampler, steps, cond_free],
253
+ )
254
+ if v is not None
255
+ }
256
+
257
+ def call_tts(text: str):
258
+ return tts.tts_with_preset(
259
+ text,
260
+ k=candidates,
261
+ voice_samples=voice_samples,
262
+ conditioning_latents=conditioning_latents,
263
+ preset=preset,
264
+ use_deterministic_seed=seed,
265
+ return_deterministic_state=True,
266
+ cvvp_amount=0.0,
267
+ half=half,
268
+ latent_averaging_mode=LATENT_MODES.index(
269
+ latent_averaging_mode
270
+ ),
271
+ **nullable_kwargs,
272
+ )
273
+
274
+ if len(text) < min_chars_to_split:
275
+ filepaths = run_and_save_tts(
276
+ call_tts,
277
+ text,
278
+ voice_path,
279
+ return_deterministic_state=True,
280
+ return_filepaths=True,
281
+ voicefixer=voice_fixer,
282
+ )
283
+ for i, fp in enumerate(filepaths):
284
+ show_generation(fp, f"{selected_voice}-text-{i}.wav")
285
+ else:
286
+ desired_length = int(min_chars_to_split)
287
+ texts = split_and_recombine_text(
288
+ text, desired_length, desired_length + 100
289
+ )
290
+ filepaths = infer_on_texts(
291
+ call_tts,
292
+ texts,
293
+ voice_path,
294
+ return_deterministic_state=True,
295
+ return_filepaths=True,
296
+ lines_to_regen=set(range(len(texts))),
297
+ voicefixer=voice_fixer,
298
+ )
299
+ for i, fp in enumerate(filepaths):
300
+ show_generation(fp, f"{selected_voice}-text-{i}.wav")
301
+ if produce_debug_state:
302
+ """Debug states can be found in the output directory"""
303
+
304
+
305
+ if __name__ == "__main__":
306
+ main()
app_utils/conf.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shelve
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class PersistentSettings(BaseModel):
9
+ """
10
+ This pydantic model will try to initialize itself from
11
+ the database upon every instantiation
12
+
13
+ It further supplies an update function, that allows to write
14
+ back any changes into the database, under its key.
15
+ """
16
+
17
+ def __init__(self, **data: Any):
18
+ with shelve.open("config.db") as db:
19
+ super().__init__(**db.get("settings", default={}), **data)
20
+
21
+ def update(self, **data: Any) -> None:
22
+ """
23
+ Persist the pydantic-dict that represents the model
24
+ """
25
+ with shelve.open("config.db") as db:
26
+ db["settings"] = {**self.dict(), **data}
27
+
28
+
29
+ class TortoiseConfig(PersistentSettings):
30
+ EXTRA_VOICES_DIR: str = ""
31
+ AR_CHECKPOINT: str = "."
32
+ DIFF_CHECKPOINT: str = "."
33
+ LOW_VRAM: bool = True
34
+
35
+ def __init__(self, **data: Any):
36
+ super().__init__(**data)
37
+ if not Path(self.AR_CHECKPOINT).is_file():
38
+ self.AR_CHECKPOINT = "."
39
+ if not Path(self.DIFF_CHECKPOINT).is_file():
40
+ self.DIFF_CHECKPOINT = "."
app_utils/filepicker.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # taken from https://gist.github.com/benlansdell/44000c264d1b373c77497c0ea73f0ef2
2
+ # slightly modified
3
+ """FilePicker for streamlit.
4
+
5
+ Still doesn't seem to be a good solution for a way to select files to process from the server Streamlit is running on.
6
+
7
+ Here's a pretty functional solution.
8
+
9
+ Usage:
10
+
11
+ ```
12
+ import streamlit as st
13
+ from filepicker import st_file_selector
14
+
15
+ tif_file = st_file_selector(st, key = 'tif', label = 'Choose tif file')
16
+ ```
17
+ """
18
+
19
+ import os
20
+
21
+ import streamlit as st
22
+
23
+
24
+ def update_dir(key):
25
+ global i_will_regret_this, i_will_regret_this2
26
+ choice = st.session_state[key]
27
+ if os.path.isdir(os.path.join(st.session_state[key + "curr_dir"], choice)):
28
+ st.session_state[key + "index"] = 0
29
+ st.session_state[key + "curr_dir"] = os.path.normpath(
30
+ os.path.join(st.session_state[key + "curr_dir"], choice)
31
+ )
32
+ files = sorted(os.listdir(st.session_state[key + "curr_dir"]))
33
+ files.insert(0, "..")
34
+ files.insert(0, ".")
35
+ st.session_state[key + "files"] = files
36
+
37
+
38
+ def st_file_selector(
39
+ st_placeholder, path=".", label="Select a file/folder", key="selected"
40
+ ):
41
+ if key + "curr_dir" not in st.session_state:
42
+ base_path = "." if path is None or path == "" else path
43
+ base_path = (
44
+ base_path if os.path.isdir(base_path) else os.path.dirname(base_path)
45
+ )
46
+ base_path = "." if base_path is None or base_path == "" else base_path
47
+
48
+ files = sorted(os.listdir(base_path))
49
+ files.insert(0, "..")
50
+ files.insert(0, ".")
51
+ st.session_state[key + "files"] = files
52
+ st.session_state[key + "curr_dir"] = base_path
53
+ st.session_state[key + "index"] = (
54
+ st.session_state[key + "files"].index(os.path.basename(path))
55
+ if os.path.isfile(path) and path[-4:] == '.pth'
56
+ else 0
57
+ )
58
+ else:
59
+ base_path = st.session_state[key + "curr_dir"]
60
+
61
+ selected_file = st_placeholder.selectbox(
62
+ label=label,
63
+ options=st.session_state[key + "files"],
64
+ index=st.session_state[key + "index"],
65
+ key=key,
66
+ on_change=lambda: update_dir(key),
67
+ )
68
+ selected_path = os.path.normpath(os.path.join(base_path, selected_file))
69
+ st_placeholder.write(os.path.abspath(selected_path))
70
+
71
+ return selected_path
app_utils/funcs.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ from contextlib import contextmanager
4
+ from time import time
5
+ from typing import Optional
6
+
7
+ import streamlit as st
8
+
9
+ from tortoise.api import TextToSpeech
10
+ from tortoise.utils.audio import load_voices
11
+
12
+
13
+ @contextmanager
14
+ def timeit(desc=""):
15
+ start = time()
16
+ yield
17
+ print(f"{desc} took {time() - start:.2f} seconds")
18
+
19
+
20
+ @st.cache_resource(max_entries=1)
21
+ def load_model(
22
+ model_dir,
23
+ high_vram,
24
+ kv_cache,
25
+ ar_checkpoint,
26
+ diff_checkpoint,
27
+ ):
28
+ gc.collect()
29
+ return TextToSpeech(
30
+ models_dir=model_dir,
31
+ high_vram=high_vram,
32
+ kv_cache=kv_cache,
33
+ ar_checkpoint=ar_checkpoint,
34
+ diff_checkpoint=diff_checkpoint,
35
+ )
36
+
37
+
38
+ @st.cache_data
39
+ def list_voices(extra_voices_dir: Optional[str]):
40
+ voices = ["random"]
41
+ if extra_voices_dir and os.path.isdir(extra_voices_dir):
42
+ voices.extend(os.listdir(extra_voices_dir))
43
+ extra_voices_ls = [extra_voices_dir]
44
+ else:
45
+ extra_voices_ls = []
46
+ voices.extend(
47
+ [v for v in os.listdir("tortoise/voices") if v != "cond_latent_example"]
48
+ )
49
+ #
50
+ return voices, extra_voices_ls
51
+
52
+
53
+ @st.cache_resource(max_entries=1)
54
+ def load_voice_conditionings(voice, extra_voices_ls):
55
+ gc.collect()
56
+ voice_samples, conditioning_latents = load_voices(voice, extra_voices_ls)
57
+ return voice_samples, conditioning_latents
requirements.txt ADDED
The diff for this file is too large to render. See raw diff
 
tortoise/__init__.py ADDED
File without changes
tortoise/api.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ## AGPL: a notification must be added stating that changes have been made to that file.
2
+
3
+ import os
4
+ import random
5
+ from time import time
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchaudio
10
+ from tqdm import tqdm
11
+
12
+ from tortoise.models.arch_util import TorchMelSpectrogram
13
+ from tortoise.models.autoregressive import UnifiedVoice
14
+ from tortoise.models.classifier import AudioMiniEncoderWithClassifierHead
15
+ from tortoise.models.clvp import CLVP
16
+ from tortoise.models.cvvp import CVVP
17
+ from tortoise.models.diffusion_decoder import DiffusionTts
18
+ from tortoise.models.random_latent_generator import RandomLatentConverter
19
+ from tortoise.models.vocoder import VocConf
20
+ from tortoise.utils.audio import denormalize_tacotron_mel, wav_to_univnet_mel
21
+ from tortoise.utils.diffusion import (
22
+ SpacedDiffusion,
23
+ get_named_beta_schedule,
24
+ space_timesteps,
25
+ )
26
+ from tortoise.utils.tokenizer import VoiceBpeTokenizer
27
+ from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
28
+
29
+ from tortoise.models.utils import MODELS_DIR, get_model_path
30
+
31
+ from contextlib import contextmanager
32
+
33
+ def pad_or_truncate(t, length):
34
+ """
35
+ Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
36
+ """
37
+ if t.shape[-1] == length:
38
+ return t
39
+ elif t.shape[-1] < length:
40
+ return F.pad(t, (0, length - t.shape[-1]))
41
+ else:
42
+ return t[..., :length]
43
+
44
+
45
+ def load_discrete_vocoder_diffuser(
46
+ trained_diffusion_steps=4000,
47
+ desired_diffusion_steps=200,
48
+ cond_free=True,
49
+ cond_free_k=1,
50
+ sampler="ddim",
51
+ ):
52
+ """
53
+ Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
54
+ """
55
+ return SpacedDiffusion(
56
+ use_timesteps=space_timesteps(
57
+ trained_diffusion_steps, [desired_diffusion_steps]
58
+ ),
59
+ model_mean_type="epsilon",
60
+ model_var_type="learned_range",
61
+ loss_type="mse",
62
+ betas=get_named_beta_schedule("linear", trained_diffusion_steps),
63
+ conditioning_free=cond_free,
64
+ conditioning_free_k=cond_free_k,
65
+ sampler=sampler,
66
+ )
67
+
68
+
69
+ def format_conditioning(clip, cond_length=132300, device="cuda"):
70
+ """
71
+ Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
72
+ """
73
+ gap = clip.shape[-1] - cond_length
74
+ if gap < 0:
75
+ clip = F.pad(clip, pad=(0, abs(gap)))
76
+ elif gap > 0:
77
+ rand_start = random.randint(0, gap)
78
+ clip = clip[:, rand_start : rand_start + cond_length]
79
+ mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0)
80
+ return mel_clip.unsqueeze(0).to(device)
81
+
82
+
83
+ def fix_autoregressive_output(codes, stop_token, complain=True):
84
+ """
85
+ This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
86
+ trained on and what the autoregressive code generator creates (which has no padding or end).
87
+ This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with
88
+ a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
89
+ and copying out the last few codes.
90
+
91
+ Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar.
92
+ """
93
+ # Strip off the autoregressive stop token and add padding.
94
+ stop_token_indices = (codes == stop_token).nonzero()
95
+ if len(stop_token_indices) == 0:
96
+ if complain:
97
+ print(
98
+ "No stop tokens found in one of the generated voice clips. This typically means the spoken audio is "
99
+ "too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, "
100
+ "try breaking up your input text."
101
+ )
102
+ return codes
103
+ else:
104
+ codes[stop_token_indices] = 83
105
+ stm = stop_token_indices.min().item()
106
+ codes[stm:] = 83
107
+ if stm - 3 < codes.shape[0]:
108
+ codes[-3] = 45
109
+ codes[-2] = 45
110
+ codes[-1] = 248
111
+
112
+ return codes
113
+
114
+
115
+ def do_spectrogram_diffusion(
116
+ diffusion_model,
117
+ diffuser,
118
+ latents,
119
+ conditioning_latents,
120
+ temperature=1,
121
+ verbose=True,
122
+ ):
123
+ """
124
+ Uses the specified diffusion model to convert discrete codes into a spectrogram.
125
+ """
126
+ with torch.no_grad():
127
+ output_seq_len = (
128
+ latents.shape[1] * 4 * 24000 // 22050
129
+ ) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
130
+ output_shape = (latents.shape[0], 100, output_seq_len)
131
+ precomputed_embeddings = diffusion_model.timestep_independent(
132
+ latents, conditioning_latents, output_seq_len, False
133
+ )
134
+
135
+ noise = torch.randn(output_shape, device=latents.device) * temperature
136
+ mel = diffuser.sample_loop(
137
+ diffusion_model,
138
+ output_shape,
139
+ noise=noise,
140
+ model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
141
+ progress=verbose,
142
+ )
143
+ return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
144
+
145
+
146
+ def classify_audio_clip(clip):
147
+ """
148
+ Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise.
149
+ :param clip: torch tensor containing audio waveform data (get it from load_audio)
150
+ :return: True if the clip was classified as coming from Tortoise and false if it was classified as real.
151
+ """
152
+ classifier = AudioMiniEncoderWithClassifierHead(
153
+ 2,
154
+ spec_dim=1,
155
+ embedding_dim=512,
156
+ depth=5,
157
+ downsample_factor=4,
158
+ resnet_blocks=2,
159
+ attn_blocks=4,
160
+ num_attn_heads=4,
161
+ base_channels=32,
162
+ dropout=0,
163
+ kernel_size=5,
164
+ distribute_zero_label=False,
165
+ )
166
+ classifier.load_state_dict(
167
+ torch.load(get_model_path("classifier.pth"), map_location=torch.device("cpu"))
168
+ )
169
+ clip = clip.cpu().unsqueeze(0)
170
+ results = F.softmax(classifier(clip), dim=-1)
171
+ return results[0][0]
172
+
173
+
174
+ def pick_best_batch_size_for_gpu():
175
+ """
176
+ Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give
177
+ you a good shot.
178
+ """
179
+ if torch.cuda.is_available():
180
+ _, available = torch.cuda.mem_get_info()
181
+ availableGb = available / (1024**3)
182
+ if availableGb > 14:
183
+ return 16
184
+ elif availableGb > 10:
185
+ return 8
186
+ elif availableGb > 7:
187
+ return 4
188
+ return 1
189
+
190
+
191
+ class TextToSpeech:
192
+ """
193
+ Main entry point into Tortoise.
194
+ """
195
+
196
+ def _config(self):
197
+ raise RuntimeError("This is depreciated")
198
+ return {
199
+ "high_vram": self.high_vram,
200
+ "models_dir": self.models_dir,
201
+ "kv_cache": self.autoregressive.inference_model.kv_cache,
202
+ "ar_checkpoint": self.ar_checkpoint,
203
+ }
204
+
205
+ def __init__(
206
+ self,
207
+ autoregressive_batch_size=None,
208
+ models_dir=MODELS_DIR,
209
+ enable_redaction=True,
210
+ device=None,
211
+ high_vram=False,
212
+ kv_cache=True,
213
+ ar_checkpoint=None,
214
+ clvp_checkpoint=None,
215
+ diff_checkpoint=None,
216
+ vocoder=VocConf.Univnet,
217
+ ):
218
+ """
219
+ Constructor
220
+ :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
221
+ GPU OOM errors. Larger numbers generates slightly faster.
222
+ :param models_dir: Where model weights are stored. This should only be specified if you are providing your own
223
+ models, otherwise use the defaults.
224
+ :param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output
225
+ (but are still rendered by the model). This can be used for prompt engineering.
226
+ Default is true.
227
+ :param device: Device to use when running the model. If omitted, the device will be automatically chosen.
228
+ :param high_vram: If true, the model will use more VRAM but will run faster.
229
+ :param kv_cache: If true, the autoregressive model will cache key value attention pairs to speed up generation.
230
+ :param ar_checkpoint: Path to a checkpoint file for the autoregressive model. If omitted, uses default
231
+ :param clvp_checkpoint: Path to a checkpoint file for the CLVP model. If omitted, uses default
232
+ :param diff_checkpoint: Path to a checkpoint file for the diffusion model. If omitted, uses default
233
+ """
234
+ self.ar_checkpoint = ar_checkpoint
235
+ self.diff_checkpoint = diff_checkpoint # TODO: check if this is even needed
236
+ self.models_dir = models_dir
237
+ self.autoregressive_batch_size = (
238
+ pick_best_batch_size_for_gpu()
239
+ if autoregressive_batch_size is None
240
+ else autoregressive_batch_size
241
+ )
242
+ self.enable_redaction = enable_redaction
243
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
244
+ if self.enable_redaction:
245
+ self.aligner = Wav2VecAlignment()
246
+
247
+ self.tokenizer = VoiceBpeTokenizer()
248
+
249
+ if os.path.exists(f"{models_dir}/autoregressive.ptt"):
250
+ # Assume this is a traced directory.
251
+ self.autoregressive = torch.jit.load(f"{models_dir}/autoregressive.ptt")
252
+ self.diffusion = torch.jit.load(f"{models_dir}/diffusion_decoder.ptt")
253
+ else:
254
+ self.autoregressive = (
255
+ UnifiedVoice(
256
+ max_mel_tokens=604,
257
+ max_text_tokens=402,
258
+ max_conditioning_inputs=2,
259
+ layers=30,
260
+ model_dim=1024,
261
+ heads=16,
262
+ number_text_tokens=255,
263
+ start_text_token=255,
264
+ checkpointing=False,
265
+ train_solo_embeddings=False,
266
+ )
267
+ .cpu()
268
+ .eval()
269
+ )
270
+ ar_path = ar_checkpoint or get_model_path("autoregressive.pth", models_dir)
271
+ self.autoregressive.load_state_dict(torch.load(ar_path))
272
+ self.autoregressive.post_init_gpt2_config(kv_cache)
273
+
274
+ diff_path = diff_checkpoint or get_model_path(
275
+ "diffusion_decoder.pth", models_dir
276
+ )
277
+ self.diffusion = (
278
+ DiffusionTts(
279
+ model_channels=1024,
280
+ num_layers=10,
281
+ in_channels=100,
282
+ out_channels=200,
283
+ in_latent_channels=1024,
284
+ in_tokens=8193,
285
+ dropout=0,
286
+ use_fp16=False,
287
+ num_heads=16,
288
+ layer_drop=0,
289
+ unconditioned_percentage=0,
290
+ )
291
+ .cpu()
292
+ .eval()
293
+ )
294
+ self.diffusion.load_state_dict(torch.load(diff_path))
295
+
296
+ self.clvp = (
297
+ CLVP(
298
+ dim_text=768,
299
+ dim_speech=768,
300
+ dim_latent=768,
301
+ num_text_tokens=256,
302
+ text_enc_depth=20,
303
+ text_seq_len=350,
304
+ text_heads=12,
305
+ num_speech_tokens=8192,
306
+ speech_enc_depth=20,
307
+ speech_heads=12,
308
+ speech_seq_len=430,
309
+ use_xformers=True,
310
+ )
311
+ .cpu()
312
+ .eval()
313
+ )
314
+ clvp_path = clvp_checkpoint or get_model_path("clvp2.pth", models_dir)
315
+ self.clvp.load_state_dict(torch.load(clvp_path))
316
+ self.cvvp = None # CVVP model is only loaded if used.
317
+
318
+ self.vocoder = vocoder.value.constructor().cpu()
319
+ self.vocoder.load_state_dict(
320
+ vocoder.value.optionally_index(
321
+ torch.load(
322
+ get_model_path(vocoder.value.model_path, models_dir),
323
+ map_location=torch.device("cpu"),
324
+ )
325
+ )
326
+ )
327
+ self.vocoder.eval(inference=True)
328
+
329
+ # Random latent generators (RLGs) are loaded lazily.
330
+ self.rlg_auto = None
331
+ self.rlg_diffusion = None
332
+
333
+ if high_vram:
334
+ self.autoregressive = self.autoregressive.to(self.device)
335
+ self.diffusion = self.diffusion.to(self.device)
336
+ self.clvp = self.clvp.to(self.device)
337
+ self.vocoder = self.vocoder.to(self.device)
338
+ self.high_vram = high_vram
339
+
340
+ @contextmanager
341
+ def temporary_cuda(self, model):
342
+ if self.high_vram:
343
+ yield model
344
+ else:
345
+ m = model.to(self.device)
346
+ yield m
347
+ m = model.cpu()
348
+
349
+ def load_cvvp(self):
350
+ """Load CVVP model."""
351
+ self.cvvp = (
352
+ CVVP(
353
+ model_dim=512,
354
+ transformer_heads=8,
355
+ dropout=0,
356
+ mel_codes=8192,
357
+ conditioning_enc_depth=8,
358
+ cond_mask_percentage=0,
359
+ speech_enc_depth=8,
360
+ speech_mask_percentage=0,
361
+ latent_multiplier=1,
362
+ )
363
+ .cpu()
364
+ .eval()
365
+ )
366
+ self.cvvp.load_state_dict(
367
+ torch.load(get_model_path("cvvp.pth", self.models_dir))
368
+ )
369
+
370
+ def get_conditioning_latents(self, voice_samples, return_mels=False, latent_averaging_mode=0, original_tortoise=False):
371
+ """
372
+ Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
373
+ These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
374
+ properties.
375
+ :param voice_samples: List of arbitrary reference clips, which should be *pairs* of torch tensors containing arbitrary kHz waveform data.
376
+ :param latent_averaging_mode: 0/1/2 for following modes:
377
+ 0 - latents will be generated as in the original tortoise, using ~4.27s from each voice sample, averaging latent across all samples
378
+ 1 - latents will be generated using (almost) entire voice samples, averaged across all the ~4.27s chunks
379
+ 2 - latents will be generated using (almost) entire voice samples, averaged per voice sample
380
+ """
381
+ assert latent_averaging_mode in [0, 1, 2], "latent_averaging mode has to be one of (0, 1, 2)"
382
+ print("mode", latent_averaging_mode)
383
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
384
+
385
+ with torch.no_grad():
386
+ # Move the entire nested structure to the device
387
+ voice_samples = [
388
+ (pair[0].to(device), pair[1].to(device))
389
+ for pair in voice_samples
390
+ ]
391
+
392
+ auto_conds = []
393
+ for ls in voice_samples:
394
+ auto_conds.append(format_conditioning(ls[0], device=device)) # Use device here
395
+ auto_conds = torch.stack(auto_conds, dim=1)
396
+ with self.temporary_cuda(self.autoregressive) as ar:
397
+ auto_latent = ar.get_conditioning(auto_conds)
398
+
399
+ diffusion_conds = []
400
+
401
+ DURS_CONST = 102400
402
+ for ls in voice_samples:
403
+ # The diffuser operates at a sample rate of 24000 (except for the latent inputs)
404
+ sample = (
405
+ torchaudio.functional.resample(ls[0], 22050, 24000)
406
+ if original_tortoise
407
+ else ls[1]
408
+ )
409
+ if latent_averaging_mode == 0:
410
+ sample = pad_or_truncate(sample, DURS_CONST)
411
+ cond_mel = wav_to_univnet_mel(
412
+ sample.to(device), # Use device here
413
+ do_normalization=False,
414
+ device=device,
415
+ )
416
+ diffusion_conds.append(cond_mel)
417
+ else:
418
+ from math import ceil
419
+
420
+ if latent_averaging_mode == 2:
421
+ temp_diffusion_conds = []
422
+ for chunk in range(ceil(sample.shape[1] / DURS_CONST)):
423
+ current_sample = sample[
424
+ :, chunk * DURS_CONST : (chunk + 1) * DURS_CONST
425
+ ]
426
+ current_sample = pad_or_truncate(current_sample, DURS_CONST)
427
+ cond_mel = wav_to_univnet_mel(
428
+ current_sample.to(device), # Use device here
429
+ do_normalization=False,
430
+ device=device,
431
+ )
432
+ if latent_averaging_mode == 1:
433
+ diffusion_conds.append(cond_mel)
434
+ elif latent_averaging_mode == 2:
435
+ temp_diffusion_conds.append(cond_mel)
436
+ if latent_averaging_mode == 2:
437
+ diffusion_conds.append(
438
+ torch.stack(temp_diffusion_conds).mean(0)
439
+ )
440
+ diffusion_conds = torch.stack(diffusion_conds, dim=1)
441
+
442
+ with self.temporary_cuda(self.diffusion) as diffusion:
443
+ diffusion_latent = diffusion.get_conditioning(diffusion_conds)
444
+
445
+ if return_mels:
446
+ return auto_latent, diffusion_latent, auto_conds, diffusion_conds
447
+ else:
448
+ return auto_latent, diffusion_latent
449
+
450
+ def get_random_conditioning_latents(self):
451
+ # Lazy-load the RLG models.
452
+ if self.rlg_auto is None:
453
+ self.rlg_auto = RandomLatentConverter(1024).eval()
454
+ self.rlg_auto.load_state_dict(
455
+ torch.load(
456
+ get_model_path("rlg_auto.pth", self.models_dir),
457
+ map_location=torch.device("cpu"),
458
+ )
459
+ )
460
+ self.rlg_diffusion = RandomLatentConverter(2048).eval()
461
+ self.rlg_diffusion.load_state_dict(
462
+ torch.load(
463
+ get_model_path("rlg_diffuser.pth", self.models_dir),
464
+ map_location=torch.device("cpu"),
465
+ )
466
+ )
467
+ with torch.no_grad():
468
+ return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(
469
+ torch.tensor([0.0])
470
+ )
471
+
472
+ def tts_with_preset(self, text, preset="fast", **kwargs):
473
+ """
474
+ Calls TTS with one of a set of preset generation parameters. Options:
475
+ 'single_sample': Produces speech even faster, but only produces 1 sample.
476
+ 'ultra_fast': Produces speech much faster than the original tortoise repo.
477
+ 'ultra_fast_old': Produces speech at a speed which belies the name of this repo. (Not really, but it's definitely fastest).
478
+ 'fast': Decent quality speech at a decent inference rate. A good choice for mass inference.
479
+ 'standard': Very good quality. This is generally about as good as you are going to get.
480
+ 'high_quality': Use if you want the absolute best. This is not really worth the compute, though.
481
+ """
482
+ # Use generally found best tuning knobs for generation.
483
+ settings = {
484
+ "temperature": 0.2,
485
+ "length_penalty": 1.0,
486
+ "repetition_penalty": 2.0,
487
+ "top_p": 0.8,
488
+ "cond_free_k": 2.0,
489
+ "diffusion_temperature": 1.0,
490
+ }
491
+ # Presets are defined here.
492
+ presets = {
493
+ "single_sample": {
494
+ "num_autoregressive_samples": 8,
495
+ "diffusion_iterations": 10,
496
+ "sampler": "ddim",
497
+ },
498
+ "ultra_fast": {
499
+ "num_autoregressive_samples": 16,
500
+ "diffusion_iterations": 10,
501
+ "sampler": "ddim",
502
+ },
503
+ "ultra_fast_old": {
504
+ "num_autoregressive_samples": 16,
505
+ "diffusion_iterations": 30,
506
+ "cond_free": False,
507
+ },
508
+ "very_fast": {
509
+ "num_autoregressive_samples": 32,
510
+ "diffusion_iterations": 30,
511
+ "sampler": "dpm++2m",
512
+ },
513
+ "fast": {
514
+ "num_autoregressive_samples": 96,
515
+ "diffusion_iterations": 20,
516
+ "sampler": "dpm++2m",
517
+ },
518
+ "fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80},
519
+ "standard": {
520
+ "num_autoregressive_samples": 256,
521
+ "diffusion_iterations": 200,
522
+ },
523
+ "high_quality": {
524
+ "num_autoregressive_samples": 256,
525
+ "diffusion_iterations": 400,
526
+ },
527
+ }
528
+ settings.update(presets[preset])
529
+ settings.update(kwargs) # allow overriding of preset settings with kwargs
530
+ return self.tts(text, **settings)
531
+
532
+ def tts(
533
+ self,
534
+ text,
535
+ voice_samples=None,
536
+ conditioning_latents=None,
537
+ k=1,
538
+ verbose=True,
539
+ use_deterministic_seed=None,
540
+ return_deterministic_state=False,
541
+ latent_averaging_mode=0,
542
+ # autoregressive generation parameters follow
543
+ num_autoregressive_samples=512,
544
+ temperature=0.8,
545
+ length_penalty=1,
546
+ repetition_penalty=2.0,
547
+ top_p=0.8,
548
+ max_mel_tokens=500,
549
+ # CVVP parameters follow
550
+ cvvp_amount=0.0,
551
+ # diffusion generation parameters follow
552
+ diffusion_iterations=100,
553
+ cond_free=True,
554
+ cond_free_k=2,
555
+ diffusion_temperature=1.0,
556
+ sampler="ddim",
557
+ half=True,
558
+ original_tortoise=False,
559
+ **hf_generate_kwargs,
560
+ ):
561
+ """
562
+ Produces an audio clip of the given text being spoken with the given reference voice.
563
+ :param text: Text to be spoken.
564
+ :param voice_samples: List of an arbitrary number of reference clips, which should be *tuple-pairs* of torch tensors containing arbitrary kHz waveform data.
565
+ :param conditioning_latents: A tuple of (autoregressive_conditioning_latent, diffusion_conditioning_latent), which
566
+ can be provided in lieu of voice_samples. This is ignored unless voice_samples=None.
567
+ Conditioning latents can be retrieved via get_conditioning_latents().
568
+ :param k: The number of returned clips. The most likely (as determined by Tortoises' CLVP model) clips are returned.
569
+ :param latent_averaging_mode: 0/1/2 for following modes:
570
+ 0 - latents will be generated as in original tortoise, using ~4.27s from each voice sample, averaging latent across all samples
571
+ 1 - latents will be generated using (almost) entire voice samples, averaged across all the ~4.27s chunks
572
+ 2 - latents will be generated using (almost) entire voice samples, averaged per voice sample
573
+ :param verbose: Whether or not to print log messages indicating the progress of creating a clip. Default=true.
574
+ ~~AUTOREGRESSIVE KNOBS~~
575
+ :param num_autoregressive_samples: Number of samples taken from the autoregressive model, all of which are filtered using CLVP.
576
+ As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great".
577
+ :param temperature: The softmax temperature of the autoregressive model.
578
+ :param length_penalty: A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs.
579
+ :param repetition_penalty: A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce the incidence
580
+ of long silences or "uhhhhhhs", etc.
581
+ :param top_p: P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs.
582
+ :param max_mel_tokens: Restricts the output length. (0,600] integer. Each unit is 1/20 of a second.
583
+ :param typical_sampling: Turns typical sampling on or off. This sampling mode is discussed in this paper: https://arxiv.org/abs/2202.00666
584
+ I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but
585
+ could use some tuning.
586
+ :param typical_mass: The typical_mass parameter from the typical_sampling algorithm.
587
+ ~~CLVP-CVVP KNOBS~~
588
+ :param cvvp_amount: Controls the influence of the CVVP model in selecting the best output from the autoregressive model.
589
+ [0,1]. Values closer to 1 mean the CVVP model is more important, 0 disables the CVVP model.
590
+ ~~DIFFUSION KNOBS~~
591
+ :param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
592
+ the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
593
+ however.
594
+ :param cond_free: Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for
595
+ each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output
596
+ of the two is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and
597
+ dramatically improves realism.
598
+ :param cond_free_k: Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf].
599
+ As cond_free_k increases, the output becomes dominated by the conditioning-free signal.
600
+ Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k
601
+ :param diffusion_temperature: Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0
602
+ are the "mean" prediction of the diffusion network and will sound bland and smeared.
603
+ ~~OTHER STUFF~~
604
+ :param hf_generate_kwargs: The huggingface Transformers generate API is used for the autoregressive transformer.
605
+ Extra keyword args fed to this function get forwarded directly to that API. Documentation
606
+ here: https://huggingface.co/docs/transformers/internal/generation_utils
607
+ :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
608
+ Sample rate is 24kHz.
609
+ """
610
+ deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
611
+
612
+ text_tokens = (
613
+ torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
614
+ )
615
+ text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
616
+ assert (
617
+ text_tokens.shape[-1] < 400
618
+ ), "Too much text provided. Break the text up into separate segments and re-try inference."
619
+
620
+ auto_conds = None
621
+ if voice_samples is not None:
622
+ (
623
+ auto_conditioning,
624
+ diffusion_conditioning,
625
+ auto_conds,
626
+ _,
627
+ ) = self.get_conditioning_latents(
628
+ voice_samples,
629
+ return_mels=True,
630
+ latent_averaging_mode=latent_averaging_mode,
631
+ original_tortoise=original_tortoise,
632
+ )
633
+ elif conditioning_latents is not None:
634
+ auto_conditioning, diffusion_conditioning = conditioning_latents
635
+ else:
636
+ (
637
+ auto_conditioning,
638
+ diffusion_conditioning,
639
+ ) = self.get_random_conditioning_latents()
640
+ auto_conditioning = auto_conditioning.to(self.device)
641
+ diffusion_conditioning = diffusion_conditioning.to(self.device)
642
+
643
+ diffuser = load_discrete_vocoder_diffuser(
644
+ desired_diffusion_steps=diffusion_iterations,
645
+ cond_free=cond_free,
646
+ cond_free_k=cond_free_k,
647
+ sampler=sampler,
648
+ )
649
+
650
+ # in the case of single_sample,
651
+ orig_batch_size = self.autoregressive_batch_size
652
+ while num_autoregressive_samples % self.autoregressive_batch_size:
653
+ self.autoregressive_batch_size //= 2
654
+ with torch.no_grad():
655
+ samples = []
656
+ num_batches = num_autoregressive_samples // self.autoregressive_batch_size
657
+ stop_mel_token = self.autoregressive.stop_mel_token
658
+ calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
659
+ self.autoregressive = self.autoregressive.to(self.device)
660
+ if verbose:
661
+ print("Generating autoregressive samples..")
662
+ with self.temporary_cuda(
663
+ self.autoregressive
664
+ ) as autoregressive, torch.autocast(
665
+ device_type="cuda", dtype=torch.float16, enabled=half
666
+ ):
667
+ for b in tqdm(range(num_batches), disable=not verbose):
668
+ codes = autoregressive.inference_speech(
669
+ auto_conditioning,
670
+ text_tokens,
671
+ do_sample=True,
672
+ top_p=top_p,
673
+ temperature=temperature,
674
+ num_return_sequences=self.autoregressive_batch_size,
675
+ length_penalty=length_penalty,
676
+ repetition_penalty=repetition_penalty,
677
+ max_generate_length=max_mel_tokens,
678
+ **hf_generate_kwargs,
679
+ )
680
+ padding_needed = max_mel_tokens - codes.shape[1]
681
+ codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
682
+ samples.append(codes)
683
+ self.autoregressive_batch_size = (
684
+ orig_batch_size # in the case of single_sample
685
+ )
686
+
687
+ clip_results = []
688
+ with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
689
+ device_type="cuda", dtype=torch.float16, enabled=half
690
+ ):
691
+ if cvvp_amount > 0:
692
+ if self.cvvp is None:
693
+ self.load_cvvp()
694
+ self.cvvp = self.cvvp.to(self.device)
695
+ if verbose:
696
+ if self.cvvp is None:
697
+ print("Computing best candidates using CLVP")
698
+ else:
699
+ print(
700
+ f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%"
701
+ )
702
+ for batch in tqdm(samples, disable=not verbose):
703
+ for i in range(batch.shape[0]):
704
+ batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
705
+ if cvvp_amount != 1:
706
+ clvp_res = clvp(
707
+ text_tokens.repeat(batch.shape[0], 1),
708
+ batch,
709
+ return_loss=False,
710
+ )
711
+ if auto_conds is not None and cvvp_amount > 0:
712
+ cvvp_accumulator = 0
713
+ for cl in range(auto_conds.shape[1]):
714
+ cvvp_accumulator = cvvp_accumulator + self.cvvp(
715
+ auto_conds[:, cl].repeat(batch.shape[0], 1, 1),
716
+ batch,
717
+ return_loss=False,
718
+ )
719
+ cvvp = cvvp_accumulator / auto_conds.shape[1]
720
+ if cvvp_amount == 1:
721
+ clip_results.append(cvvp)
722
+ else:
723
+ clip_results.append(
724
+ cvvp * cvvp_amount + clvp_res * (1 - cvvp_amount)
725
+ )
726
+ else:
727
+ clip_results.append(clvp_res)
728
+ clip_results = torch.cat(clip_results, dim=0)
729
+ samples = torch.cat(samples, dim=0)
730
+ best_results = samples[torch.topk(clip_results, k=k).indices]
731
+ if self.cvvp is not None:
732
+ self.cvvp = self.cvvp.cpu()
733
+ del samples
734
+
735
+ # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
736
+ # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
737
+ # results, but will increase memory usage.
738
+ with self.temporary_cuda(self.autoregressive) as autoregressive:
739
+ best_latents = autoregressive(
740
+ auto_conditioning.repeat(k, 1),
741
+ text_tokens.repeat(k, 1),
742
+ torch.tensor([text_tokens.shape[-1]], device=text_tokens.device),
743
+ best_results,
744
+ torch.tensor(
745
+ [
746
+ best_results.shape[-1]
747
+ * self.autoregressive.mel_length_compression
748
+ ],
749
+ device=text_tokens.device,
750
+ ),
751
+ return_latent=True,
752
+ clip_inputs=False,
753
+ )
754
+ del auto_conditioning
755
+
756
+ if verbose:
757
+ print("Transforming autoregressive outputs into audio..")
758
+ wav_candidates = []
759
+ with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda(
760
+ self.vocoder
761
+ ) as vocoder:
762
+ diffusion.enable_fp16 = half # hacky
763
+ for b in range(best_results.shape[0]):
764
+ codes = best_results[b].unsqueeze(0)
765
+ latents = best_latents[b].unsqueeze(0)
766
+
767
+ # Find the first occurrence of the "calm" token and trim the codes to that.
768
+ ctokens = 0
769
+ for k in range(codes.shape[-1]):
770
+ if codes[0, k] == calm_token:
771
+ ctokens += 1
772
+ else:
773
+ ctokens = 0
774
+ if (
775
+ ctokens > 8
776
+ ): # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
777
+ latents = latents[:, :k]
778
+ break
779
+
780
+ mel = do_spectrogram_diffusion(
781
+ diffusion,
782
+ diffuser,
783
+ latents,
784
+ diffusion_conditioning,
785
+ temperature=diffusion_temperature,
786
+ verbose=verbose,
787
+ )
788
+ wav = vocoder.inference(mel)
789
+ wav_candidates.append(wav.cpu())
790
+
791
+ def potentially_redact(clip, text):
792
+ if self.enable_redaction:
793
+ return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
794
+ return clip
795
+
796
+ wav_candidates = [
797
+ potentially_redact(wav_candidate, text)
798
+ for wav_candidate in wav_candidates
799
+ ]
800
+
801
+ if len(wav_candidates) > 1:
802
+ res = wav_candidates
803
+ else:
804
+ res = wav_candidates[0]
805
+
806
+ if return_deterministic_state:
807
+ return res, (
808
+ deterministic_seed,
809
+ text,
810
+ voice_samples,
811
+ conditioning_latents,
812
+ )
813
+ else:
814
+ return res
815
+
816
+ def deterministic_state(self, seed=None):
817
+ """
818
+ Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
819
+ reproduced.
820
+ """
821
+ seed = int(time()) if seed is None else seed
822
+ torch.manual_seed(seed)
823
+ random.seed(seed)
824
+ # Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
825
+ # torch.use_deterministic_algorithms(True)
826
+
827
+ return seed
tortoise/data/got.txt ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Chapter One
2
+
3
+
4
+ Bran
5
+
6
+
7
+ The morning had dawned clear and cold, with a crispness that hinted at the end of summer. They set forth at daybreak to see a man beheaded, twenty in all, and Bran rode among them, nervous with excitement. This was the first time he had been deemed old enough to go with his lord father and his brothers to see the king's justice done. It was the ninth year of summer, and the seventh of Bran's life.
8
+
9
+
10
+ The man had been taken outside a small holdfast in the hills. Robb thought he was a wildling, his sword sworn to Mance Rayder, the King-beyond-the-Wall. It made Bran's skin prickle to think of it. He remembered the hearth tales Old Nan told them. The wildlings were cruel men, she said, slavers and slayers and thieves. They consorted with giants and ghouls, stole girl children in the dead of night, and drank blood from polished horns. And their women lay with the Others in the Long Night to sire terrible half-human children.
11
+
12
+
13
+ But the man they found bound hand and foot to the holdfast wall awaiting the king's justice was old and scrawny, not much taller than Robb. He had lost both ears and a finger to frostbite, and he dressed all in black, the same as a brother of the Night's Watch, except that his furs were ragged and greasy.
14
+
15
+
16
+ The breath of man and horse mingled, steaming, in the cold morning air as his lord father had the man cut down from the wall and dragged before them. Robb and Jon sat tall and still on their horses, with Bran between them on his pony, trying to seem older than seven, trying to pretend that he'd seen all this before. A faint wind blew through the holdfast gate. Over their heads flapped the banner of the Starks of Winterfell: a grey direwolf racing across an ice-white field.
17
+
18
+ Bran's father sat solemnly on his horse, long brown hair stirring in the wind. His closely trimmed beard was shot with white, making him look older than his thirty-five years. He had a grim cast to his grey eyes this day, and he seemed not at all the man who would sit before the fire in the evening and talk softly of the age of heroes and the children of the forest. He had taken off Father's face, Bran thought, and donned the face of Lord Stark of Winterfell.
19
+
20
+
21
+ There were questions asked and answers given there in the chill of morning, but afterward Bran could not recall much of what had been said. Finally his lord father gave a command, and two of his guardsmen dragged the ragged man to the ironwood stump in the center of the square. They forced his head down onto the hard black wood. Lord Eddard Stark dismounted and his ward Theon Greyjoy brought forth the sword. "Ice," that sword was called. It was as wide across as a man's hand, and taller even than Robb. The blade was Valyrian steel, spell-forged and dark as smoke. Nothing held an edge like Valyrian steel.
22
+
23
+
24
+ His father peeled off his gloves and handed them to Jory Cassel, the captain of his household guard. He took hold of Ice with both hands and said, "In the name of Robert of the House Baratheon, the First of his Name, King of the Andals and the Rhoynar and the First Men, Lord of the Seven Kingdoms and Protector of the Realm, by the word of Eddard of the House Stark, Lord of Winterfell and Warden of the North, I do sentence you to die." He lifted the greatsword high above his head.
25
+
26
+
27
+ Bran's bastard brother Jon Snow moved closer. "Keep the pony well in hand," he whispered. "And don't look away. Father will know if you do."
28
+
29
+
30
+ Bran kept his pony well in hand, and did not look away.
31
+
32
+
33
+ His father took off the man's head with a single sure stroke. Blood sprayed out across the snow, as red as surnmerwine. One of the horses reared and had to be restrained to keep from bolting. Bran could not take his eyes off the blood. The snows around the stump drank it eagerly, reddening as he watched.
34
+
35
+ The head bounced off a thick root and rolled. It came up near Greyjoy's feet. Theon was a lean, dark youth of nineteen who found everything amusing. He laughed, put his boot on the head, and kicked it away.
36
+
37
+
38
+ "Ass," Jon muttered, low enough so Greyjoy did not hear. He put a hand on Bran's shoulder, and Bran looked over at his bastard brother. "You did well," Jon told him solemnly. Jon was fourteen, an old hand at justice.
39
+
40
+
41
+ It seemed colder on the long ride back to Winterfell, though the wind had died by then and the sun was higher in the sky. Bran rode with his brothers, well ahead of the main party, his pony struggling hard to keep up with their horses.
42
+
43
+
44
+ "The deserter died bravely," Robb said. He was big and broad and growing every day, with his mother's coloring, the fair skin, red-brown hair, and blue eyes of the Tullys of Riverrun. "He had courage, at the least."
45
+
46
+
47
+ "No," Jon Snow said quietly. "It was not courage. This one was dead of fear. You could see it in his eyes, Stark." Jon's eyes were a grey so dark they seemed almost black, but there was little they did not see. He was of an age with Robb, but they did not look alike. Jon was slender where Robb was muscular, dark where Robb was fair, graceful and quick where his half brother was strong and fast.
48
+
49
+
50
+ Robb was not impressed. "The Others take his eyes," he swore. "He died well. Race you to the bridge?"
51
+
52
+
53
+ "Done," Jon said, kicking his horse forward. Robb cursed and followed, and they galloped off down the trail, Robb laughing and hooting, Jon silent and intent. The hooves of their horses kicked up showers of snow as they went.
54
+
55
+ Bran did not try to follow. His pony could not keep up. He had seen the ragged man's eyes, and he was thinking of them now. After a while, the sound of Robb's laughter receded, and the woods grew silent again.
56
+
57
+
58
+ So deep in thought was he that he never heard the rest of the party until his father moved up to ride beside him. "Are you well, Bran?" he asked, not unkindly.
59
+
60
+
61
+ "Yes, Father," Bran told him. He looked up. Wrapped in his furs and leathers, mounted on his great warhorse, his lord father loomed over him like a giant. "Robb says the man died bravely, but Jon says he was afraid."
62
+
63
+
64
+ "What do you think?" his father asked.
65
+
66
+
67
+ Bran thought about it. "Can a man still be brave if he's afraid?"
68
+
69
+
70
+ "That is the only time a man can be brave," his father told him. "Do you understand why I did it?"
71
+
72
+
73
+ "He was a wildling," Bran said. "They carry off women and sell them to the Others."
74
+
75
+
76
+ His lord father smiled. "Old Nan has been telling you stories again. In truth, the man was an oathbreaker, a deserter from the Night's Watch. No man is more dangerous. The deserter knows his life is forfeit if he is taken, so he will not flinch from any crime, no matter how vile. But you mistake me. The question was not why the man had to die, but why I must do it."
77
+
78
+
79
+ Bran had no answer for that. "King Robert has a headsman," he said, uncertainly.
80
+
81
+
82
+ "He does," his father admitted. "As did the Targaryen kings before him. Yet our way is the older way. The blood of the First Men still flows in the veins of the Starks, and we hold to the belief that the man who passes the sentence should swing the sword. If you would take a man's life, you owe it to him to look into his eyes and hear his final words. And if you cannot bear to do that, then perhaps the man does not deserve to die.
83
+
84
+
85
+ "One day, Bran, you will be Robb's bannerman, holding a keep of your own for your brother and your king, and justice will fall to you. When that day comes, you must take no pleasure in the task, but neither must you look away. A ruler who hides behind paid executioners soon forgets what death is."
86
+
87
+
88
+ That was when Jon reappeared on the crest of the hill before them. He waved and shouted down at them. "Father, Bran, come quickly, see what Robb has found!" Then he was gone again.
89
+
90
+
91
+ Jory rode up beside them. "Trouble, my lord?"
92
+
93
+
94
+ "Beyond a doubt," his lord father said. "Come, let us see what mischief my sons have rooted out now." He sent his horse into a trot. Jory and Bran and the rest came after.
95
+
96
+
97
+ They found Robb on the riverbank north of the bridge, with Jon still mounted beside him. The late summer snows had been heavy this moonturn. Robb stood knee-deep in white, his hood pulled back so the sun shone in his hair. He was cradling something in his arm, while the boys talked in hushed, excited voices.
98
+
99
+
100
+ The riders picked their way carefully through the drifts, groping for solid footing on the hidden, uneven ground . Jory Cassel and Theon Greyjoy were the first to reach the boys. Greyjoy was laughing and joking as he rode. Bran heard the breath go out of him. "Gods!" he exclaimed, struggling to keep control of his horse as he reached for his sword.
101
+
102
+
103
+ Jory's sword was already out. "Robb, get away from it!" he called as his horse reared under him.
104
+
105
+
106
+ Robb grinned and looked up from the bundle in his arms. "She can't hurt you," he said. "She's dead, Jory."
107
+
108
+
109
+ Bran was afire with curiosity by then. He would have spurred the pony faster, but his father made them dismount beside the bridge and approach on foot. Bran jumped off and ran.
110
+
111
+
112
+ By then Jon, Jory, and Theon Greyjoy had all dismounted as well. "What in the seven hells is it?" Greyjoy was saying.
113
+
114
+
115
+ "A wolf," Robb told him.
116
+
117
+
118
+ "A freak," Greyjoy said. "Look at the size of it."
119
+
120
+
121
+ Bran's heart was thumping in his chest as he pushed through a waist-high drift to his brothers' side.
122
+
123
+
124
+ Half-buried in bloodstained snow, a huge dark shape slumped in death. Ice had formed in its shaggy grey fur, and the faint smell of corruption clung to it like a woman's perfume. Bran glimpsed blind eyes crawling with maggots, a wide mouth full of yellowed teeth. But it was the size of it that made him gasp. It was bigger than his pony, twice the size of the largest hound in his father's kennel.
125
+
126
+
127
+ "It's no freak," Jon said calmly. "That's a direwolf. They grow larger than the other kind."
128
+
129
+
130
+ Theon Greyjoy said, "There's not been a direwolf sighted south of the Wall in two hundred years."
131
+
132
+
133
+ "I see one now," Jon replied.
134
+
135
+
136
+ Bran tore his eyes away from the monster. That was when he noticed the bundle in Robb's arms. He gave a cry of delight and moved closer. The pup was a tiny ball of grey-black fur, its eyes still closed. It nuzzled blindly against Robb's chest as he cradled it, searching for milk among his leathers, making a sad little whimpery sound. Bran reached out hesitantly. "Go on," Robb told him. "You can touch him."
137
+
138
+
139
+ Bran gave the pup a quick nervous stroke, then turned as Jon said, "Here you go." His half brother put a second pup into his arms. "There are five of them." Bran sat down in the snow and hugged the wolf pup to his face. Its fur was soft and warm against his cheek.
140
+
141
+
142
+ "Direwolves loose in the realm, after so many years," muttered Hullen, the master of horse. "I like it not."
143
+
144
+
145
+ "It is a sign," Jory said.
146
+
147
+
148
+ Father frowned. "This is only a dead animal, Jory," he said. Yet he seemed troubled. Snow crunched under his boots as he moved around the body. "Do we know what killed her?"
149
+
150
+
151
+ "There's something in the throat," Robb told him, proud to have found the answer before his father even asked. "There, just under the jaw."
152
+
153
+
154
+ His father knelt and groped under the beast's head with his hand. He gave a yank and held it up for all to see. A foot of shattered antler, tines snapped off, all wet with blood.
155
+
156
+
157
+ A sudden silence descended over the party. The men looked at the antler uneasily, and no one dared to speak. Even Bran could sense their fear, though he did not understand.
158
+
159
+
160
+ His father tossed the antler to the side and cleansed his hands in the snow. "I'm surprised she lived long enough to whelp," he said. His voice broke the spell.
161
+
162
+
163
+ "Maybe she didn't," Jory said. "I've heard tales . . . maybe the bitch was already dead when the pups came."
164
+
165
+
166
+ "Born with the dead," another man put in. "Worse luck."
167
+
168
+
169
+ "No matter," said Hullen. "They be dead soon enough too."
170
+
171
+
172
+ Bran gave a wordless cry of dismay.
173
+
174
+
175
+ "The sooner the better," Theon Greyjoy agreed. He drew his sword. "Give the beast here, Bran."
176
+
177
+
178
+ The little thing squirmed against him, as if it heard and understood. "No!" Bran cried out fiercely. "It's mine."
179
+
180
+
181
+ "Put away your sword, Greyjoy," Robb said. For a moment he sounded as commanding as their father, like the lord he would someday be. "We will keep these pups."
182
+
183
+
184
+ "You cannot do that, boy," said Harwin, who was Hullen's son.
185
+
186
+
187
+ "It be a mercy to kill them," Hullen said.
188
+
189
+
190
+ Bran looked to his lord father for rescue, but got only a frown, a furrowed brow. "Hullen speaks truly, son. Better a swift death than a hard one from cold and starvation."
191
+
192
+
193
+ "No!" He could feel tears welling in his eyes, and he looked away. He did not want to cry in front of his father.
194
+
195
+
196
+ Robb resisted stubbornly. "Ser Rodrik's red bitch whelped again last week," he said. "It was a small litter, only two live pups. She'll have milk enough."
197
+
198
+
199
+ "She'll rip them apart when they try to nurse."
200
+
201
+
202
+ "Lord Stark," Jon said. It was strange to hear him call Father that, so formal. Bran looked at him with desperate hope. "There are five pups," he told Father. "Three male, two female."
203
+
204
+
205
+ "What of it, Jon?"
206
+
207
+
208
+ "You have five trueborn children," Jon said. "Three sons, two daughters. The direwolf is the sigil of your House. Your children were meant to have these pups, my lord."
209
+
210
+
211
+ Bran saw his father's face change, saw the other men exchange glances. He loved Jon with all his heart at that moment. Even at seven, Bran understood what his brother had done. The count had come right only because Jon had omitted himself. He had included the girls, included even Rickon, the baby, but not the bastard who bore the surname Snow, the name that custom decreed be given to all those in the north unlucky enough to be born with no name of their own.
212
+
213
+
214
+ Their father understood as well. "You want no pup for yourself, Jon?" he asked softly.
215
+
216
+
217
+ "The direwolf graces the banners of House Stark," Jon pointed out. "I am no Stark, Father."
218
+
219
+
220
+ Their lord father regarded Jon thoughtfully. Robb rushed into the silence he left. "I will nurse him myself, Father," he promised. "I will soak a towel with warm milk, and give him suck from that."
221
+
222
+
223
+ "Me too!" Bran echoed.
224
+
225
+
226
+ The lord weighed his sons long and carefully with his eyes. "Easy to say, and harder to do. I will not have you wasting the servants' time with this. If you want these pups, you will feed them yourselves. Is that understood?"
227
+
228
+
229
+ Bran nodded eagerly. The pup squirmed in his grasp, licked at his face with a warm tongue.
230
+
231
+
232
+ "You must train them as well," their father said. "You must train them. The kennelmaster will have nothing to do with these monsters, I promise you that. And the gods help you if you neglect them, or brutalize them, or train them badly. These are not dogs to beg for treats and slink off at a kick. A direwolf will rip a man's arm off his shoulder as easily as a dog will kill a rat. Are you sure you want this?"
233
+
234
+ "Yes, Father," Bran said.
235
+
236
+
237
+ "Yes," Robb agreed.
238
+
239
+
240
+ "The pups may die anyway, despite all you do."
241
+
242
+
243
+ "They won't die," Robb said. "We won't let them die."
244
+
245
+
246
+ "Keep them, then. Jory, Desmond, gather up the other pups. It's time we were back to Winterfell."
247
+
248
+
249
+ It was not until they were mounted and on their way that Bran allowed himself to taste the sweet air of victory. By then, his pup was snuggled inside his leathers, warm against him, safe for the long ride home. Bran was wondering what to name him.
250
+
251
+
252
+ Halfway across the bridge, Jon pulled up suddenly.
253
+
254
+
255
+ "What is it, Jon?" their lord father asked.
256
+
257
+
258
+ "Can't you hear it?"
259
+
260
+
261
+ Bran could hear the wind in the trees, the clatter of their hooves on the ironwood planks, the whimpering of his hungry pup, but Jon was listening to something else.
262
+
263
+
264
+ "There," Jon said. He swung his horse around and galloped back across the bridge. They watched him dismount where the direwolf lay dead in the snow, watched him kneel. A moment later he was riding back to them, smiling.
265
+
266
+
267
+ "He must have crawled away from the others," Jon said.
268
+
269
+
270
+ "Or been driven away," their father said, looking at the sixth pup. His fur was white, where the rest of the litter was grey. His eyes were as red as the blood of the ragged man who had died that morning. Bran thought it curious that this pup alone would have opened his eyes while the others were still blind.
271
+
272
+
273
+ "An albino," Theon Greyjoy said with wry amusement. "This one will die even faster than the others."
274
+
275
+
276
+ Jon Snow gave his father's ward a long, chilling look. "I think not, Greyjoy," he said. "This one belongs to me."
tortoise/data/layman.txt ADDED
File without changes
tortoise/data/mel_norms.pth ADDED
Binary file (1.07 kB). View file
 
tortoise/data/riding_hood.txt ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Once upon a time there lived in a certain village a little country girl, the prettiest creature who was ever seen. Her mother was excessively fond of her; and her grandmother doted on her still more. This good woman had a little red riding hood made for her. It suited the girl so extremely well that everybody called her Little Red Riding Hood.
2
+ One day her mother, having made some cakes, said to her, "Go, my dear, and see how your grandmother is doing, for I hear she has been very ill. Take her a cake, and this little pot of butter."
3
+
4
+ Little Red Riding Hood set out immediately to go to her grandmother, who lived in another village.
5
+
6
+ As she was going through the wood, she met with a wolf, who had a very great mind to eat her up, but he dared not, because of some woodcutters working nearby in the forest. He asked her where she was going. The poor child, who did not know that it was dangerous to stay and talk to a wolf, said to him, "I am going to see my grandmother and carry her a cake and a little pot of butter from my mother."
7
+
8
+ "Does she live far off?" said the wolf
9
+
10
+ "Oh I say," answered Little Red Riding Hood; "it is beyond that mill you see there, at the first house in the village."
11
+
12
+ "Well," said the wolf, "and I'll go and see her too. I'll go this way and go you that, and we shall see who will be there first."
13
+
14
+ The wolf ran as fast as he could, taking the shortest path, and the little girl took a roundabout way, entertaining herself by gathering nuts, running after butterflies, and gathering bouquets of little flowers. It was not long before the wolf arrived at the old woman's house. He knocked at the door: tap, tap.
15
+
16
+ "Who's there?"
17
+
18
+ "Your grandchild, Little Red Riding Hood," replied the wolf, counterfeiting her voice; "who has brought you a cake and a little pot of butter sent you by mother."
19
+
20
+ The good grandmother, who was in bed, because she was somewhat ill, cried out, "Pull the bobbin, and the latch will go up."
21
+
22
+ The wolf pulled the bobbin, and the door opened, and then he immediately fell upon the good woman and ate her up in a moment, for it been more than three days since he had eaten. He then shut the door and got into the grandmother's bed, expecting Little Red Riding Hood, who came some time afterwards and knocked at the door: tap, tap.
23
+
24
+ "Who's there?"
25
+
26
+ Little Red Riding Hood, hearing the big voice of the wolf, was at first afraid; but believing her grandmother had a cold and was hoarse, answered, "It is your grandchild Little Red Riding Hood, who has brought you a cake and a little pot of butter mother sends you."
27
+
28
+ The wolf cried out to her, softening his voice as much as he could, "Pull the bobbin, and the latch will go up."
29
+
30
+ Little Red Riding Hood pulled the bobbin, and the door opened.
31
+
32
+ The wolf, seeing her come in, said to her, hiding himself under the bedclothes, "Put the cake and the little pot of butter upon the stool, and come get into bed with me."
33
+
34
+ Little Red Riding Hood took off her clothes and got into bed. She was greatly amazed to see how her grandmother looked in her nightclothes, and said to her, "Grandmother, what big arms you have!"
35
+
36
+ "All the better to hug you with, my dear."
37
+
38
+ "Grandmother, what big legs you have!"
39
+
40
+ "All the better to run with, my child."
41
+
42
+ "Grandmother, what big ears you have!"
43
+
44
+ "All the better to hear with, my child."
45
+
46
+ "Grandmother, what big eyes you have!"
47
+
48
+ "All the better to see with, my child."
49
+
50
+ "Grandmother, what big teeth you have got!"
51
+
52
+ "All the better to eat you up with."
53
+
54
+ And, saying these words, this wicked wolf fell upon Little Red Riding Hood, and ate her all up.
tortoise/data/seal_copypasta.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ What the fuck did you just fucking say about me, you little bitch? I'll have you know I graduated top of my class in the Navy Seals, and I've been involved in numerous secret raids on Al kayda, and I have over 300 confirmed kills. I am trained in gorilla warfare and I'm the top sniper in the entire U S armed forces. You are nothing to me but just another target. I will wipe you the fuck out with precision the likes of which has never been seen before on this Earth, mark my fucking words. You think you can get away with saying that shit to me over the Internet? Think again, fucker. As we speak I am contacting my secret network of spies across the U S A and your IP is being traced right now so you better prepare for the storm, maggot. The storm that wipes out the pathetic little thing you call your life. You're fucking dead, kid. I can be anywhere, anytime, and I can kill you in over seven hundred ways, and that's just with my bare hands. Not only am I extensively trained in unarmed combat, but I have access to the entire arsenal of the United States Marine Corps and I will use it to its full extent to wipe your miserable ass off the face of the continent, you little shit. If only you could have known what unholy retribution your little "clever" comment was about to bring down upon you, maybe you would have held your fucking tongue. But you couldn't, you didn't, and now you're paying the price, you goddamn idiot. I will shit fury all over you and you will drown in it. You're fucking dead, kiddo.
tortoise/data/tokenizer.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}}
tortoise/dpm_solver_pytorch.py ADDED
@@ -0,0 +1,1653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+
5
+
6
+ class NoiseScheduleVP:
7
+ def __init__(
8
+ self,
9
+ schedule="discrete",
10
+ betas=None,
11
+ alphas_cumprod=None,
12
+ continuous_beta_0=0.1,
13
+ continuous_beta_1=20.0,
14
+ dtype=torch.float32,
15
+ ):
16
+ """Create a wrapper class for the forward SDE (VP type).
17
+
18
+ ***
19
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
20
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
21
+ ***
22
+
23
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
24
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
25
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
26
+
27
+ log_alpha_t = self.marginal_log_mean_coeff(t)
28
+ sigma_t = self.marginal_std(t)
29
+ lambda_t = self.marginal_lambda(t)
30
+
31
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
32
+
33
+ t = self.inverse_lambda(lambda_t)
34
+
35
+ ===============================================================
36
+
37
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
38
+
39
+ 1. For discrete-time DPMs:
40
+
41
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
42
+ t_i = (i + 1) / N
43
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
44
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
45
+
46
+ Args:
47
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
48
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
49
+
50
+ Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
51
+
52
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
53
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
54
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
55
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
56
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
57
+ and
58
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
59
+
60
+
61
+ 2. For continuous-time DPMs:
62
+
63
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
64
+ schedule are the default settings in DDPM and improved-DDPM:
65
+
66
+ Args:
67
+ beta_min: A `float` number. The smallest beta for the linear schedule.
68
+ beta_max: A `float` number. The largest beta for the linear schedule.
69
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
70
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
71
+ T: A `float` number. The ending time of the forward process.
72
+
73
+ ===============================================================
74
+
75
+ Args:
76
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
77
+ 'linear' or 'cosine' for continuous-time DPMs.
78
+ Returns:
79
+ A wrapper object of the forward SDE (VP type).
80
+
81
+ ===============================================================
82
+
83
+ Example:
84
+
85
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
86
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
87
+
88
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
89
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
90
+
91
+ # For continuous-time DPMs (VPSDE), linear schedule:
92
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
93
+
94
+ """
95
+
96
+ if schedule not in ["discrete", "linear", "cosine"]:
97
+ raise ValueError(
98
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
99
+ schedule
100
+ )
101
+ )
102
+
103
+ self.schedule = schedule
104
+ if schedule == "discrete":
105
+ if betas is not None:
106
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
107
+ else:
108
+ assert alphas_cumprod is not None
109
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
110
+ self.total_N = len(log_alphas)
111
+ self.T = 1.0
112
+ self.t_array = (
113
+ torch.linspace(0.0, 1.0, self.total_N + 1)[1:]
114
+ .reshape((1, -1))
115
+ .to(dtype=dtype)
116
+ )
117
+ self.log_alpha_array = log_alphas.reshape(
118
+ (
119
+ 1,
120
+ -1,
121
+ )
122
+ ).to(dtype=dtype)
123
+ else:
124
+ self.total_N = 1000
125
+ self.beta_0 = continuous_beta_0
126
+ self.beta_1 = continuous_beta_1
127
+ self.cosine_s = 0.008
128
+ self.cosine_beta_max = 999.0
129
+ self.cosine_t_max = (
130
+ math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
131
+ * 2.0
132
+ * (1.0 + self.cosine_s)
133
+ / math.pi
134
+ - self.cosine_s
135
+ )
136
+ self.cosine_log_alpha_0 = math.log(
137
+ math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)
138
+ )
139
+ self.schedule = schedule
140
+ if schedule == "cosine":
141
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
142
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
143
+ self.T = 0.9946
144
+ else:
145
+ self.T = 1.0
146
+
147
+ def marginal_log_mean_coeff(self, t):
148
+ """
149
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
150
+ """
151
+ if self.schedule == "discrete":
152
+ return interpolate_fn(
153
+ t.reshape((-1, 1)),
154
+ self.t_array.to(t.device),
155
+ self.log_alpha_array.to(t.device),
156
+ ).reshape((-1))
157
+ elif self.schedule == "linear":
158
+ return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
159
+ elif self.schedule == "cosine":
160
+
161
+ def log_alpha_fn(s):
162
+ return torch.log(
163
+ torch.cos(
164
+ (s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0
165
+ )
166
+ )
167
+
168
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
169
+ return log_alpha_t
170
+
171
+ def marginal_alpha(self, t):
172
+ """
173
+ Compute alpha_t of a given continuous-time label t in [0, T].
174
+ """
175
+ return torch.exp(self.marginal_log_mean_coeff(t))
176
+
177
+ def marginal_std(self, t):
178
+ """
179
+ Compute sigma_t of a given continuous-time label t in [0, T].
180
+ """
181
+ return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
182
+
183
+ def marginal_lambda(self, t):
184
+ """
185
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
186
+ """
187
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
188
+ log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
189
+ return log_mean_coeff - log_std
190
+
191
+ def inverse_lambda(self, lamb):
192
+ """
193
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
194
+ """
195
+ if self.schedule == "linear":
196
+ tmp = (
197
+ 2.0
198
+ * (self.beta_1 - self.beta_0)
199
+ * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
200
+ )
201
+ Delta = self.beta_0**2 + tmp
202
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
203
+ elif self.schedule == "discrete":
204
+ log_alpha = -0.5 * torch.logaddexp(
205
+ torch.zeros((1,)).to(lamb.device), -2.0 * lamb
206
+ )
207
+ t = interpolate_fn(
208
+ log_alpha.reshape((-1, 1)),
209
+ torch.flip(self.log_alpha_array.to(lamb.device), [1]),
210
+ torch.flip(self.t_array.to(lamb.device), [1]),
211
+ )
212
+ return t.reshape((-1,))
213
+ else:
214
+ log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
215
+
216
+ def t_fn(log_alpha_t):
217
+ return (
218
+ torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0))
219
+ * 2.0
220
+ * (1.0 + self.cosine_s)
221
+ / math.pi
222
+ - self.cosine_s
223
+ )
224
+
225
+ t = t_fn(log_alpha)
226
+ return t
227
+
228
+
229
+ def model_wrapper(
230
+ model,
231
+ noise_schedule,
232
+ model_type="noise",
233
+ model_kwargs={},
234
+ guidance_type="uncond",
235
+ condition=None,
236
+ unconditional_condition=None,
237
+ guidance_scale=1.0,
238
+ classifier_fn=None,
239
+ classifier_kwargs={},
240
+ ):
241
+ """Create a wrapper function for the noise prediction model.
242
+
243
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
244
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
245
+
246
+ We support four types of the diffusion model by setting `model_type`:
247
+
248
+ 1. "noise": noise prediction model. (Trained by predicting noise).
249
+
250
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
251
+
252
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
253
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
254
+
255
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
256
+ arXiv preprint arXiv:2202.00512 (2022).
257
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
258
+ arXiv preprint arXiv:2210.02303 (2022).
259
+
260
+ 4. "score": marginal score function. (Trained by denoising score matching).
261
+ Note that the score function and the noise prediction model follows a simple relationship:
262
+ ```
263
+ noise(x_t, t) = -sigma_t * score(x_t, t)
264
+ ```
265
+
266
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
267
+ 1. "uncond": unconditional sampling by DPMs.
268
+ The input `model` has the following format:
269
+ ``
270
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
271
+ ``
272
+
273
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
274
+ The input `model` has the following format:
275
+ ``
276
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
277
+ ``
278
+
279
+ The input `classifier_fn` has the following format:
280
+ ``
281
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
282
+ ``
283
+
284
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
285
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
286
+
287
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
288
+ The input `model` has the following format:
289
+ ``
290
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
291
+ ``
292
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
293
+
294
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
295
+ arXiv preprint arXiv:2207.12598 (2022).
296
+
297
+
298
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
299
+ or continuous-time labels (i.e. epsilon to T).
300
+
301
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
302
+ ``
303
+ def model_fn(x, t_continuous) -> noise:
304
+ t_input = get_model_input_time(t_continuous)
305
+ return noise_pred(model, x, t_input, **model_kwargs)
306
+ ``
307
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
308
+
309
+ ===============================================================
310
+
311
+ Args:
312
+ model: A diffusion model with the corresponding format described above.
313
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
314
+ model_type: A `str`. The parameterization type of the diffusion model.
315
+ "noise" or "x_start" or "v" or "score".
316
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
317
+ guidance_type: A `str`. The type of the guidance for sampling.
318
+ "uncond" or "classifier" or "classifier-free".
319
+ condition: A pytorch tensor. The condition for the guided sampling.
320
+ Only used for "classifier" or "classifier-free" guidance type.
321
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
322
+ Only used for "classifier-free" guidance type.
323
+ guidance_scale: A `float`. The scale for the guided sampling.
324
+ classifier_fn: A classifier function. Only used for the classifier guidance.
325
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
326
+ Returns:
327
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
328
+ """
329
+
330
+ def get_model_input_time(t_continuous):
331
+ """
332
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
333
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
334
+ For continuous-time DPMs, we just use `t_continuous`.
335
+ """
336
+ if noise_schedule.schedule == "discrete":
337
+ return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
338
+ else:
339
+ return t_continuous
340
+
341
+ def noise_pred_fn(x, t_continuous, cond=None):
342
+ t_input = get_model_input_time(t_continuous)
343
+ if cond is None:
344
+ output = model(x, t_input, **model_kwargs)
345
+ else:
346
+ output = model(x, t_input, cond, **model_kwargs)
347
+ if model_type == "noise":
348
+ return output
349
+ elif model_type == "x_start":
350
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(
351
+ t_continuous
352
+ ), noise_schedule.marginal_std(t_continuous)
353
+ return (x - alpha_t * output) / sigma_t
354
+ elif model_type == "v":
355
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(
356
+ t_continuous
357
+ ), noise_schedule.marginal_std(t_continuous)
358
+ return alpha_t * output + sigma_t * x
359
+ elif model_type == "score":
360
+ sigma_t = noise_schedule.marginal_std(t_continuous)
361
+ return -sigma_t * output
362
+
363
+ def cond_grad_fn(x, t_input):
364
+ """
365
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
366
+ """
367
+ with torch.enable_grad():
368
+ x_in = x.detach().requires_grad_(True)
369
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
370
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
371
+
372
+ def model_fn(x, t_continuous):
373
+ """
374
+ The noise predicition model function that is used for DPM-Solver.
375
+ """
376
+ if guidance_type == "uncond":
377
+ return noise_pred_fn(x, t_continuous)
378
+ elif guidance_type == "classifier":
379
+ assert classifier_fn is not None
380
+ t_input = get_model_input_time(t_continuous)
381
+ cond_grad = cond_grad_fn(x, t_input)
382
+ sigma_t = noise_schedule.marginal_std(t_continuous)
383
+ noise = noise_pred_fn(x, t_continuous)
384
+ return noise - guidance_scale * sigma_t * cond_grad
385
+ elif guidance_type == "classifier-free":
386
+ if guidance_scale == 1.0 or unconditional_condition is None:
387
+ return noise_pred_fn(x, t_continuous, cond=condition)
388
+ else:
389
+ x_in = torch.cat([x] * 2)
390
+ t_in = torch.cat([t_continuous] * 2)
391
+ c_in = torch.cat([unconditional_condition, condition])
392
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
393
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
394
+
395
+ assert model_type in ["noise", "x_start", "v", "score"]
396
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
397
+ return model_fn
398
+
399
+
400
+ class DPM_Solver:
401
+ def __init__(
402
+ self,
403
+ model_fn,
404
+ noise_schedule,
405
+ algorithm_type="dpmsolver++",
406
+ correcting_x0_fn=None,
407
+ correcting_xt_fn=None,
408
+ thresholding_max_val=1.0,
409
+ dynamic_thresholding_ratio=0.995,
410
+ ):
411
+ """Construct a DPM-Solver.
412
+
413
+ We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
414
+
415
+ We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
416
+ can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
417
+ dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
418
+ DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
419
+ DPMs (such as stable-diffusion).
420
+
421
+ To support advanced algorithms in image-to-image applications, we also support corrector functions for
422
+ both x0 and xt.
423
+
424
+ Args:
425
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
426
+ ``
427
+ def model_fn(x, t_continuous):
428
+ return noise
429
+ ``
430
+ The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
431
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
432
+ algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
433
+ correcting_x0_fn: A `str` or a function with the following format:
434
+ ```
435
+ def correcting_x0_fn(x0, t):
436
+ x0_new = ...
437
+ return x0_new
438
+ ```
439
+ This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
440
+ ```
441
+ x0_pred = data_pred_model(xt, t)
442
+ if correcting_x0_fn is not None:
443
+ x0_pred = correcting_x0_fn(x0_pred, t)
444
+ xt_1 = update(x0_pred, xt, t)
445
+ ```
446
+ If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
447
+ correcting_xt_fn: A function with the following format:
448
+ ```
449
+ def correcting_xt_fn(xt, t, step):
450
+ x_new = ...
451
+ return x_new
452
+ ```
453
+ This function is to correct the intermediate samples xt at each sampling step. e.g.,
454
+ ```
455
+ xt = ...
456
+ xt = correcting_xt_fn(xt, t, step)
457
+ ```
458
+ thresholding_max_val: A `float`. The max value for thresholding.
459
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
460
+ dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
461
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
462
+
463
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
464
+ Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
465
+ with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
466
+ """
467
+ self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
468
+ self.noise_schedule = noise_schedule
469
+ assert algorithm_type in ["dpmsolver", "dpmsolver++"]
470
+ self.algorithm_type = algorithm_type
471
+ if correcting_x0_fn == "dynamic_thresholding":
472
+ self.correcting_x0_fn = self.dynamic_thresholding_fn
473
+ else:
474
+ self.correcting_x0_fn = correcting_x0_fn
475
+ self.correcting_xt_fn = correcting_xt_fn
476
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
477
+ self.thresholding_max_val = thresholding_max_val
478
+
479
+ def dynamic_thresholding_fn(self, x0, t):
480
+ """
481
+ The dynamic thresholding method.
482
+ """
483
+ dims = x0.dim()
484
+ p = self.dynamic_thresholding_ratio
485
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
486
+ s = expand_dims(
487
+ torch.maximum(
488
+ s, self.thresholding_max_val * torch.ones_like(s).to(s.device)
489
+ ),
490
+ dims,
491
+ )
492
+ x0 = torch.clamp(x0, -s, s) / s
493
+ return x0
494
+
495
+ def noise_prediction_fn(self, x, t):
496
+ """
497
+ Return the noise prediction model.
498
+ """
499
+ return self.model(x, t)
500
+
501
+ def data_prediction_fn(self, x, t):
502
+ """
503
+ Return the data prediction model (with corrector).
504
+ """
505
+ noise = self.noise_prediction_fn(x, t)
506
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(
507
+ t
508
+ ), self.noise_schedule.marginal_std(t)
509
+ x0 = (x - sigma_t * noise) / alpha_t
510
+ if self.correcting_x0_fn is not None:
511
+ x0 = self.correcting_x0_fn(x0, t)
512
+ return x0
513
+
514
+ def model_fn(self, x, t):
515
+ """
516
+ Convert the model to the noise prediction model or the data prediction model.
517
+ """
518
+ if self.algorithm_type == "dpmsolver++":
519
+ return self.data_prediction_fn(x, t)
520
+ else:
521
+ return self.noise_prediction_fn(x, t)
522
+
523
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
524
+ """Compute the intermediate time steps for sampling.
525
+
526
+ Args:
527
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
528
+ - 'logSNR': uniform logSNR for the time steps.
529
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
530
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
531
+ t_T: A `float`. The starting time of the sampling (default is T).
532
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
533
+ N: A `int`. The total number of the spacing of the time steps.
534
+ device: A torch device.
535
+ Returns:
536
+ A pytorch tensor of the time steps, with the shape (N + 1,).
537
+ """
538
+ if skip_type == "logSNR":
539
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
540
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
541
+ logSNR_steps = torch.linspace(
542
+ lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1
543
+ ).to(device)
544
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
545
+ elif skip_type == "time_uniform":
546
+ return torch.linspace(t_T, t_0, N + 1).to(device)
547
+ elif skip_type == "time_quadratic":
548
+ t_order = 2
549
+ t = (
550
+ torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1)
551
+ .pow(t_order)
552
+ .to(device)
553
+ )
554
+ return t
555
+ else:
556
+ raise ValueError(
557
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(
558
+ skip_type
559
+ )
560
+ )
561
+
562
+ def get_orders_and_timesteps_for_singlestep_solver(
563
+ self, steps, order, skip_type, t_T, t_0, device
564
+ ):
565
+ """
566
+ Get the order of each step for sampling by the singlestep DPM-Solver.
567
+
568
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
569
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
570
+ - If order == 1:
571
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
572
+ - If order == 2:
573
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
574
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
575
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
576
+ - If order == 3:
577
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
578
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
579
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
580
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
581
+
582
+ ============================================
583
+ Args:
584
+ order: A `int`. The max order for the solver (2 or 3).
585
+ steps: A `int`. The total number of function evaluations (NFE).
586
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
587
+ - 'logSNR': uniform logSNR for the time steps.
588
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
589
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
590
+ t_T: A `float`. The starting time of the sampling (default is T).
591
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
592
+ device: A torch device.
593
+ Returns:
594
+ orders: A list of the solver order of each step.
595
+ """
596
+ if order == 3:
597
+ K = steps // 3 + 1
598
+ if steps % 3 == 0:
599
+ orders = [3,] * (
600
+ K - 2
601
+ ) + [2, 1]
602
+ elif steps % 3 == 1:
603
+ orders = [3,] * (
604
+ K - 1
605
+ ) + [1]
606
+ else:
607
+ orders = [3,] * (
608
+ K - 1
609
+ ) + [2]
610
+ elif order == 2:
611
+ if steps % 2 == 0:
612
+ K = steps // 2
613
+ orders = [
614
+ 2,
615
+ ] * K
616
+ else:
617
+ K = steps // 2 + 1
618
+ orders = [2,] * (
619
+ K - 1
620
+ ) + [1]
621
+ elif order == 1:
622
+ K = 1
623
+ orders = [
624
+ 1,
625
+ ] * steps
626
+ else:
627
+ raise ValueError("'order' must be '1' or '2' or '3'.")
628
+ if skip_type == "logSNR":
629
+ # To reproduce the results in DPM-Solver paper
630
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
631
+ else:
632
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
633
+ torch.cumsum(
634
+ torch.tensor(
635
+ [
636
+ 0,
637
+ ]
638
+ + orders
639
+ ),
640
+ 0,
641
+ ).to(device)
642
+ ]
643
+ return timesteps_outer, orders
644
+
645
+ def denoise_to_zero_fn(self, x, s):
646
+ """
647
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
648
+ """
649
+ return self.data_prediction_fn(x, s)
650
+
651
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
652
+ """
653
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
654
+
655
+ Args:
656
+ x: A pytorch tensor. The initial value at time `s`.
657
+ s: A pytorch tensor. The starting time, with the shape (1,).
658
+ t: A pytorch tensor. The ending time, with the shape (1,).
659
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
660
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
661
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
662
+ Returns:
663
+ x_t: A pytorch tensor. The approximated solution at time `t`.
664
+ """
665
+ ns = self.noise_schedule
666
+ dims = x.dim()
667
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
668
+ h = lambda_t - lambda_s
669
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(
670
+ s
671
+ ), ns.marginal_log_mean_coeff(t)
672
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
673
+ alpha_t = torch.exp(log_alpha_t)
674
+
675
+ if self.algorithm_type == "dpmsolver++":
676
+ phi_1 = torch.expm1(-h)
677
+ if model_s is None:
678
+ model_s = self.model_fn(x, s)
679
+ x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s
680
+ if return_intermediate:
681
+ return x_t, {"model_s": model_s}
682
+ else:
683
+ return x_t
684
+ else:
685
+ phi_1 = torch.expm1(h)
686
+ if model_s is None:
687
+ model_s = self.model_fn(x, s)
688
+ x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s
689
+ if return_intermediate:
690
+ return x_t, {"model_s": model_s}
691
+ else:
692
+ return x_t
693
+
694
+ def singlestep_dpm_solver_second_update(
695
+ self,
696
+ x,
697
+ s,
698
+ t,
699
+ r1=0.5,
700
+ model_s=None,
701
+ return_intermediate=False,
702
+ solver_type="dpmsolver",
703
+ ):
704
+ """
705
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
706
+
707
+ Args:
708
+ x: A pytorch tensor. The initial value at time `s`.
709
+ s: A pytorch tensor. The starting time, with the shape (1,).
710
+ t: A pytorch tensor. The ending time, with the shape (1,).
711
+ r1: A `float`. The hyperparameter of the second-order solver.
712
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
713
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
714
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
715
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
716
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
717
+ Returns:
718
+ x_t: A pytorch tensor. The approximated solution at time `t`.
719
+ """
720
+ if solver_type not in ["dpmsolver", "taylor"]:
721
+ raise ValueError(
722
+ "'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(
723
+ solver_type
724
+ )
725
+ )
726
+ if r1 is None:
727
+ r1 = 0.5
728
+ ns = self.noise_schedule
729
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
730
+ h = lambda_t - lambda_s
731
+ lambda_s1 = lambda_s + r1 * h
732
+ s1 = ns.inverse_lambda(lambda_s1)
733
+ log_alpha_s, log_alpha_s1, log_alpha_t = (
734
+ ns.marginal_log_mean_coeff(s),
735
+ ns.marginal_log_mean_coeff(s1),
736
+ ns.marginal_log_mean_coeff(t),
737
+ )
738
+ sigma_s, sigma_s1, sigma_t = (
739
+ ns.marginal_std(s),
740
+ ns.marginal_std(s1),
741
+ ns.marginal_std(t),
742
+ )
743
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
744
+
745
+ if self.algorithm_type == "dpmsolver++":
746
+ phi_11 = torch.expm1(-r1 * h)
747
+ phi_1 = torch.expm1(-h)
748
+
749
+ if model_s is None:
750
+ model_s = self.model_fn(x, s)
751
+ x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s
752
+ model_s1 = self.model_fn(x_s1, s1)
753
+ if solver_type == "dpmsolver":
754
+ x_t = (
755
+ (sigma_t / sigma_s) * x
756
+ - (alpha_t * phi_1) * model_s
757
+ - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
758
+ )
759
+ elif solver_type == "taylor":
760
+ x_t = (
761
+ (sigma_t / sigma_s) * x
762
+ - (alpha_t * phi_1) * model_s
763
+ + (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s)
764
+ )
765
+ else:
766
+ phi_11 = torch.expm1(r1 * h)
767
+ phi_1 = torch.expm1(h)
768
+
769
+ if model_s is None:
770
+ model_s = self.model_fn(x, s)
771
+ x_s1 = (
772
+ torch.exp(log_alpha_s1 - log_alpha_s) * x
773
+ - (sigma_s1 * phi_11) * model_s
774
+ )
775
+ model_s1 = self.model_fn(x_s1, s1)
776
+ if solver_type == "dpmsolver":
777
+ x_t = (
778
+ torch.exp(log_alpha_t - log_alpha_s) * x
779
+ - (sigma_t * phi_1) * model_s
780
+ - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
781
+ )
782
+ elif solver_type == "taylor":
783
+ x_t = (
784
+ torch.exp(log_alpha_t - log_alpha_s) * x
785
+ - (sigma_t * phi_1) * model_s
786
+ - (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s)
787
+ )
788
+ if return_intermediate:
789
+ return x_t, {"model_s": model_s, "model_s1": model_s1}
790
+ else:
791
+ return x_t
792
+
793
+ def singlestep_dpm_solver_third_update(
794
+ self,
795
+ x,
796
+ s,
797
+ t,
798
+ r1=1.0 / 3.0,
799
+ r2=2.0 / 3.0,
800
+ model_s=None,
801
+ model_s1=None,
802
+ return_intermediate=False,
803
+ solver_type="dpmsolver",
804
+ ):
805
+ """
806
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
807
+
808
+ Args:
809
+ x: A pytorch tensor. The initial value at time `s`.
810
+ s: A pytorch tensor. The starting time, with the shape (1,).
811
+ t: A pytorch tensor. The ending time, with the shape (1,).
812
+ r1: A `float`. The hyperparameter of the third-order solver.
813
+ r2: A `float`. The hyperparameter of the third-order solver.
814
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
815
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
816
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
817
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
818
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
819
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
820
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
821
+ Returns:
822
+ x_t: A pytorch tensor. The approximated solution at time `t`.
823
+ """
824
+ if solver_type not in ["dpmsolver", "taylor"]:
825
+ raise ValueError(
826
+ "'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(
827
+ solver_type
828
+ )
829
+ )
830
+ if r1 is None:
831
+ r1 = 1.0 / 3.0
832
+ if r2 is None:
833
+ r2 = 2.0 / 3.0
834
+ ns = self.noise_schedule
835
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
836
+ h = lambda_t - lambda_s
837
+ lambda_s1 = lambda_s + r1 * h
838
+ lambda_s2 = lambda_s + r2 * h
839
+ s1 = ns.inverse_lambda(lambda_s1)
840
+ s2 = ns.inverse_lambda(lambda_s2)
841
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = (
842
+ ns.marginal_log_mean_coeff(s),
843
+ ns.marginal_log_mean_coeff(s1),
844
+ ns.marginal_log_mean_coeff(s2),
845
+ ns.marginal_log_mean_coeff(t),
846
+ )
847
+ sigma_s, sigma_s1, sigma_s2, sigma_t = (
848
+ ns.marginal_std(s),
849
+ ns.marginal_std(s1),
850
+ ns.marginal_std(s2),
851
+ ns.marginal_std(t),
852
+ )
853
+ alpha_s1, alpha_s2, alpha_t = (
854
+ torch.exp(log_alpha_s1),
855
+ torch.exp(log_alpha_s2),
856
+ torch.exp(log_alpha_t),
857
+ )
858
+
859
+ if self.algorithm_type == "dpmsolver++":
860
+ phi_11 = torch.expm1(-r1 * h)
861
+ phi_12 = torch.expm1(-r2 * h)
862
+ phi_1 = torch.expm1(-h)
863
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0
864
+ phi_2 = phi_1 / h + 1.0
865
+ phi_3 = phi_2 / h - 0.5
866
+
867
+ if model_s is None:
868
+ model_s = self.model_fn(x, s)
869
+ if model_s1 is None:
870
+ x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s
871
+ model_s1 = self.model_fn(x_s1, s1)
872
+ x_s2 = (
873
+ (sigma_s2 / sigma_s) * x
874
+ - (alpha_s2 * phi_12) * model_s
875
+ + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
876
+ )
877
+ model_s2 = self.model_fn(x_s2, s2)
878
+ if solver_type == "dpmsolver":
879
+ x_t = (
880
+ (sigma_t / sigma_s) * x
881
+ - (alpha_t * phi_1) * model_s
882
+ + (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
883
+ )
884
+ elif solver_type == "taylor":
885
+ D1_0 = (1.0 / r1) * (model_s1 - model_s)
886
+ D1_1 = (1.0 / r2) * (model_s2 - model_s)
887
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
888
+ D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
889
+ x_t = (
890
+ (sigma_t / sigma_s) * x
891
+ - (alpha_t * phi_1) * model_s
892
+ + (alpha_t * phi_2) * D1
893
+ - (alpha_t * phi_3) * D2
894
+ )
895
+ else:
896
+ phi_11 = torch.expm1(r1 * h)
897
+ phi_12 = torch.expm1(r2 * h)
898
+ phi_1 = torch.expm1(h)
899
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0
900
+ phi_2 = phi_1 / h - 1.0
901
+ phi_3 = phi_2 / h - 0.5
902
+
903
+ if model_s is None:
904
+ model_s = self.model_fn(x, s)
905
+ if model_s1 is None:
906
+ x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (
907
+ sigma_s1 * phi_11
908
+ ) * model_s
909
+ model_s1 = self.model_fn(x_s1, s1)
910
+ x_s2 = (
911
+ (torch.exp(log_alpha_s2 - log_alpha_s)) * x
912
+ - (sigma_s2 * phi_12) * model_s
913
+ - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
914
+ )
915
+ model_s2 = self.model_fn(x_s2, s2)
916
+ if solver_type == "dpmsolver":
917
+ x_t = (
918
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
919
+ - (sigma_t * phi_1) * model_s
920
+ - (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
921
+ )
922
+ elif solver_type == "taylor":
923
+ D1_0 = (1.0 / r1) * (model_s1 - model_s)
924
+ D1_1 = (1.0 / r2) * (model_s2 - model_s)
925
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
926
+ D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
927
+ x_t = (
928
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
929
+ - (sigma_t * phi_1) * model_s
930
+ - (sigma_t * phi_2) * D1
931
+ - (sigma_t * phi_3) * D2
932
+ )
933
+
934
+ if return_intermediate:
935
+ return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2}
936
+ else:
937
+ return x_t
938
+
939
+ def multistep_dpm_solver_second_update(
940
+ self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"
941
+ ):
942
+ """
943
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
944
+
945
+ Args:
946
+ x: A pytorch tensor. The initial value at time `s`.
947
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
948
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
949
+ t: A pytorch tensor. The ending time, with the shape (1,).
950
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
951
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
952
+ Returns:
953
+ x_t: A pytorch tensor. The approximated solution at time `t`.
954
+ """
955
+ if solver_type not in ["dpmsolver", "taylor"]:
956
+ raise ValueError(
957
+ "'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(
958
+ solver_type
959
+ )
960
+ )
961
+ ns = self.noise_schedule
962
+ model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
963
+ t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
964
+ lambda_prev_1, lambda_prev_0, lambda_t = (
965
+ ns.marginal_lambda(t_prev_1),
966
+ ns.marginal_lambda(t_prev_0),
967
+ ns.marginal_lambda(t),
968
+ )
969
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
970
+ t_prev_0
971
+ ), ns.marginal_log_mean_coeff(t)
972
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
973
+ alpha_t = torch.exp(log_alpha_t)
974
+
975
+ h_0 = lambda_prev_0 - lambda_prev_1
976
+ h = lambda_t - lambda_prev_0
977
+ r0 = h_0 / h
978
+ D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
979
+ if self.algorithm_type == "dpmsolver++":
980
+ phi_1 = torch.expm1(-h)
981
+ if solver_type == "dpmsolver":
982
+ x_t = (
983
+ (sigma_t / sigma_prev_0) * x
984
+ - (alpha_t * phi_1) * model_prev_0
985
+ - 0.5 * (alpha_t * phi_1) * D1_0
986
+ )
987
+ elif solver_type == "taylor":
988
+ x_t = (
989
+ (sigma_t / sigma_prev_0) * x
990
+ - (alpha_t * phi_1) * model_prev_0
991
+ + (alpha_t * (phi_1 / h + 1.0)) * D1_0
992
+ )
993
+ else:
994
+ phi_1 = torch.expm1(h)
995
+ if solver_type == "dpmsolver":
996
+ x_t = (
997
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
998
+ - (sigma_t * phi_1) * model_prev_0
999
+ - 0.5 * (sigma_t * phi_1) * D1_0
1000
+ )
1001
+ elif solver_type == "taylor":
1002
+ x_t = (
1003
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
1004
+ - (sigma_t * phi_1) * model_prev_0
1005
+ - (sigma_t * (phi_1 / h - 1.0)) * D1_0
1006
+ )
1007
+ return x_t
1008
+
1009
+ def multistep_dpm_solver_third_update(
1010
+ self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"
1011
+ ):
1012
+ """
1013
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
1014
+
1015
+ Args:
1016
+ x: A pytorch tensor. The initial value at time `s`.
1017
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
1018
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
1019
+ t: A pytorch tensor. The ending time, with the shape (1,).
1020
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
1021
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
1022
+ Returns:
1023
+ x_t: A pytorch tensor. The approximated solution at time `t`.
1024
+ """
1025
+ ns = self.noise_schedule
1026
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
1027
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
1028
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (
1029
+ ns.marginal_lambda(t_prev_2),
1030
+ ns.marginal_lambda(t_prev_1),
1031
+ ns.marginal_lambda(t_prev_0),
1032
+ ns.marginal_lambda(t),
1033
+ )
1034
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
1035
+ t_prev_0
1036
+ ), ns.marginal_log_mean_coeff(t)
1037
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
1038
+ alpha_t = torch.exp(log_alpha_t)
1039
+
1040
+ h_1 = lambda_prev_1 - lambda_prev_2
1041
+ h_0 = lambda_prev_0 - lambda_prev_1
1042
+ h = lambda_t - lambda_prev_0
1043
+ r0, r1 = h_0 / h, h_1 / h
1044
+ D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
1045
+ D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2)
1046
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
1047
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
1048
+ if self.algorithm_type == "dpmsolver++":
1049
+ phi_1 = torch.expm1(-h)
1050
+ phi_2 = phi_1 / h + 1.0
1051
+ phi_3 = phi_2 / h - 0.5
1052
+ x_t = (
1053
+ (sigma_t / sigma_prev_0) * x
1054
+ - (alpha_t * phi_1) * model_prev_0
1055
+ + (alpha_t * phi_2) * D1
1056
+ - (alpha_t * phi_3) * D2
1057
+ )
1058
+ else:
1059
+ phi_1 = torch.expm1(h)
1060
+ phi_2 = phi_1 / h - 1.0
1061
+ phi_3 = phi_2 / h - 0.5
1062
+ x_t = (
1063
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
1064
+ - (sigma_t * phi_1) * model_prev_0
1065
+ - (sigma_t * phi_2) * D1
1066
+ - (sigma_t * phi_3) * D2
1067
+ )
1068
+ return x_t
1069
+
1070
+ def singlestep_dpm_solver_update(
1071
+ self,
1072
+ x,
1073
+ s,
1074
+ t,
1075
+ order,
1076
+ return_intermediate=False,
1077
+ solver_type="dpmsolver",
1078
+ r1=None,
1079
+ r2=None,
1080
+ ):
1081
+ """
1082
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
1083
+
1084
+ Args:
1085
+ x: A pytorch tensor. The initial value at time `s`.
1086
+ s: A pytorch tensor. The starting time, with the shape (1,).
1087
+ t: A pytorch tensor. The ending time, with the shape (1,).
1088
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
1089
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
1090
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
1091
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
1092
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
1093
+ r2: A `float`. The hyperparameter of the third-order solver.
1094
+ Returns:
1095
+ x_t: A pytorch tensor. The approximated solution at time `t`.
1096
+ """
1097
+ if order == 1:
1098
+ return self.dpm_solver_first_update(
1099
+ x, s, t, return_intermediate=return_intermediate
1100
+ )
1101
+ elif order == 2:
1102
+ return self.singlestep_dpm_solver_second_update(
1103
+ x,
1104
+ s,
1105
+ t,
1106
+ return_intermediate=return_intermediate,
1107
+ solver_type=solver_type,
1108
+ r1=r1,
1109
+ )
1110
+ elif order == 3:
1111
+ return self.singlestep_dpm_solver_third_update(
1112
+ x,
1113
+ s,
1114
+ t,
1115
+ return_intermediate=return_intermediate,
1116
+ solver_type=solver_type,
1117
+ r1=r1,
1118
+ r2=r2,
1119
+ )
1120
+ else:
1121
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
1122
+
1123
+ def multistep_dpm_solver_update(
1124
+ self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"
1125
+ ):
1126
+ """
1127
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
1128
+
1129
+ Args:
1130
+ x: A pytorch tensor. The initial value at time `s`.
1131
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
1132
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
1133
+ t: A pytorch tensor. The ending time, with the shape (1,).
1134
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
1135
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
1136
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
1137
+ Returns:
1138
+ x_t: A pytorch tensor. The approximated solution at time `t`.
1139
+ """
1140
+ if order == 1:
1141
+ return self.dpm_solver_first_update(
1142
+ x, t_prev_list[-1], t, model_s=model_prev_list[-1]
1143
+ )
1144
+ elif order == 2:
1145
+ return self.multistep_dpm_solver_second_update(
1146
+ x, model_prev_list, t_prev_list, t, solver_type=solver_type
1147
+ )
1148
+ elif order == 3:
1149
+ return self.multistep_dpm_solver_third_update(
1150
+ x, model_prev_list, t_prev_list, t, solver_type=solver_type
1151
+ )
1152
+ else:
1153
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
1154
+
1155
+ def dpm_solver_adaptive(
1156
+ self,
1157
+ x,
1158
+ order,
1159
+ t_T,
1160
+ t_0,
1161
+ h_init=0.05,
1162
+ atol=0.0078,
1163
+ rtol=0.05,
1164
+ theta=0.9,
1165
+ t_err=1e-5,
1166
+ solver_type="dpmsolver",
1167
+ ):
1168
+ """
1169
+ The adaptive step size solver based on singlestep DPM-Solver.
1170
+
1171
+ Args:
1172
+ x: A pytorch tensor. The initial value at time `t_T`.
1173
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
1174
+ t_T: A `float`. The starting time of the sampling (default is T).
1175
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
1176
+ h_init: A `float`. The initial step size (for logSNR).
1177
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
1178
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
1179
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
1180
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
1181
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
1182
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
1183
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
1184
+ Returns:
1185
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
1186
+
1187
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
1188
+ """
1189
+ ns = self.noise_schedule
1190
+ s = t_T * torch.ones((1,)).to(x)
1191
+ lambda_s = ns.marginal_lambda(s)
1192
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
1193
+ h = h_init * torch.ones_like(s).to(x)
1194
+ x_prev = x
1195
+ nfe = 0
1196
+ if order == 2:
1197
+ r1 = 0.5
1198
+
1199
+ def lower_update(x, s, t):
1200
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=True)
1201
+
1202
+ def higher_update(x, s, t, **kwargs):
1203
+ return self.singlestep_dpm_solver_second_update(
1204
+ x, s, t, r1=r1, solver_type=solver_type, **kwargs
1205
+ )
1206
+
1207
+ elif order == 3:
1208
+ r1, r2 = 1.0 / 3.0, 2.0 / 3.0
1209
+
1210
+ def lower_update(x, s, t):
1211
+ return self.singlestep_dpm_solver_second_update(
1212
+ x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type
1213
+ )
1214
+
1215
+ def higher_update(x, s, t, **kwargs):
1216
+ return self.singlestep_dpm_solver_third_update(
1217
+ x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
1218
+ )
1219
+
1220
+ else:
1221
+ raise ValueError(
1222
+ "For adaptive step size solver, order must be 2 or 3, got {}".format(
1223
+ order
1224
+ )
1225
+ )
1226
+ while torch.abs((s - t_0)).mean() > t_err:
1227
+ t = ns.inverse_lambda(lambda_s + h)
1228
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
1229
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
1230
+ delta = torch.max(
1231
+ torch.ones_like(x).to(x) * atol,
1232
+ rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)),
1233
+ )
1234
+
1235
+ def norm_fn(v):
1236
+ return torch.sqrt(
1237
+ torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)
1238
+ )
1239
+
1240
+ E = norm_fn((x_higher - x_lower) / delta).max()
1241
+ if torch.all(E <= 1.0):
1242
+ x = x_higher
1243
+ s = t
1244
+ x_prev = x_lower
1245
+ lambda_s = ns.marginal_lambda(s)
1246
+ h = torch.min(
1247
+ theta * h * torch.float_power(E, -1.0 / order).float(),
1248
+ lambda_0 - lambda_s,
1249
+ )
1250
+ nfe += order
1251
+ print("adaptive solver nfe", nfe)
1252
+ return x
1253
+
1254
+ def add_noise(self, x, t, noise=None):
1255
+ """
1256
+ Compute the noised input xt = alpha_t * x + sigma_t * noise.
1257
+
1258
+ Args:
1259
+ x: A `torch.Tensor` with shape `(batch_size, *shape)`.
1260
+ t: A `torch.Tensor` with shape `(t_size,)`.
1261
+ Returns:
1262
+ xt with shape `(t_size, batch_size, *shape)`.
1263
+ """
1264
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(
1265
+ t
1266
+ ), self.noise_schedule.marginal_std(t)
1267
+ if noise is None:
1268
+ noise = torch.randn((t.shape[0], *x.shape), device=x.device)
1269
+ x = x.reshape((-1, *x.shape))
1270
+ xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
1271
+ if t.shape[0] == 1:
1272
+ return xt.squeeze(0)
1273
+ else:
1274
+ return xt
1275
+
1276
+ def inverse(
1277
+ self,
1278
+ x,
1279
+ steps=20,
1280
+ t_start=None,
1281
+ t_end=None,
1282
+ order=2,
1283
+ skip_type="time_uniform",
1284
+ method="multistep",
1285
+ lower_order_final=True,
1286
+ denoise_to_zero=False,
1287
+ solver_type="dpmsolver",
1288
+ atol=0.0078,
1289
+ rtol=0.05,
1290
+ return_intermediate=False,
1291
+ ):
1292
+ """
1293
+ Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
1294
+ For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
1295
+ """
1296
+ t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start
1297
+ t_T = self.noise_schedule.T if t_end is None else t_end
1298
+ assert (
1299
+ t_0 > 0 and t_T > 0
1300
+ ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1301
+ return self.sample(
1302
+ x,
1303
+ steps=steps,
1304
+ t_start=t_0,
1305
+ t_end=t_T,
1306
+ order=order,
1307
+ skip_type=skip_type,
1308
+ method=method,
1309
+ lower_order_final=lower_order_final,
1310
+ denoise_to_zero=denoise_to_zero,
1311
+ solver_type=solver_type,
1312
+ atol=atol,
1313
+ rtol=rtol,
1314
+ return_intermediate=return_intermediate,
1315
+ )
1316
+
1317
+ def sample(
1318
+ self,
1319
+ x,
1320
+ steps=20,
1321
+ t_start=None,
1322
+ t_end=None,
1323
+ order=2,
1324
+ skip_type="time_uniform",
1325
+ method="multistep",
1326
+ lower_order_final=True,
1327
+ denoise_to_zero=False,
1328
+ solver_type="dpmsolver",
1329
+ atol=0.0078,
1330
+ rtol=0.05,
1331
+ return_intermediate=False,
1332
+ ):
1333
+ """
1334
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
1335
+
1336
+ =====================================================
1337
+
1338
+ We support the following algorithms for both noise prediction model and data prediction model:
1339
+ - 'singlestep':
1340
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
1341
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
1342
+ The total number of function evaluations (NFE) == `steps`.
1343
+ Given a fixed NFE == `steps`, the sampling procedure is:
1344
+ - If `order` == 1:
1345
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
1346
+ - If `order` == 2:
1347
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
1348
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
1349
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1350
+ - If `order` == 3:
1351
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
1352
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1353
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
1354
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
1355
+ - 'multistep':
1356
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
1357
+ We initialize the first `order` values by lower order multistep solvers.
1358
+ Given a fixed NFE == `steps`, the sampling procedure is:
1359
+ Denote K = steps.
1360
+ - If `order` == 1:
1361
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
1362
+ - If `order` == 2:
1363
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
1364
+ - If `order` == 3:
1365
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
1366
+ - 'singlestep_fixed':
1367
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
1368
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
1369
+ - 'adaptive':
1370
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
1371
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
1372
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
1373
+ (NFE) and the sample quality.
1374
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
1375
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
1376
+
1377
+ =====================================================
1378
+
1379
+ Some advices for choosing the algorithm:
1380
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1381
+ Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
1382
+ e.g., DPM-Solver:
1383
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
1384
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1385
+ skip_type='time_uniform', method='singlestep')
1386
+ e.g., DPM-Solver++:
1387
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1388
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1389
+ skip_type='time_uniform', method='singlestep')
1390
+ - For **guided sampling with large guidance scale** by DPMs:
1391
+ Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
1392
+ e.g.
1393
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1394
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1395
+ skip_type='time_uniform', method='multistep')
1396
+
1397
+ We support three types of `skip_type`:
1398
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1399
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1400
+ - 'time_quadratic': quadratic time for the time steps.
1401
+
1402
+ =====================================================
1403
+ Args:
1404
+ x: A pytorch tensor. The initial value at time `t_start`
1405
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1406
+ steps: A `int`. The total number of function evaluations (NFE).
1407
+ t_start: A `float`. The starting time of the sampling.
1408
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1409
+ t_end: A `float`. The ending time of the sampling.
1410
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1411
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1412
+ For discrete-time DPMs:
1413
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1414
+ For continuous-time DPMs:
1415
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1416
+ order: A `int`. The order of DPM-Solver.
1417
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1418
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1419
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1420
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1421
+
1422
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1423
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1424
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1425
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1426
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1427
+ it for high-resolutional images.
1428
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1429
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1430
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1431
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1432
+ solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
1433
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1434
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1435
+ return_intermediate: A `bool`. Whether to save the xt at each step.
1436
+ When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
1437
+ Returns:
1438
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1439
+
1440
+ """
1441
+ t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
1442
+ t_T = self.noise_schedule.T if t_start is None else t_start
1443
+ assert (
1444
+ t_0 > 0 and t_T > 0
1445
+ ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1446
+ if return_intermediate:
1447
+ assert method in [
1448
+ "multistep",
1449
+ "singlestep",
1450
+ "singlestep_fixed",
1451
+ ], "Cannot use adaptive solver when saving intermediate values"
1452
+ if self.correcting_xt_fn is not None:
1453
+ assert method in [
1454
+ "multistep",
1455
+ "singlestep",
1456
+ "singlestep_fixed",
1457
+ ], "Cannot use adaptive solver when correcting_xt_fn is not None"
1458
+ device = x.device
1459
+ intermediates = []
1460
+ with torch.no_grad():
1461
+ if method == "adaptive":
1462
+ x = self.dpm_solver_adaptive(
1463
+ x,
1464
+ order=order,
1465
+ t_T=t_T,
1466
+ t_0=t_0,
1467
+ atol=atol,
1468
+ rtol=rtol,
1469
+ solver_type=solver_type,
1470
+ )
1471
+ elif method == "multistep":
1472
+ assert steps >= order
1473
+ timesteps = self.get_time_steps(
1474
+ skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device
1475
+ )
1476
+ assert timesteps.shape[0] - 1 == steps
1477
+ # Init the initial values.
1478
+ step = 0
1479
+ t = timesteps[step]
1480
+ t_prev_list = [t]
1481
+ model_prev_list = [self.model_fn(x, t)]
1482
+ if self.correcting_xt_fn is not None:
1483
+ x = self.correcting_xt_fn(x, t, step)
1484
+ if return_intermediate:
1485
+ intermediates.append(x)
1486
+ # Init the first `order` values by lower order multistep DPM-Solver.
1487
+ for step in range(1, order):
1488
+ t = timesteps[step]
1489
+ x = self.multistep_dpm_solver_update(
1490
+ x,
1491
+ model_prev_list,
1492
+ t_prev_list,
1493
+ t,
1494
+ step,
1495
+ solver_type=solver_type,
1496
+ )
1497
+ if self.correcting_xt_fn is not None:
1498
+ x = self.correcting_xt_fn(x, t, step)
1499
+ if return_intermediate:
1500
+ intermediates.append(x)
1501
+ t_prev_list.append(t)
1502
+ model_prev_list.append(self.model_fn(x, t))
1503
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1504
+ for step in range(order, steps + 1):
1505
+ t = timesteps[step]
1506
+ # We only use lower order for steps < 10
1507
+ if lower_order_final and steps < 10:
1508
+ step_order = min(order, steps + 1 - step)
1509
+ else:
1510
+ step_order = order
1511
+ x = self.multistep_dpm_solver_update(
1512
+ x,
1513
+ model_prev_list,
1514
+ t_prev_list,
1515
+ t,
1516
+ step_order,
1517
+ solver_type=solver_type,
1518
+ )
1519
+ if self.correcting_xt_fn is not None:
1520
+ x = self.correcting_xt_fn(x, t, step)
1521
+ if return_intermediate:
1522
+ intermediates.append(x)
1523
+ for i in range(order - 1):
1524
+ t_prev_list[i] = t_prev_list[i + 1]
1525
+ model_prev_list[i] = model_prev_list[i + 1]
1526
+ t_prev_list[-1] = t
1527
+ # We do not need to evaluate the final model value.
1528
+ if step < steps:
1529
+ model_prev_list[-1] = self.model_fn(x, t)
1530
+ elif method in ["singlestep", "singlestep_fixed"]:
1531
+ if method == "singlestep":
1532
+ (
1533
+ timesteps_outer,
1534
+ orders,
1535
+ ) = self.get_orders_and_timesteps_for_singlestep_solver(
1536
+ steps=steps,
1537
+ order=order,
1538
+ skip_type=skip_type,
1539
+ t_T=t_T,
1540
+ t_0=t_0,
1541
+ device=device,
1542
+ )
1543
+ elif method == "singlestep_fixed":
1544
+ K = steps // order
1545
+ orders = [
1546
+ order,
1547
+ ] * K
1548
+ timesteps_outer = self.get_time_steps(
1549
+ skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device
1550
+ )
1551
+ for step, order in enumerate(orders):
1552
+ s, t = timesteps_outer[step], timesteps_outer[step + 1]
1553
+ timesteps_inner = self.get_time_steps(
1554
+ skip_type=skip_type,
1555
+ t_T=s.item(),
1556
+ t_0=t.item(),
1557
+ N=order,
1558
+ device=device,
1559
+ )
1560
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1561
+ h = lambda_inner[-1] - lambda_inner[0]
1562
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1563
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1564
+ x = self.singlestep_dpm_solver_update(
1565
+ x, s, t, order, solver_type=solver_type, r1=r1, r2=r2
1566
+ )
1567
+ if self.correcting_xt_fn is not None:
1568
+ x = self.correcting_xt_fn(x, t, step)
1569
+ if return_intermediate:
1570
+ intermediates.append(x)
1571
+ else:
1572
+ raise ValueError("Got wrong method {}".format(method))
1573
+ if denoise_to_zero:
1574
+ t = torch.ones((1,)).to(device) * t_0
1575
+ x = self.denoise_to_zero_fn(x, t)
1576
+ if self.correcting_xt_fn is not None:
1577
+ x = self.correcting_xt_fn(x, t, step + 1)
1578
+ if return_intermediate:
1579
+ intermediates.append(x)
1580
+ if return_intermediate:
1581
+ return x, intermediates
1582
+ else:
1583
+ return x
1584
+
1585
+
1586
+ #############################################################
1587
+ # other utility functions
1588
+ #############################################################
1589
+
1590
+
1591
+ def interpolate_fn(x, xp, yp):
1592
+ """
1593
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1594
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1595
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1596
+
1597
+ Args:
1598
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1599
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1600
+ yp: PyTorch tensor with shape [C, K].
1601
+ Returns:
1602
+ The function values f(x), with shape [N, C].
1603
+ """
1604
+ N, K = x.shape[0], xp.shape[1]
1605
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1606
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1607
+ x_idx = torch.argmin(x_indices, dim=2)
1608
+ cand_start_idx = x_idx - 1
1609
+ start_idx = torch.where(
1610
+ torch.eq(x_idx, 0),
1611
+ torch.tensor(1, device=x.device),
1612
+ torch.where(
1613
+ torch.eq(x_idx, K),
1614
+ torch.tensor(K - 2, device=x.device),
1615
+ cand_start_idx,
1616
+ ),
1617
+ )
1618
+ end_idx = torch.where(
1619
+ torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1
1620
+ )
1621
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1622
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1623
+ start_idx2 = torch.where(
1624
+ torch.eq(x_idx, 0),
1625
+ torch.tensor(0, device=x.device),
1626
+ torch.where(
1627
+ torch.eq(x_idx, K),
1628
+ torch.tensor(K - 2, device=x.device),
1629
+ cand_start_idx,
1630
+ ),
1631
+ )
1632
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1633
+ start_y = torch.gather(
1634
+ y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)
1635
+ ).squeeze(2)
1636
+ end_y = torch.gather(
1637
+ y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)
1638
+ ).squeeze(2)
1639
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1640
+ return cand
1641
+
1642
+
1643
+ def expand_dims(v, dims):
1644
+ """
1645
+ Expand the tensor `v` to the dim `dims`.
1646
+
1647
+ Args:
1648
+ `v`: a PyTorch tensor with shape [N].
1649
+ `dim`: a `int`.
1650
+ Returns:
1651
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1652
+ """
1653
+ return v[(...,) + (None,) * (dims - 1)]
tortoise/get_conditioning_latents.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # +
2
+ import argparse
3
+ import os
4
+
5
+ import torch
6
+ from api import TextToSpeech
7
+
8
+ from tortoise.utils.audio import get_voices, load_required_audio
9
+
10
+ """
11
+ Dumps the conditioning latents for the specified voice to disk. These are expressive latents which can be used for
12
+ other ML models, or can be augmented manually and fed back into Tortoise to affect vocal qualities.
13
+ """
14
+ if __name__ == "__main__":
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument(
17
+ "--voice",
18
+ type=str,
19
+ help="Selects the voice to convert to conditioning latents",
20
+ default="pat2",
21
+ )
22
+ parser.add_argument(
23
+ "--output_path",
24
+ type=str,
25
+ help="Where to store outputs.",
26
+ default="../results/conditioning_latents",
27
+ )
28
+ parser.add_argument(
29
+ "--latent_averaging_mode",
30
+ type=int,
31
+ help="How to average voice latents, 0 for standard, 1 for per-sample, 2 for per-minichunk",
32
+ default=0,
33
+ )
34
+
35
+ args = parser.parse_args()
36
+ os.makedirs(args.output_path, exist_ok=True)
37
+
38
+ tts = TextToSpeech()
39
+ voices = get_voices()
40
+ print(list(voices.keys()))
41
+ selected_voices = args.voice.split(",")
42
+ for voice in selected_voices:
43
+ cond_paths = voices[voice]
44
+ conds = []
45
+ for cond_path in cond_paths:
46
+ c = load_required_audio(cond_path)
47
+ conds.append(c)
48
+ conditioning_latents = tts.get_conditioning_latents(
49
+ conds, latent_averaging_mode=args.latent_averaging_mode
50
+ )
51
+ torch.save(conditioning_latents, os.path.join(args.output_path, f"{voice}.pth"))
tortoise/inference.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from random import randint
4
+ from typing import List, Optional, Set, Union
5
+
6
+ from tortoise.utils.audio import get_voices, load_audio, load_voices
7
+ from tortoise.utils.text import split_and_recombine_text
8
+
9
+
10
+ def get_all_voices(extra_voice_dirs_str: str = ""):
11
+ extra_voice_dirs = extra_voice_dirs_str.split(",") if extra_voice_dirs_str else []
12
+ return sorted(get_voices(extra_voice_dirs)), extra_voice_dirs
13
+
14
+
15
+ def parse_voice_str(voice_str: str, all_voices: List[str]):
16
+ selected_voices = all_voices if voice_str == "all" else voice_str.split(",")
17
+ selected_voices = [v.split("&") if "&" in v else [v] for v in selected_voices]
18
+ for voices in selected_voices:
19
+ for v in voices:
20
+ if v != "random" and v not in all_voices:
21
+ raise ValueError(
22
+ f"voice {v} not available, use --list-voices to see available voices."
23
+ )
24
+
25
+ return selected_voices
26
+
27
+
28
+ def voice_loader(selected_voices: list, extra_voice_dirs: List[str]):
29
+ for voices in selected_voices:
30
+ yield voices, *load_voices(voices, extra_voice_dirs)
31
+
32
+
33
+ def parse_multiarg_text(text: List[str]):
34
+ return (" ".join(text) if text else "".join(line for line in sys.stdin)).strip()
35
+
36
+
37
+ def split_text(text: str, text_split: str):
38
+ if text_split:
39
+ desired_length, max_length = map(int, text_split.split(","))
40
+ if desired_length > max_length:
41
+ raise ValueError(
42
+ f"--text-split: desired_length ({desired_length}) must be <= max_length ({max_length})"
43
+ )
44
+ texts = split_and_recombine_text(text, desired_length, max_length)
45
+ else:
46
+ texts = split_and_recombine_text(text)
47
+ #
48
+ if not texts:
49
+ raise ValueError("no text provided")
50
+ return texts
51
+
52
+
53
+ def validate_output_dir(output_dir: str, selected_voices: list, candidates: int):
54
+ if output_dir:
55
+ os.makedirs(output_dir, exist_ok=True)
56
+ else:
57
+ if len(selected_voices) > 1:
58
+ raise ValueError('cannot have multiple voices without --output-dir"')
59
+ if candidates > 1:
60
+ raise ValueError('cannot have multiple candidates without --output-dir"')
61
+ return output_dir
62
+
63
+
64
+ def check_pydub(play: bool):
65
+ if play:
66
+ try:
67
+ import pydub
68
+ import pydub.playback
69
+
70
+ return pydub
71
+ except ImportError:
72
+ raise RuntimeError(
73
+ '--play requires pydub to be installed, which can be done with "pip install pydub"'
74
+ )
75
+
76
+
77
+ def get_seed(seed: Optional[int]):
78
+ return randint(0, 2**32 - 1) if seed is None else seed
79
+
80
+
81
+ from pathlib import Path
82
+ from typing import Any, Callable
83
+
84
+ import torch
85
+ import torchaudio
86
+
87
+
88
+ def run_and_save_tts(
89
+ call_tts,
90
+ text,
91
+ output_dir: Path,
92
+ return_deterministic_state,
93
+ return_filepaths=False,
94
+ voicefixer=True,
95
+ ):
96
+ output_dir.mkdir(exist_ok=True)
97
+ if return_deterministic_state:
98
+ gen, dbg = call_tts(text)
99
+ torch.save(dbg, output_dir / "dbg.pt")
100
+ else:
101
+ gen = call_tts(text)
102
+ #
103
+ if not isinstance(gen, list):
104
+ gen = [gen]
105
+ gen = [g.squeeze(0).cpu() for g in gen]
106
+ fps = []
107
+ for i, g in enumerate(gen):
108
+ fps.append(output_dir / f"{i}.wav")
109
+ save_gen_with_voicefix(g, fps[-1], squeeze=False, voicefixer=voicefixer)
110
+ # torchaudio.save(output_dir/f'{i}.wav', g, 24000)
111
+ return fps if return_filepaths else gen
112
+
113
+
114
+ def infer_on_texts(
115
+ call_tts: Callable[[str], Any],
116
+ texts: List[str],
117
+ output_dir: Union[str, Path],
118
+ return_deterministic_state: bool,
119
+ lines_to_regen: Set[int],
120
+ logger=print,
121
+ return_filepaths=False,
122
+ voicefixer=True,
123
+ ):
124
+ audio_chunks = []
125
+ base_p = Path(output_dir)
126
+ base_p.mkdir(exist_ok=True)
127
+
128
+ for text_idx, text in enumerate(texts):
129
+ line_p = base_p / f"{text_idx}"
130
+ line_p.mkdir(exist_ok=True)
131
+ #
132
+ if text_idx not in lines_to_regen:
133
+ files = list(line_p.glob("*.wav"))
134
+ if files:
135
+ logger(f"loading existing audio fragments for [{text_idx}]")
136
+ audio_chunks.append([load_audio(str(f), 24000) for f in files])
137
+ continue
138
+ else:
139
+ logger(f"no existing audio fragment for [{text_idx}]")
140
+ #
141
+ logger(f"generating audio for text {text_idx}: {text}")
142
+ audio_chunks.append(
143
+ run_and_save_tts(
144
+ call_tts,
145
+ text,
146
+ line_p,
147
+ return_deterministic_state,
148
+ voicefixer=voicefixer,
149
+ )
150
+ )
151
+
152
+ fnames = []
153
+ results = []
154
+ for i in range(len(audio_chunks[0])):
155
+ resultant = torch.cat([c[i] for c in audio_chunks], dim=-1)
156
+ fnames.append(base_p / f"combined-{i}.wav")
157
+ save_gen_with_voicefix(
158
+ resultant, fnames[-1], squeeze=False, voicefixer=False
159
+ ) # do not run fix on combined!!
160
+ results.append(resultant)
161
+ # torchaudio.save(base_p/'combined.wav', resultant, 24000)
162
+ return fnames if return_filepaths else results
163
+
164
+
165
+ from voicefixer import VoiceFixer
166
+
167
+ vfixer = VoiceFixer()
168
+
169
+
170
+ def save_gen_with_voicefix(g, fpath, squeeze=True, voicefixer=True):
171
+ torchaudio.save(fpath, g.squeeze(0).cpu() if squeeze else g, 24000, format="wav")
172
+ if voicefixer:
173
+ vfixer.restore(
174
+ input=fpath,
175
+ output=fpath,
176
+ cuda=True,
177
+ mode=0,
178
+ # your_vocoder_func = convert_mel_to_wav # TODO test if integration with unvinet improves things
179
+ )
tortoise/models/__init__.py ADDED
File without changes
tortoise/models/arch_util.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import math
3
+ import os
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+
10
+ from tortoise.models.xtransformers import (
11
+ ContinuousTransformerWrapper,
12
+ RelativePositionBias,
13
+ )
14
+
15
+
16
+ def zero_module(module):
17
+ """
18
+ Zero out the parameters of a module and return it.
19
+ """
20
+ for p in module.parameters():
21
+ p.detach().zero_()
22
+ return module
23
+
24
+
25
+ class GroupNorm32(nn.GroupNorm):
26
+ def forward(self, x):
27
+ return super().forward(x.float()).type(x.dtype)
28
+
29
+
30
+ def normalization(channels):
31
+ """
32
+ Make a standard normalization layer.
33
+
34
+ :param channels: number of input channels.
35
+ :return: an nn.Module for normalization.
36
+ """
37
+ groups = 32
38
+ if channels <= 16:
39
+ groups = 8
40
+ elif channels <= 64:
41
+ groups = 16
42
+ while channels % groups != 0:
43
+ groups = int(groups / 2)
44
+ assert groups > 2
45
+ return GroupNorm32(groups, channels)
46
+
47
+
48
+ class QKVAttentionLegacy(nn.Module):
49
+ """
50
+ A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
51
+ """
52
+
53
+ def __init__(self, n_heads):
54
+ super().__init__()
55
+ self.n_heads = n_heads
56
+
57
+ def forward(self, qkv, mask=None, rel_pos=None):
58
+ """
59
+ Apply QKV attention.
60
+
61
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
62
+ :return: an [N x (H * C) x T] tensor after attention.
63
+ """
64
+ bs, width, length = qkv.shape
65
+ assert width % (3 * self.n_heads) == 0
66
+ ch = width // (3 * self.n_heads)
67
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
68
+ scale = 1 / math.sqrt(math.sqrt(ch))
69
+ weight = torch.einsum(
70
+ "bct,bcs->bts", q * scale, k * scale
71
+ ) # More stable with f16 than dividing afterwards
72
+ if rel_pos is not None:
73
+ weight = rel_pos(
74
+ weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])
75
+ ).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
76
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
77
+ if mask is not None:
78
+ # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
79
+ mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
80
+ weight = weight * mask
81
+ a = torch.einsum("bts,bcs->bct", weight, v)
82
+
83
+ return a.reshape(bs, -1, length)
84
+
85
+
86
+ class AttentionBlock(nn.Module):
87
+ """
88
+ An attention block that allows spatial positions to attend to each other.
89
+
90
+ Originally ported from here, but adapted to the N-d case.
91
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ channels,
97
+ num_heads=1,
98
+ num_head_channels=-1,
99
+ do_checkpoint=True,
100
+ relative_pos_embeddings=False,
101
+ ):
102
+ super().__init__()
103
+ self.channels = channels
104
+ self.do_checkpoint = do_checkpoint
105
+ if num_head_channels == -1:
106
+ self.num_heads = num_heads
107
+ else:
108
+ assert (
109
+ channels % num_head_channels == 0
110
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
111
+ self.num_heads = channels // num_head_channels
112
+ self.norm = normalization(channels)
113
+ self.qkv = nn.Conv1d(channels, channels * 3, 1)
114
+ # split heads before split qkv
115
+ self.attention = QKVAttentionLegacy(self.num_heads)
116
+
117
+ self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
118
+ if relative_pos_embeddings:
119
+ self.relative_pos_embeddings = RelativePositionBias(
120
+ scale=(channels // self.num_heads) ** 0.5,
121
+ causal=False,
122
+ heads=num_heads,
123
+ num_buckets=32,
124
+ max_distance=64,
125
+ )
126
+ else:
127
+ self.relative_pos_embeddings = None
128
+
129
+ def forward(self, x, mask=None):
130
+ b, c, *spatial = x.shape
131
+ x = x.reshape(b, c, -1)
132
+ qkv = self.qkv(self.norm(x))
133
+ h = self.attention(qkv, mask, self.relative_pos_embeddings)
134
+ h = self.proj_out(h)
135
+ return (x + h).reshape(b, c, *spatial)
136
+
137
+
138
+ class Upsample(nn.Module):
139
+ """
140
+ An upsampling layer with an optional convolution.
141
+
142
+ :param channels: channels in the inputs and outputs.
143
+ :param use_conv: a bool determining if a convolution is applied.
144
+ """
145
+
146
+ def __init__(self, channels, use_conv, out_channels=None, factor=4):
147
+ super().__init__()
148
+ self.channels = channels
149
+ self.out_channels = out_channels or channels
150
+ self.use_conv = use_conv
151
+ self.factor = factor
152
+ if use_conv:
153
+ ksize = 5
154
+ pad = 2
155
+ self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad)
156
+
157
+ def forward(self, x):
158
+ assert x.shape[1] == self.channels
159
+ x = F.interpolate(x, scale_factor=self.factor, mode="nearest")
160
+ if self.use_conv:
161
+ x = self.conv(x)
162
+ return x
163
+
164
+
165
+ class Downsample(nn.Module):
166
+ """
167
+ A downsampling layer with an optional convolution.
168
+
169
+ :param channels: channels in the inputs and outputs.
170
+ :param use_conv: a bool determining if a convolution is applied.
171
+ """
172
+
173
+ def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2):
174
+ super().__init__()
175
+ self.channels = channels
176
+ self.out_channels = out_channels or channels
177
+ self.use_conv = use_conv
178
+
179
+ stride = factor
180
+ if use_conv:
181
+ self.op = nn.Conv1d(
182
+ self.channels, self.out_channels, ksize, stride=stride, padding=pad
183
+ )
184
+ else:
185
+ assert self.channels == self.out_channels
186
+ self.op = nn.AvgPool1d(kernel_size=stride, stride=stride)
187
+
188
+ def forward(self, x):
189
+ assert x.shape[1] == self.channels
190
+ return self.op(x)
191
+
192
+
193
+ class ResBlock(nn.Module):
194
+ def __init__(
195
+ self,
196
+ channels,
197
+ dropout,
198
+ out_channels=None,
199
+ use_conv=False,
200
+ use_scale_shift_norm=False,
201
+ up=False,
202
+ down=False,
203
+ kernel_size=3,
204
+ ):
205
+ super().__init__()
206
+ self.channels = channels
207
+ self.dropout = dropout
208
+ self.out_channels = out_channels or channels
209
+ self.use_conv = use_conv
210
+ self.use_scale_shift_norm = use_scale_shift_norm
211
+ padding = 1 if kernel_size == 3 else 2
212
+
213
+ self.in_layers = nn.Sequential(
214
+ normalization(channels),
215
+ nn.SiLU(),
216
+ nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
217
+ )
218
+
219
+ self.updown = up or down
220
+
221
+ if up:
222
+ self.h_upd = Upsample(channels, False)
223
+ self.x_upd = Upsample(channels, False)
224
+ elif down:
225
+ self.h_upd = Downsample(channels, False)
226
+ self.x_upd = Downsample(channels, False)
227
+ else:
228
+ self.h_upd = self.x_upd = nn.Identity()
229
+
230
+ self.out_layers = nn.Sequential(
231
+ normalization(self.out_channels),
232
+ nn.SiLU(),
233
+ nn.Dropout(p=dropout),
234
+ zero_module(
235
+ nn.Conv1d(
236
+ self.out_channels, self.out_channels, kernel_size, padding=padding
237
+ )
238
+ ),
239
+ )
240
+
241
+ if self.out_channels == channels:
242
+ self.skip_connection = nn.Identity()
243
+ elif use_conv:
244
+ self.skip_connection = nn.Conv1d(
245
+ channels, self.out_channels, kernel_size, padding=padding
246
+ )
247
+ else:
248
+ self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
249
+
250
+ def forward(self, x):
251
+ if self.updown:
252
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
253
+ h = in_rest(x)
254
+ h = self.h_upd(h)
255
+ x = self.x_upd(x)
256
+ h = in_conv(h)
257
+ else:
258
+ h = self.in_layers(x)
259
+ h = self.out_layers(h)
260
+ return self.skip_connection(x) + h
261
+
262
+
263
+ class AudioMiniEncoder(nn.Module):
264
+ def __init__(
265
+ self,
266
+ spec_dim,
267
+ embedding_dim,
268
+ base_channels=128,
269
+ depth=2,
270
+ resnet_blocks=2,
271
+ attn_blocks=4,
272
+ num_attn_heads=4,
273
+ dropout=0,
274
+ downsample_factor=2,
275
+ kernel_size=3,
276
+ ):
277
+ super().__init__()
278
+ self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1))
279
+ ch = base_channels
280
+ res = []
281
+ for l in range(depth):
282
+ for r in range(resnet_blocks):
283
+ res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
284
+ res.append(
285
+ Downsample(
286
+ ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor
287
+ )
288
+ )
289
+ ch *= 2
290
+ self.res = nn.Sequential(*res)
291
+ self.final = nn.Sequential(
292
+ normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)
293
+ )
294
+ attn = []
295
+ for a in range(attn_blocks):
296
+ attn.append(
297
+ AttentionBlock(
298
+ embedding_dim,
299
+ num_attn_heads,
300
+ )
301
+ )
302
+ self.attn = nn.Sequential(*attn)
303
+ self.dim = embedding_dim
304
+
305
+ def forward(self, x):
306
+ h = self.init(x)
307
+ h = self.res(h)
308
+ h = self.final(h)
309
+ h = self.attn(h)
310
+ return h[:, :, 0]
311
+
312
+
313
+ DEFAULT_MEL_NORM_FILE = os.path.join(
314
+ os.path.dirname(os.path.realpath(__file__)), "../data/mel_norms.pth"
315
+ )
316
+
317
+
318
+ class TorchMelSpectrogram(nn.Module):
319
+ def __init__(
320
+ self,
321
+ filter_length=1024,
322
+ hop_length=256,
323
+ win_length=1024,
324
+ n_mel_channels=80,
325
+ mel_fmin=0,
326
+ mel_fmax=8000,
327
+ sampling_rate=22050,
328
+ normalize=False,
329
+ mel_norm_file=DEFAULT_MEL_NORM_FILE,
330
+ ):
331
+ super().__init__()
332
+ # These are the default tacotron values for the MEL spectrogram.
333
+ self.filter_length = filter_length
334
+ self.hop_length = hop_length
335
+ self.win_length = win_length
336
+ self.n_mel_channels = n_mel_channels
337
+ self.mel_fmin = mel_fmin
338
+ self.mel_fmax = mel_fmax
339
+ self.sampling_rate = sampling_rate
340
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
341
+ n_fft=self.filter_length,
342
+ hop_length=self.hop_length,
343
+ win_length=self.win_length,
344
+ power=2,
345
+ normalized=normalize,
346
+ sample_rate=self.sampling_rate,
347
+ f_min=self.mel_fmin,
348
+ f_max=self.mel_fmax,
349
+ n_mels=self.n_mel_channels,
350
+ norm="slaney",
351
+ )
352
+ self.mel_norm_file = mel_norm_file
353
+ if self.mel_norm_file is not None:
354
+ self.mel_norms = torch.load(self.mel_norm_file)
355
+ else:
356
+ self.mel_norms = None
357
+
358
+ def forward(self, inp):
359
+ if (
360
+ len(inp.shape) == 3
361
+ ): # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
362
+ inp = inp.squeeze(1)
363
+ assert len(inp.shape) == 2
364
+ self.mel_stft = self.mel_stft.to(inp.device)
365
+ mel = self.mel_stft(inp)
366
+ # Perform dynamic range compression
367
+ mel = torch.log(torch.clamp(mel, min=1e-5))
368
+ if self.mel_norms is not None:
369
+ self.mel_norms = self.mel_norms.to(mel.device)
370
+ mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
371
+ return mel
372
+
373
+
374
+ class CheckpointedLayer(nn.Module):
375
+ """
376
+ Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
377
+ checkpoint for all other args.
378
+ """
379
+
380
+ def __init__(self, wrap):
381
+ super().__init__()
382
+ self.wrap = wrap
383
+
384
+ def forward(self, x, *args, **kwargs):
385
+ for k, v in kwargs.items():
386
+ assert not (
387
+ isinstance(v, torch.Tensor) and v.requires_grad
388
+ ) # This would screw up checkpointing.
389
+ partial = functools.partial(self.wrap, **kwargs)
390
+ return partial(x, *args)
391
+
392
+
393
+ class CheckpointedXTransformerEncoder(nn.Module):
394
+ """
395
+ Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
396
+ to channels-last that XTransformer expects.
397
+ """
398
+
399
+ def __init__(
400
+ self,
401
+ needs_permute=True,
402
+ exit_permute=True,
403
+ checkpoint=True,
404
+ **xtransformer_kwargs,
405
+ ):
406
+ super().__init__()
407
+ self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
408
+ self.needs_permute = needs_permute
409
+ self.exit_permute = exit_permute
410
+
411
+ if not checkpoint:
412
+ return
413
+ for i in range(len(self.transformer.attn_layers.layers)):
414
+ n, b, r = self.transformer.attn_layers.layers[i]
415
+ self.transformer.attn_layers.layers[i] = nn.ModuleList(
416
+ [n, CheckpointedLayer(b), r]
417
+ )
418
+
419
+ def forward(self, x, **kwargs):
420
+ if self.needs_permute:
421
+ x = x.permute(0, 2, 1)
422
+ h = self.transformer(x, **kwargs)
423
+ if self.exit_permute:
424
+ h = h.permute(0, 2, 1)
425
+ return h
tortoise/models/autoregressive.py ADDED
@@ -0,0 +1,810 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AGPL: a notification must be added stating that changes have been made to that file.
2
+ import functools
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
8
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
9
+
10
+ from tortoise.models.arch_util import AttentionBlock
11
+ from tortoise.utils.typical_sampling import TypicalLogitsWarper
12
+
13
+
14
+ def null_position_embeddings(range, dim):
15
+ return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
16
+
17
+
18
+ def _p(t):
19
+ return t and (len(t), len(t[0]), t[0][0].shape) # kv_cache debug
20
+
21
+
22
+ class ResBlock(nn.Module):
23
+ """
24
+ Basic residual convolutional block that uses GroupNorm.
25
+ """
26
+
27
+ def __init__(self, chan):
28
+ super().__init__()
29
+ self.net = nn.Sequential(
30
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
31
+ nn.GroupNorm(chan // 8, chan),
32
+ nn.ReLU(),
33
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
34
+ nn.GroupNorm(chan // 8, chan),
35
+ )
36
+
37
+ def forward(self, x):
38
+ return F.relu(self.net(x) + x)
39
+
40
+
41
+ class GPT2InferenceModel(GPT2PreTrainedModel):
42
+ def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache):
43
+ super().__init__(config)
44
+ self.transformer = gpt
45
+ self.text_pos_embedding = text_pos_emb
46
+ self.embeddings = embeddings
47
+ self.lm_head = nn.Sequential(norm, linear)
48
+ self.kv_cache = kv_cache
49
+
50
+ def store_mel_emb(self, mel_emb):
51
+ self.cached_mel_emb = mel_emb
52
+
53
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
54
+ token_type_ids = kwargs.get("token_type_ids", None) # usually None
55
+ if not self.kv_cache:
56
+ past_key_values = None
57
+ # only last token for inputs_ids if past is defined in kwargs
58
+ if past_key_values:
59
+ input_ids = input_ids[:, -1].unsqueeze(-1)
60
+ if token_type_ids is not None:
61
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
62
+
63
+ attention_mask = kwargs.get("attention_mask", None)
64
+ position_ids = kwargs.get("position_ids", None)
65
+
66
+ if attention_mask is not None and position_ids is None:
67
+ # create position_ids on the fly for batch generation
68
+ position_ids = attention_mask.long().cumsum(-1) - 1
69
+ position_ids.masked_fill_(attention_mask == 0, 1)
70
+ if past_key_values:
71
+ position_ids = position_ids[:, -1].unsqueeze(-1)
72
+ else:
73
+ position_ids = None
74
+ return {
75
+ "input_ids": input_ids,
76
+ "past_key_values": past_key_values,
77
+ "use_cache": kwargs.get("use_cache"),
78
+ "position_ids": position_ids,
79
+ "attention_mask": attention_mask,
80
+ "token_type_ids": token_type_ids,
81
+ }
82
+
83
+ def forward(
84
+ self,
85
+ input_ids=None,
86
+ past_key_values=None,
87
+ attention_mask=None,
88
+ token_type_ids=None,
89
+ position_ids=None,
90
+ head_mask=None,
91
+ inputs_embeds=None,
92
+ encoder_hidden_states=None,
93
+ encoder_attention_mask=None,
94
+ labels=None,
95
+ use_cache=None,
96
+ output_attentions=None,
97
+ output_hidden_states=None,
98
+ return_dict=None,
99
+ ):
100
+ assert self.cached_mel_emb is not None
101
+ assert inputs_embeds is None # Not supported by this inference model.
102
+ assert labels is None # Training not supported by this inference model.
103
+ return_dict = (
104
+ return_dict if return_dict is not None else self.config.use_return_dict
105
+ )
106
+
107
+ # Create embedding
108
+ mel_len = self.cached_mel_emb.shape[1]
109
+ if input_ids.shape[1] != 1:
110
+ text_inputs = input_ids[:, mel_len:]
111
+ text_emb = self.embeddings(text_inputs)
112
+ text_emb = text_emb + self.text_pos_embedding(text_emb)
113
+ if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
114
+ mel_emb = self.cached_mel_emb.repeat_interleave(
115
+ text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
116
+ )
117
+ else: # this outcome only occurs once per loop in most cases
118
+ mel_emb = self.cached_mel_emb
119
+ emb = torch.cat([mel_emb, text_emb], dim=1)
120
+ else:
121
+ emb = self.embeddings(input_ids)
122
+ emb = emb + self.text_pos_embedding.get_fixed_embedding(
123
+ attention_mask.shape[1] - mel_len, attention_mask.device
124
+ )
125
+
126
+ transformer_outputs = self.transformer(
127
+ inputs_embeds=emb,
128
+ past_key_values=past_key_values,
129
+ attention_mask=attention_mask,
130
+ token_type_ids=token_type_ids,
131
+ position_ids=position_ids,
132
+ head_mask=head_mask,
133
+ encoder_hidden_states=encoder_hidden_states,
134
+ encoder_attention_mask=encoder_attention_mask,
135
+ use_cache=use_cache,
136
+ output_attentions=output_attentions,
137
+ output_hidden_states=output_hidden_states,
138
+ return_dict=return_dict,
139
+ )
140
+ hidden_states = transformer_outputs[0]
141
+ lm_logits = self.lm_head(hidden_states)
142
+
143
+ if not return_dict:
144
+ return (lm_logits,) + transformer_outputs[1:]
145
+
146
+ return CausalLMOutputWithCrossAttentions(
147
+ loss=None,
148
+ logits=lm_logits,
149
+ past_key_values=transformer_outputs.past_key_values,
150
+ hidden_states=transformer_outputs.hidden_states,
151
+ attentions=transformer_outputs.attentions,
152
+ cross_attentions=transformer_outputs.cross_attentions,
153
+ )
154
+
155
+ @staticmethod
156
+ def _reorder_cache(past, beam_idx):
157
+ """
158
+ This function is used to re-order the :obj:`past_key_values` cache if
159
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
160
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
161
+ """
162
+ return tuple(
163
+ tuple(
164
+ past_state.index_select(0, beam_idx.to(past_state.device))
165
+ for past_state in layer_past
166
+ )
167
+ for layer_past in past
168
+ )
169
+
170
+
171
+ class ConditioningEncoder(nn.Module):
172
+ def __init__(
173
+ self,
174
+ spec_dim,
175
+ embedding_dim,
176
+ attn_blocks=6,
177
+ num_attn_heads=4,
178
+ do_checkpointing=False,
179
+ mean=False,
180
+ ):
181
+ super().__init__()
182
+ attn = []
183
+ self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
184
+ for a in range(attn_blocks):
185
+ attn.append(AttentionBlock(embedding_dim, num_attn_heads))
186
+ self.attn = nn.Sequential(*attn)
187
+ self.dim = embedding_dim
188
+ self.do_checkpointing = do_checkpointing
189
+ self.mean = mean
190
+
191
+ def forward(self, x):
192
+ h = self.init(x)
193
+ h = self.attn(h)
194
+ if self.mean:
195
+ return h.mean(dim=2)
196
+ else:
197
+ return h[:, :, 0]
198
+
199
+
200
+ class LearnedPositionEmbeddings(nn.Module):
201
+ def __init__(self, seq_len, model_dim, init=0.02):
202
+ super().__init__()
203
+ self.emb = nn.Embedding(seq_len, model_dim)
204
+ # Initializing this way is standard for GPT-2
205
+ self.emb.weight.data.normal_(mean=0.0, std=init)
206
+
207
+ def forward(self, x):
208
+ sl = x.shape[1]
209
+ return self.emb(torch.arange(0, sl, device=x.device))
210
+
211
+ def get_fixed_embedding(self, ind, dev):
212
+ return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind]
213
+
214
+
215
+ def build_hf_gpt_transformer(
216
+ layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing
217
+ ):
218
+ """
219
+ GPT-2 implemented by the HuggingFace library.
220
+ """
221
+ from transformers import GPT2Config, GPT2Model
222
+
223
+ gpt_config = GPT2Config(
224
+ vocab_size=256, # Unused.
225
+ n_positions=max_mel_seq_len + max_text_seq_len,
226
+ n_ctx=max_mel_seq_len + max_text_seq_len,
227
+ n_embd=model_dim,
228
+ n_layer=layers,
229
+ n_head=heads,
230
+ gradient_checkpointing=checkpointing,
231
+ use_cache=not checkpointing,
232
+ )
233
+ gpt = GPT2Model(gpt_config)
234
+ # Override the built in positional embeddings
235
+ del (
236
+ gpt.wpe
237
+ ) # TODO: figure out relevance in fixing exported model definition: Embedding(1012, 1024)
238
+ gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
239
+ # Built-in token embeddings are unused.
240
+ del gpt.wte
241
+ return (
242
+ gpt,
243
+ LearnedPositionEmbeddings(max_mel_seq_len, model_dim),
244
+ LearnedPositionEmbeddings(max_text_seq_len, model_dim),
245
+ None,
246
+ None,
247
+ )
248
+
249
+
250
+ class MelEncoder(nn.Module):
251
+ def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
252
+ super().__init__()
253
+ self.channels = channels
254
+ self.encoder = nn.Sequential(
255
+ nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
256
+ nn.Sequential(
257
+ *[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]
258
+ ),
259
+ nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
260
+ nn.GroupNorm(channels // 16, channels // 2),
261
+ nn.ReLU(),
262
+ nn.Sequential(
263
+ *[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]
264
+ ),
265
+ nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
266
+ nn.GroupNorm(channels // 8, channels),
267
+ nn.ReLU(),
268
+ nn.Sequential(
269
+ *[ResBlock(channels) for _ in range(resblocks_per_reduction)]
270
+ ),
271
+ )
272
+ self.reduction = 4
273
+
274
+ def forward(self, x):
275
+ for e in self.encoder:
276
+ x = e(x)
277
+ return x.permute(0, 2, 1)
278
+
279
+
280
+ class UnifiedVoice(nn.Module):
281
+ def __init__(
282
+ self,
283
+ layers=8,
284
+ model_dim=512,
285
+ heads=8,
286
+ max_text_tokens=120,
287
+ max_mel_tokens=250,
288
+ max_conditioning_inputs=1,
289
+ mel_length_compression=1024,
290
+ number_text_tokens=256,
291
+ start_text_token=None,
292
+ number_mel_codes=8194,
293
+ start_mel_token=8192,
294
+ stop_mel_token=8193,
295
+ train_solo_embeddings=False,
296
+ use_mel_codes_as_input=True,
297
+ checkpointing=True,
298
+ types=1,
299
+ ):
300
+ """
301
+ Args:
302
+ layers: Number of layers in transformer stack.
303
+ model_dim: Operating dimensions of the transformer
304
+ heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
305
+ max_text_tokens: Maximum number of text tokens that will be encountered by model.
306
+ max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
307
+ max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
308
+ mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
309
+ number_text_tokens:
310
+ start_text_token:
311
+ stop_text_token:
312
+ number_mel_codes:
313
+ start_mel_token:
314
+ stop_mel_token:
315
+ train_solo_embeddings:
316
+ use_mel_codes_as_input:
317
+ checkpointing:
318
+ """
319
+ super().__init__()
320
+
321
+ self.number_text_tokens = number_text_tokens
322
+ self.start_text_token = (
323
+ number_text_tokens * types if start_text_token is None else start_text_token
324
+ )
325
+ self.stop_text_token = 0
326
+ self.number_mel_codes = number_mel_codes
327
+ self.start_mel_token = start_mel_token
328
+ self.stop_mel_token = stop_mel_token
329
+ self.layers = layers
330
+ self.heads = heads
331
+ self.max_mel_tokens = max_mel_tokens
332
+ self.max_text_tokens = max_text_tokens
333
+ self.model_dim = model_dim
334
+ self.max_conditioning_inputs = max_conditioning_inputs
335
+ self.mel_length_compression = mel_length_compression
336
+ self.conditioning_encoder = ConditioningEncoder(
337
+ 80, model_dim, num_attn_heads=heads
338
+ )
339
+ self.text_embedding = nn.Embedding(
340
+ self.number_text_tokens * types + 1, model_dim
341
+ )
342
+ if use_mel_codes_as_input:
343
+ self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
344
+ else:
345
+ self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
346
+ (
347
+ self.gpt,
348
+ self.mel_pos_embedding,
349
+ self.text_pos_embedding,
350
+ self.mel_layer_pos_embedding,
351
+ self.text_layer_pos_embedding,
352
+ ) = build_hf_gpt_transformer(
353
+ layers,
354
+ model_dim,
355
+ heads,
356
+ self.max_mel_tokens + 2 + self.max_conditioning_inputs,
357
+ self.max_text_tokens + 2,
358
+ checkpointing,
359
+ )
360
+ if train_solo_embeddings:
361
+ self.mel_solo_embedding = nn.Parameter(
362
+ torch.randn(1, 1, model_dim) * 0.02, requires_grad=True
363
+ )
364
+ self.text_solo_embedding = nn.Parameter(
365
+ torch.randn(1, 1, model_dim) * 0.02, requires_grad=True
366
+ )
367
+ else:
368
+ self.mel_solo_embedding = 0
369
+ self.text_solo_embedding = 0
370
+
371
+ self.final_norm = nn.LayerNorm(model_dim)
372
+ self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1)
373
+ self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
374
+
375
+ # Initialize the embeddings per the GPT-2 scheme
376
+ embeddings = [self.text_embedding]
377
+ if use_mel_codes_as_input:
378
+ embeddings.append(self.mel_embedding)
379
+ for module in embeddings:
380
+ module.weight.data.normal_(mean=0.0, std=0.02)
381
+
382
+ def post_init_gpt2_config(self, kv_cache=True):
383
+ seq_length = self.max_mel_tokens + self.max_text_tokens + 2
384
+ gpt_config = GPT2Config(
385
+ vocab_size=self.max_mel_tokens,
386
+ n_positions=seq_length,
387
+ n_ctx=seq_length,
388
+ n_embd=self.model_dim,
389
+ n_layer=self.layers,
390
+ n_head=self.heads,
391
+ gradient_checkpointing=False,
392
+ use_cache=True,
393
+ )
394
+ self.inference_model = GPT2InferenceModel(
395
+ gpt_config,
396
+ self.gpt,
397
+ self.mel_pos_embedding,
398
+ self.mel_embedding,
399
+ self.final_norm,
400
+ self.mel_head,
401
+ kv_cache=kv_cache,
402
+ )
403
+ # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
404
+ self.gpt.wte = self.mel_embedding
405
+
406
+ def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
407
+ inp = F.pad(input, (1, 0), value=start_token)
408
+ tar = F.pad(input, (0, 1), value=stop_token)
409
+ return inp, tar
410
+
411
+ def set_mel_padding(self, mel_input_tokens, wav_lengths):
412
+ """
413
+ Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
414
+ that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
415
+ preformatting to create a working TTS model.
416
+ """
417
+ # Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
418
+ mel_lengths = torch.div(
419
+ wav_lengths, self.mel_length_compression, rounding_mode="trunc"
420
+ )
421
+ for b in range(len(mel_lengths)):
422
+ actual_end = (
423
+ mel_lengths[b] + 1
424
+ ) # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
425
+ if actual_end < mel_input_tokens.shape[-1]:
426
+ mel_input_tokens[b, actual_end:] = self.stop_mel_token
427
+ return mel_input_tokens
428
+
429
+ def get_logits(
430
+ self,
431
+ speech_conditioning_inputs,
432
+ first_inputs,
433
+ first_head,
434
+ second_inputs=None,
435
+ second_head=None,
436
+ get_attns=False,
437
+ return_latent=False,
438
+ ):
439
+ if second_inputs is not None:
440
+ emb = torch.cat(
441
+ [speech_conditioning_inputs, first_inputs, second_inputs], dim=1
442
+ )
443
+ else:
444
+ emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
445
+
446
+ gpt_out = self.gpt(
447
+ inputs_embeds=emb, return_dict=True, output_attentions=get_attns
448
+ )
449
+ if get_attns:
450
+ return gpt_out.attentions
451
+
452
+ enc = gpt_out.last_hidden_state[
453
+ :, 1:
454
+ ] # The first logit is tied to the speech_conditioning_input
455
+ enc = self.final_norm(enc)
456
+
457
+ if return_latent:
458
+ return (
459
+ enc[
460
+ :,
461
+ speech_conditioning_inputs.shape[
462
+ 1
463
+ ] : speech_conditioning_inputs.shape[1]
464
+ + first_inputs.shape[1],
465
+ ],
466
+ enc[:, -second_inputs.shape[1] :],
467
+ )
468
+
469
+ first_logits = enc[:, : first_inputs.shape[1]]
470
+ first_logits = first_head(first_logits)
471
+ first_logits = first_logits.permute(0, 2, 1)
472
+ if second_inputs is not None:
473
+ second_logits = enc[:, -second_inputs.shape[1] :]
474
+ second_logits = second_head(second_logits)
475
+ second_logits = second_logits.permute(0, 2, 1)
476
+ return first_logits, second_logits
477
+ else:
478
+ return first_logits
479
+
480
+ def get_conditioning(self, speech_conditioning_input):
481
+ speech_conditioning_input = (
482
+ speech_conditioning_input.unsqueeze(1)
483
+ if len(speech_conditioning_input.shape) == 3
484
+ else speech_conditioning_input
485
+ )
486
+ conds = []
487
+ for j in range(speech_conditioning_input.shape[1]):
488
+ conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
489
+ conds = torch.stack(conds, dim=1)
490
+ conds = conds.mean(dim=1)
491
+ return conds
492
+
493
+ def forward(
494
+ self,
495
+ speech_conditioning_latent,
496
+ text_inputs,
497
+ text_lengths,
498
+ mel_codes,
499
+ wav_lengths,
500
+ types=None,
501
+ text_first=True,
502
+ raw_mels=None,
503
+ return_attentions=False,
504
+ return_latent=False,
505
+ clip_inputs=True,
506
+ ):
507
+ """
508
+ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
509
+ (actuated by `text_first`).
510
+
511
+ speech_conditioning_input: MEL float tensor, (b,1024)
512
+ text_inputs: long tensor, (b,t)
513
+ text_lengths: long tensor, (b,)
514
+ mel_inputs: long tensor, (b,m)
515
+ wav_lengths: long tensor, (b,)
516
+ raw_mels: MEL float tensor (b,80,s)
517
+
518
+ If return_attentions is specified, only logits are returned.
519
+ If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
520
+ If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
521
+ """
522
+ # Types are expressed by expanding the text embedding space.
523
+ if types is not None:
524
+ text_inputs = text_inputs * (1 + types).unsqueeze(-1)
525
+
526
+ if clip_inputs:
527
+ # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
528
+ # chopping the inputs by the maximum actual length.
529
+ max_text_len = text_lengths.max()
530
+ text_inputs = text_inputs[:, :max_text_len]
531
+ max_mel_len = wav_lengths.max() // self.mel_length_compression
532
+ mel_codes = mel_codes[:, :max_mel_len]
533
+ if raw_mels is not None:
534
+ raw_mels = raw_mels[:, :, : max_mel_len * 4]
535
+ mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
536
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
537
+ mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
538
+
539
+ conds = speech_conditioning_latent.unsqueeze(1)
540
+ text_inputs, text_targets = self.build_aligned_inputs_and_targets(
541
+ text_inputs, self.start_text_token, self.stop_text_token
542
+ )
543
+ text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(
544
+ text_inputs
545
+ )
546
+ mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
547
+ mel_codes, self.start_mel_token, self.stop_mel_token
548
+ )
549
+ if raw_mels is not None:
550
+ mel_inp = F.pad(raw_mels, (0, 8))
551
+ else:
552
+ mel_inp = mel_codes
553
+ mel_emb = self.mel_embedding(mel_inp)
554
+ mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
555
+
556
+ if text_first:
557
+ text_logits, mel_logits = self.get_logits(
558
+ conds,
559
+ text_emb,
560
+ self.text_head,
561
+ mel_emb,
562
+ self.mel_head,
563
+ get_attns=return_attentions,
564
+ return_latent=return_latent,
565
+ )
566
+ if return_latent:
567
+ return mel_logits[
568
+ :, :-2
569
+ ] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
570
+ else:
571
+ mel_logits, text_logits = self.get_logits(
572
+ conds,
573
+ mel_emb,
574
+ self.mel_head,
575
+ text_emb,
576
+ self.text_head,
577
+ get_attns=return_attentions,
578
+ return_latent=return_latent,
579
+ )
580
+ if return_latent:
581
+ return text_logits[
582
+ :, :-2
583
+ ] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
584
+
585
+ if return_attentions:
586
+ return mel_logits
587
+ loss_text = F.cross_entropy(text_logits, text_targets.long())
588
+ loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
589
+ return loss_text.mean(), loss_mel.mean(), mel_logits
590
+
591
+ def inference_speech(
592
+ self,
593
+ speech_conditioning_latent,
594
+ text_inputs,
595
+ input_tokens=None,
596
+ num_return_sequences=1,
597
+ max_generate_length=None,
598
+ typical_sampling=False,
599
+ typical_mass=0.9,
600
+ **hf_generate_kwargs
601
+ ):
602
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
603
+ text_inputs, text_targets = self.build_aligned_inputs_and_targets(
604
+ text_inputs, self.start_text_token, self.stop_text_token
605
+ )
606
+ text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(
607
+ text_inputs
608
+ )
609
+
610
+ conds = speech_conditioning_latent.unsqueeze(1)
611
+ emb = torch.cat([conds, text_emb], dim=1)
612
+ self.inference_model.store_mel_emb(emb)
613
+
614
+ fake_inputs = torch.full(
615
+ (
616
+ emb.shape[0],
617
+ conds.shape[1] + emb.shape[1],
618
+ ),
619
+ fill_value=1,
620
+ dtype=torch.long,
621
+ device=text_inputs.device,
622
+ )
623
+ fake_inputs[:, -1] = self.start_mel_token
624
+ trunc_index = fake_inputs.shape[1]
625
+ if input_tokens is None:
626
+ inputs = fake_inputs
627
+ else:
628
+ assert (
629
+ num_return_sequences % input_tokens.shape[0] == 0
630
+ ), "The number of return sequences must be divisible by the number of input sequences"
631
+ fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
632
+ input_tokens = input_tokens.repeat(
633
+ num_return_sequences // input_tokens.shape[0], 1
634
+ )
635
+ inputs = torch.cat([fake_inputs, input_tokens], dim=1)
636
+
637
+ logits_processor = (
638
+ LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)])
639
+ if typical_sampling
640
+ else LogitsProcessorList()
641
+ ) # TODO disable this
642
+ max_length = (
643
+ trunc_index + self.max_mel_tokens - 1
644
+ if max_generate_length is None
645
+ else trunc_index + max_generate_length
646
+ )
647
+ gen = self.inference_model.generate(
648
+ inputs,
649
+ bos_token_id=self.start_mel_token,
650
+ pad_token_id=self.stop_mel_token,
651
+ eos_token_id=self.stop_mel_token,
652
+ max_length=max_length,
653
+ logits_processor=logits_processor,
654
+ num_return_sequences=num_return_sequences,
655
+ **hf_generate_kwargs
656
+ )
657
+ return gen[:, trunc_index:]
658
+
659
+
660
+ class PrunedGPT2InferenceModel(GPT2PreTrainedModel):
661
+ def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear):
662
+ super().__init__(config)
663
+ self.transformer = gpt
664
+ self.text_pos_embedding = text_pos_emb
665
+ self.embeddings = embeddings
666
+ self.lm_head = nn.Sequential(norm, linear)
667
+
668
+ def store_mel_emb(self, mel_emb):
669
+ self.cached_mel_emb = mel_emb
670
+
671
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
672
+ token_type_ids = kwargs.get("token_type_ids", None)
673
+ # only last token for inputs_ids if past is defined in kwargs
674
+ print(past)
675
+ if past:
676
+ input_ids = input_ids[:, -1].unsqueeze(-1)
677
+ if token_type_ids is not None:
678
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
679
+
680
+ attention_mask = kwargs.get("attention_mask", None)
681
+ position_ids = kwargs.get("position_ids", None)
682
+
683
+ if attention_mask is not None and position_ids is None:
684
+ # create position_ids on the fly for batch generation
685
+ print(position_ids)
686
+ position_ids = attention_mask.long().cumsum(-1) - 1
687
+ position_ids.masked_fill_(attention_mask == 0, 1)
688
+ print(position_ids)
689
+ if past:
690
+ position_ids = position_ids[:, -1].unsqueeze(-1)
691
+ else:
692
+ position_ids = None
693
+ return {
694
+ "input_ids": input_ids,
695
+ "past_key_values": past,
696
+ "use_cache": kwargs.get("use_cache"),
697
+ "position_ids": position_ids,
698
+ "attention_mask": attention_mask,
699
+ "token_type_ids": token_type_ids,
700
+ }
701
+
702
+ def forward(self, input_ids=None, attention_mask=None, position_ids=None, **kwargs):
703
+ past_key_values = None
704
+ token_type_ids = None
705
+ head_mask = None
706
+ inputs_embeds = None
707
+ encoder_hidden_states = None
708
+ encoder_attention_mask = None
709
+ labels = None
710
+ use_cache = True
711
+ output_attentions = False
712
+ output_hidden_states = False
713
+ return_dict = True
714
+ #
715
+ assert self.cached_mel_emb is not None
716
+ assert inputs_embeds is None # Not supported by this inference model.
717
+ assert labels is None # Training not supported by this inference model.
718
+ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
719
+ """
720
+ print(attention_mask)
721
+ print(position_ids)
722
+ print(attention_mask.dtype)
723
+ print(position_ids.dtype)
724
+ """
725
+
726
+ """
727
+ attention_mask=tensor([[1, 1, 1, ..., 1, 1, 1],
728
+ [1, 1, 1, ..., 1, 1, 1],
729
+ [1, 1, 1, ..., 1, 1, 1],
730
+ ...,
731
+ [1, 1, 1, ..., 1, 1, 1],
732
+ [1, 1, 1, ..., 1, 1, 1],
733
+ [1, 1, 1, ..., 1, 1, 1]], device='cuda:0')
734
+ """
735
+
736
+ # Create embedding
737
+ mel_len = self.cached_mel_emb.shape[1]
738
+ text_inputs = input_ids[:, mel_len:]
739
+ text_emb = self.embeddings(text_inputs)
740
+ text_emb = text_emb + self.text_pos_embedding(text_emb)
741
+ mel_emb = self.cached_mel_emb.repeat_interleave(
742
+ text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
743
+ )
744
+ emb = torch.cat([mel_emb, text_emb], dim=1)
745
+
746
+ transformer_outputs = self.transformer(
747
+ inputs_embeds=emb,
748
+ past_key_values=past_key_values,
749
+ attention_mask=attention_mask,
750
+ token_type_ids=token_type_ids,
751
+ position_ids=position_ids,
752
+ head_mask=head_mask,
753
+ encoder_hidden_states=encoder_hidden_states,
754
+ encoder_attention_mask=encoder_attention_mask,
755
+ use_cache=use_cache,
756
+ output_attentions=output_attentions,
757
+ output_hidden_states=output_hidden_states,
758
+ return_dict=return_dict,
759
+ )
760
+ hidden_states = transformer_outputs[0]
761
+
762
+ lm_logits = self.lm_head(hidden_states)
763
+
764
+ if not return_dict:
765
+ return (lm_logits,) + transformer_outputs[1:]
766
+ return CausalLMOutputWithCrossAttentions(
767
+ loss=None,
768
+ logits=lm_logits,
769
+ past_key_values=transformer_outputs.past_key_values,
770
+ hidden_states=transformer_outputs.hidden_states,
771
+ attentions=transformer_outputs.attentions,
772
+ cross_attentions=transformer_outputs.cross_attentions,
773
+ )
774
+
775
+ @staticmethod
776
+ def _reorder_cache(past, beam_idx):
777
+ """
778
+ This function is used to re-order the :obj:`past_key_values` cache if
779
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
780
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
781
+ """
782
+ return tuple(
783
+ tuple(
784
+ past_state.index_select(0, beam_idx.to(past_state.device))
785
+ for past_state in layer_past
786
+ )
787
+ for layer_past in past
788
+ )
789
+
790
+
791
+ if __name__ == "__main__":
792
+ gpt = UnifiedVoice(
793
+ model_dim=256,
794
+ heads=4,
795
+ train_solo_embeddings=True,
796
+ use_mel_codes_as_input=True,
797
+ max_conditioning_inputs=4,
798
+ )
799
+ l = gpt(
800
+ torch.randn(2, 3, 80, 800),
801
+ torch.randint(high=120, size=(2, 120)),
802
+ torch.tensor([32, 120]),
803
+ torch.randint(high=8192, size=(2, 250)),
804
+ torch.tensor([250 * 256, 195 * 256]),
805
+ )
806
+ gpt.text_forward(
807
+ torch.randn(2, 80, 800),
808
+ torch.randint(high=50, size=(2, 80)),
809
+ torch.tensor([32, 80]),
810
+ )
tortoise/models/classifier.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from tortoise.models.arch_util import (
5
+ AttentionBlock,
6
+ Downsample,
7
+ Upsample,
8
+ normalization,
9
+ zero_module,
10
+ )
11
+
12
+
13
+ class ResBlock(nn.Module):
14
+ def __init__(
15
+ self,
16
+ channels,
17
+ dropout,
18
+ out_channels=None,
19
+ use_conv=False,
20
+ use_scale_shift_norm=False,
21
+ dims=2,
22
+ up=False,
23
+ down=False,
24
+ kernel_size=3,
25
+ do_checkpoint=True,
26
+ ):
27
+ super().__init__()
28
+ self.channels = channels
29
+ self.dropout = dropout
30
+ self.out_channels = out_channels or channels
31
+ self.use_conv = use_conv
32
+ self.use_scale_shift_norm = use_scale_shift_norm
33
+ self.do_checkpoint = do_checkpoint
34
+ padding = 1 if kernel_size == 3 else 2
35
+
36
+ self.in_layers = nn.Sequential(
37
+ normalization(channels),
38
+ nn.SiLU(),
39
+ nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
40
+ )
41
+
42
+ self.updown = up or down
43
+
44
+ if up:
45
+ self.h_upd = Upsample(channels, False, dims)
46
+ self.x_upd = Upsample(channels, False, dims)
47
+ elif down:
48
+ self.h_upd = Downsample(channels, False, dims)
49
+ self.x_upd = Downsample(channels, False, dims)
50
+ else:
51
+ self.h_upd = self.x_upd = nn.Identity()
52
+
53
+ self.out_layers = nn.Sequential(
54
+ normalization(self.out_channels),
55
+ nn.SiLU(),
56
+ nn.Dropout(p=dropout),
57
+ zero_module(
58
+ nn.Conv1d(
59
+ self.out_channels, self.out_channels, kernel_size, padding=padding
60
+ )
61
+ ),
62
+ )
63
+
64
+ if self.out_channels == channels:
65
+ self.skip_connection = nn.Identity()
66
+ elif use_conv:
67
+ self.skip_connection = nn.Conv1d(
68
+ dims, channels, self.out_channels, kernel_size, padding=padding
69
+ )
70
+ else:
71
+ self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
72
+
73
+ def forward(self, x):
74
+ if self.updown:
75
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
76
+ h = in_rest(x)
77
+ h = self.h_upd(h)
78
+ x = self.x_upd(x)
79
+ h = in_conv(h)
80
+ else:
81
+ h = self.in_layers(x)
82
+ h = self.out_layers(h)
83
+ return self.skip_connection(x) + h
84
+
85
+
86
+ class AudioMiniEncoder(nn.Module):
87
+ def __init__(
88
+ self,
89
+ spec_dim,
90
+ embedding_dim,
91
+ base_channels=128,
92
+ depth=2,
93
+ resnet_blocks=2,
94
+ attn_blocks=4,
95
+ num_attn_heads=4,
96
+ dropout=0,
97
+ downsample_factor=2,
98
+ kernel_size=3,
99
+ ):
100
+ super().__init__()
101
+ self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1))
102
+ ch = base_channels
103
+ res = []
104
+ self.layers = depth
105
+ for l in range(depth):
106
+ for r in range(resnet_blocks):
107
+ res.append(
108
+ ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size)
109
+ )
110
+ res.append(
111
+ Downsample(
112
+ ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor
113
+ )
114
+ )
115
+ ch *= 2
116
+ self.res = nn.Sequential(*res)
117
+ self.final = nn.Sequential(
118
+ normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)
119
+ )
120
+ attn = []
121
+ for a in range(attn_blocks):
122
+ attn.append(
123
+ AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)
124
+ )
125
+ self.attn = nn.Sequential(*attn)
126
+ self.dim = embedding_dim
127
+
128
+ def forward(self, x):
129
+ h = self.init(x)
130
+ h = self.res(h)
131
+ h = self.final(h)
132
+ for blk in self.attn:
133
+ h = blk(h)
134
+ return h[:, :, 0]
135
+
136
+
137
+ class AudioMiniEncoderWithClassifierHead(nn.Module):
138
+ def __init__(self, classes, distribute_zero_label=True, **kwargs):
139
+ super().__init__()
140
+ self.enc = AudioMiniEncoder(**kwargs)
141
+ self.head = nn.Linear(self.enc.dim, classes)
142
+ self.num_classes = classes
143
+ self.distribute_zero_label = distribute_zero_label
144
+
145
+ def forward(self, x, labels=None):
146
+ h = self.enc(x)
147
+ logits = self.head(h)
148
+ if labels is None:
149
+ return logits
150
+ else:
151
+ if self.distribute_zero_label:
152
+ oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes)
153
+ zeros_indices = (labels == 0).unsqueeze(-1)
154
+ # Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise.
155
+ zero_extra_mass = torch.full_like(
156
+ oh_labels,
157
+ dtype=torch.float,
158
+ fill_value=0.2 / (self.num_classes - 1),
159
+ )
160
+ zero_extra_mass[:, 0] = -0.2
161
+ zero_extra_mass = zero_extra_mass * zeros_indices
162
+ oh_labels = oh_labels + zero_extra_mass
163
+ else:
164
+ oh_labels = labels
165
+ loss = nn.functional.cross_entropy(logits, oh_labels)
166
+ return loss
tortoise/models/clvp.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import einsum
5
+
6
+ from tortoise.models.arch_util import CheckpointedXTransformerEncoder
7
+ from tortoise.models.transformer import Transformer
8
+ from tortoise.models.xtransformers import Encoder
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def masked_mean(t, mask, dim=1):
16
+ t = t.masked_fill(~mask[:, :, None], 0.0)
17
+ return t.sum(dim=1) / mask.sum(dim=1)[..., None]
18
+
19
+
20
+ class CLVP(nn.Module):
21
+ """
22
+ CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding
23
+ transcribed text.
24
+
25
+ Originally from https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ *,
31
+ dim_text=512,
32
+ dim_speech=512,
33
+ dim_latent=512,
34
+ num_text_tokens=256,
35
+ text_enc_depth=6,
36
+ text_seq_len=120,
37
+ text_heads=8,
38
+ num_speech_tokens=8192,
39
+ speech_enc_depth=6,
40
+ speech_heads=8,
41
+ speech_seq_len=250,
42
+ text_mask_percentage=0,
43
+ voice_mask_percentage=0,
44
+ wav_token_compression=1024,
45
+ use_xformers=False,
46
+ ):
47
+ super().__init__()
48
+ self.text_emb = nn.Embedding(num_text_tokens, dim_text)
49
+ self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
50
+
51
+ self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
52
+ self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
53
+
54
+ if use_xformers:
55
+ self.text_transformer = CheckpointedXTransformerEncoder(
56
+ needs_permute=False,
57
+ exit_permute=False,
58
+ max_seq_len=-1,
59
+ attn_layers=Encoder(
60
+ dim=dim_text,
61
+ depth=text_enc_depth,
62
+ heads=text_heads,
63
+ ff_dropout=0.1,
64
+ ff_mult=2,
65
+ attn_dropout=0.1,
66
+ use_rmsnorm=True,
67
+ ff_glu=True,
68
+ rotary_pos_emb=True,
69
+ ),
70
+ )
71
+ self.speech_transformer = CheckpointedXTransformerEncoder(
72
+ needs_permute=False,
73
+ exit_permute=False,
74
+ max_seq_len=-1,
75
+ attn_layers=Encoder(
76
+ dim=dim_speech,
77
+ depth=speech_enc_depth,
78
+ heads=speech_heads,
79
+ ff_dropout=0.1,
80
+ ff_mult=2,
81
+ attn_dropout=0.1,
82
+ use_rmsnorm=True,
83
+ ff_glu=True,
84
+ rotary_pos_emb=True,
85
+ ),
86
+ )
87
+ else:
88
+ self.text_transformer = Transformer(
89
+ causal=False,
90
+ seq_len=text_seq_len,
91
+ dim=dim_text,
92
+ depth=text_enc_depth,
93
+ heads=text_heads,
94
+ )
95
+ self.speech_transformer = Transformer(
96
+ causal=False,
97
+ seq_len=speech_seq_len,
98
+ dim=dim_speech,
99
+ depth=speech_enc_depth,
100
+ heads=speech_heads,
101
+ )
102
+
103
+ self.temperature = nn.Parameter(torch.tensor(1.0))
104
+ self.text_mask_percentage = text_mask_percentage
105
+ self.voice_mask_percentage = voice_mask_percentage
106
+ self.wav_token_compression = wav_token_compression
107
+ self.xformers = use_xformers
108
+ if not use_xformers:
109
+ self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
110
+ self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
111
+
112
+ def forward(self, text, speech_tokens, return_loss=False):
113
+ b, device = text.shape[0], text.device
114
+ if self.training:
115
+ text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
116
+ voice_mask = (
117
+ torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
118
+ )
119
+ else:
120
+ text_mask = torch.ones_like(text.float()).bool()
121
+ voice_mask = torch.ones_like(speech_tokens.float()).bool()
122
+
123
+ text_emb = self.text_emb(text)
124
+ speech_emb = self.speech_emb(speech_tokens)
125
+
126
+ if not self.xformers:
127
+ text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
128
+ speech_emb += self.speech_pos_emb(
129
+ torch.arange(speech_emb.shape[1], device=device)
130
+ )
131
+
132
+ enc_text = self.text_transformer(text_emb, mask=text_mask)
133
+ enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
134
+
135
+ text_latents = masked_mean(enc_text, text_mask, dim=1)
136
+ speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
137
+
138
+ text_latents = self.to_text_latent(text_latents)
139
+ speech_latents = self.to_speech_latent(speech_latents)
140
+
141
+ text_latents, speech_latents = map(
142
+ lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)
143
+ )
144
+
145
+ temp = self.temperature.exp()
146
+
147
+ if not return_loss:
148
+ sim = einsum("n d, n d -> n", text_latents, speech_latents) * temp
149
+ return sim
150
+
151
+ sim = einsum("i d, j d -> i j", text_latents, speech_latents) * temp
152
+ labels = torch.arange(b, device=device)
153
+ loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
154
+ return loss
155
+
156
+
157
+ if __name__ == "__main__":
158
+ clip = CLVP(text_mask_percentage=0.2, voice_mask_percentage=0.2)
159
+ clip(
160
+ torch.randint(0, 256, (2, 120)),
161
+ torch.tensor([50, 100]),
162
+ torch.randint(0, 8192, (2, 250)),
163
+ torch.tensor([101, 102]),
164
+ return_loss=True,
165
+ )
166
+ nonloss = clip(
167
+ torch.randint(0, 256, (2, 120)),
168
+ torch.tensor([50, 100]),
169
+ torch.randint(0, 8192, (2, 250)),
170
+ torch.tensor([101, 102]),
171
+ return_loss=False,
172
+ )
173
+ print(nonloss.shape)
tortoise/models/cvvp.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import einsum
5
+
6
+ from tortoise.models.arch_util import AttentionBlock
7
+ from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder
8
+
9
+
10
+ def exists(val):
11
+ return val is not None
12
+
13
+
14
+ def masked_mean(t, mask):
15
+ t = t.masked_fill(~mask, 0.0)
16
+ return t.sum(dim=1) / mask.sum(dim=1)
17
+
18
+
19
+ class CollapsingTransformer(nn.Module):
20
+ def __init__(
21
+ self,
22
+ model_dim,
23
+ output_dims,
24
+ heads,
25
+ dropout,
26
+ depth,
27
+ mask_percentage=0,
28
+ **encoder_kwargs
29
+ ):
30
+ super().__init__()
31
+ self.transformer = ContinuousTransformerWrapper(
32
+ max_seq_len=-1,
33
+ use_pos_emb=False,
34
+ attn_layers=Encoder(
35
+ dim=model_dim,
36
+ depth=depth,
37
+ heads=heads,
38
+ ff_dropout=dropout,
39
+ ff_mult=1,
40
+ attn_dropout=dropout,
41
+ use_rmsnorm=True,
42
+ ff_glu=True,
43
+ rotary_pos_emb=True,
44
+ **encoder_kwargs,
45
+ ),
46
+ )
47
+ self.pre_combiner = nn.Sequential(
48
+ nn.Conv1d(model_dim, output_dims, 1),
49
+ AttentionBlock(output_dims, num_heads=heads, do_checkpoint=False),
50
+ nn.Conv1d(output_dims, output_dims, 1),
51
+ )
52
+ self.mask_percentage = mask_percentage
53
+
54
+ def forward(self, x, **transformer_kwargs):
55
+ h = self.transformer(x, **transformer_kwargs)
56
+ h = h.permute(0, 2, 1)
57
+ h = self.pre_combiner(h).permute(0, 2, 1)
58
+ if self.training:
59
+ mask = torch.rand_like(h.float()) > self.mask_percentage
60
+ else:
61
+ mask = torch.ones_like(h.float()).bool()
62
+ return masked_mean(h, mask)
63
+
64
+
65
+ class ConvFormatEmbedding(nn.Module):
66
+ def __init__(self, *args, **kwargs):
67
+ super().__init__()
68
+ self.emb = nn.Embedding(*args, **kwargs)
69
+
70
+ def forward(self, x):
71
+ y = self.emb(x)
72
+ return y.permute(0, 2, 1)
73
+
74
+
75
+ class CVVP(nn.Module):
76
+ def __init__(
77
+ self,
78
+ model_dim=512,
79
+ transformer_heads=8,
80
+ dropout=0.1,
81
+ conditioning_enc_depth=8,
82
+ cond_mask_percentage=0,
83
+ mel_channels=80,
84
+ mel_codes=None,
85
+ speech_enc_depth=8,
86
+ speech_mask_percentage=0,
87
+ latent_multiplier=1,
88
+ ):
89
+ super().__init__()
90
+ latent_dim = latent_multiplier * model_dim
91
+ self.temperature = nn.Parameter(torch.tensor(1.0))
92
+
93
+ self.cond_emb = nn.Sequential(
94
+ nn.Conv1d(mel_channels, model_dim // 2, kernel_size=5, stride=2, padding=2),
95
+ nn.Conv1d(model_dim // 2, model_dim, kernel_size=3, stride=2, padding=1),
96
+ )
97
+ self.conditioning_transformer = CollapsingTransformer(
98
+ model_dim,
99
+ model_dim,
100
+ transformer_heads,
101
+ dropout,
102
+ conditioning_enc_depth,
103
+ cond_mask_percentage,
104
+ )
105
+ self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False)
106
+
107
+ if mel_codes is None:
108
+ self.speech_emb = nn.Conv1d(
109
+ mel_channels, model_dim, kernel_size=5, padding=2
110
+ )
111
+ else:
112
+ self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
113
+ self.speech_transformer = CollapsingTransformer(
114
+ model_dim,
115
+ latent_dim,
116
+ transformer_heads,
117
+ dropout,
118
+ speech_enc_depth,
119
+ speech_mask_percentage,
120
+ )
121
+ self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False)
122
+
123
+ def get_grad_norm_parameter_groups(self):
124
+ return {
125
+ "conditioning": list(self.conditioning_transformer.parameters()),
126
+ "speech": list(self.speech_transformer.parameters()),
127
+ }
128
+
129
+ def forward(self, mel_cond, mel_input, return_loss=False):
130
+ cond_emb = self.cond_emb(mel_cond).permute(0, 2, 1)
131
+ enc_cond = self.conditioning_transformer(cond_emb)
132
+ cond_latents = self.to_conditioning_latent(enc_cond)
133
+
134
+ speech_emb = self.speech_emb(mel_input).permute(0, 2, 1)
135
+ enc_speech = self.speech_transformer(speech_emb)
136
+ speech_latents = self.to_speech_latent(enc_speech)
137
+
138
+ cond_latents, speech_latents = map(
139
+ lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents)
140
+ )
141
+ temp = self.temperature.exp()
142
+
143
+ if not return_loss:
144
+ sim = einsum("n d, n d -> n", cond_latents, speech_latents) * temp
145
+ return sim
146
+
147
+ sim = einsum("i d, j d -> i j", cond_latents, speech_latents) * temp
148
+ labels = torch.arange(cond_latents.shape[0], device=mel_input.device)
149
+ loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
150
+
151
+ return loss
152
+
153
+
154
+ if __name__ == "__main__":
155
+ clvp = CVVP()
156
+ clvp(torch.randn(2, 80, 100), torch.randn(2, 80, 95), return_loss=True)
tortoise/models/diffusion_decoder.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from abc import abstractmethod
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import autocast
9
+
10
+ from tortoise.models.arch_util import AttentionBlock, normalization
11
+
12
+
13
+ def is_latent(t):
14
+ return t.dtype == torch.float
15
+
16
+
17
+ def is_sequence(t):
18
+ return t.dtype == torch.long
19
+
20
+
21
+ def timestep_embedding(timesteps, dim, max_period=10000):
22
+ """
23
+ Create sinusoidal timestep embeddings.
24
+
25
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
26
+ These may be fractional.
27
+ :param dim: the dimension of the output.
28
+ :param max_period: controls the minimum frequency of the embeddings.
29
+ :return: an [N x dim] Tensor of positional embeddings.
30
+ """
31
+ half = dim // 2
32
+ freqs = torch.exp(
33
+ -math.log(max_period)
34
+ * torch.arange(start=0, end=half, dtype=torch.float32)
35
+ / half
36
+ ).to(device=timesteps.device)
37
+ args = timesteps[:, None].float() * freqs[None]
38
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
39
+ if dim % 2:
40
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
41
+ return embedding
42
+
43
+
44
+ class TimestepBlock(nn.Module):
45
+ @abstractmethod
46
+ def forward(self, x, emb):
47
+ """
48
+ Apply the module to `x` given `emb` timestep embeddings.
49
+ """
50
+
51
+
52
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
53
+ def forward(self, x, emb):
54
+ for layer in self:
55
+ if isinstance(layer, TimestepBlock):
56
+ x = layer(x, emb)
57
+ else:
58
+ x = layer(x)
59
+ return x
60
+
61
+
62
+ class ResBlock(TimestepBlock):
63
+ def __init__(
64
+ self,
65
+ channels,
66
+ emb_channels,
67
+ dropout,
68
+ out_channels=None,
69
+ dims=2,
70
+ kernel_size=3,
71
+ efficient_config=True,
72
+ use_scale_shift_norm=False,
73
+ ):
74
+ super().__init__()
75
+ self.channels = channels
76
+ self.emb_channels = emb_channels
77
+ self.dropout = dropout
78
+ self.out_channels = out_channels or channels
79
+ self.use_scale_shift_norm = use_scale_shift_norm
80
+ padding = {1: 0, 3: 1, 5: 2}[kernel_size]
81
+ eff_kernel = 1 if efficient_config else 3
82
+ eff_padding = 0 if efficient_config else 1
83
+
84
+ self.in_layers = nn.Sequential(
85
+ normalization(channels),
86
+ nn.SiLU(),
87
+ nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding),
88
+ )
89
+
90
+ self.emb_layers = nn.Sequential(
91
+ nn.SiLU(),
92
+ nn.Linear(
93
+ emb_channels,
94
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
95
+ ),
96
+ )
97
+ self.out_layers = nn.Sequential(
98
+ normalization(self.out_channels),
99
+ nn.SiLU(),
100
+ nn.Dropout(p=dropout),
101
+ nn.Conv1d(
102
+ self.out_channels, self.out_channels, kernel_size, padding=padding
103
+ ),
104
+ )
105
+
106
+ if self.out_channels == channels:
107
+ self.skip_connection = nn.Identity()
108
+ else:
109
+ self.skip_connection = nn.Conv1d(
110
+ channels, self.out_channels, eff_kernel, padding=eff_padding
111
+ )
112
+
113
+ def forward(self, x, emb):
114
+ h = self.in_layers(x)
115
+ emb_out = self.emb_layers(emb).type(h.dtype)
116
+ while len(emb_out.shape) < len(h.shape):
117
+ emb_out = emb_out[..., None]
118
+ if self.use_scale_shift_norm:
119
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
120
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
121
+ h = out_norm(h) * (1 + scale) + shift
122
+ h = out_rest(h)
123
+ else:
124
+ h = h + emb_out
125
+ h = self.out_layers(h)
126
+ return self.skip_connection(x) + h
127
+
128
+
129
+ class DiffusionLayer(TimestepBlock):
130
+ def __init__(self, model_channels, dropout, num_heads):
131
+ super().__init__()
132
+ self.resblk = ResBlock(
133
+ model_channels,
134
+ model_channels,
135
+ dropout,
136
+ model_channels,
137
+ dims=1,
138
+ use_scale_shift_norm=True,
139
+ )
140
+ self.attn = AttentionBlock(
141
+ model_channels, num_heads, relative_pos_embeddings=True
142
+ )
143
+
144
+ def forward(self, x, time_emb):
145
+ y = self.resblk(x, time_emb)
146
+ return self.attn(y)
147
+
148
+
149
+ class DiffusionTts(nn.Module):
150
+ def __init__(
151
+ self,
152
+ model_channels=512,
153
+ num_layers=8,
154
+ in_channels=100,
155
+ in_latent_channels=512,
156
+ in_tokens=8193,
157
+ out_channels=200, # mean and variance
158
+ dropout=0,
159
+ use_fp16=False,
160
+ num_heads=16,
161
+ # Parameters for regularization.
162
+ layer_drop=0.1,
163
+ unconditioned_percentage=0.1, # This implements a mechanism similar to what is used in classifier-free training.
164
+ ):
165
+ super().__init__()
166
+
167
+ self.in_channels = in_channels
168
+ self.model_channels = model_channels
169
+ self.out_channels = out_channels
170
+ self.dropout = dropout
171
+ self.num_heads = num_heads
172
+ self.unconditioned_percentage = unconditioned_percentage
173
+ self.enable_fp16 = use_fp16
174
+ self.layer_drop = layer_drop
175
+
176
+ self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1)
177
+ self.time_embed = nn.Sequential(
178
+ nn.Linear(model_channels, model_channels),
179
+ nn.SiLU(),
180
+ nn.Linear(model_channels, model_channels),
181
+ )
182
+
183
+ # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
184
+ # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
185
+ # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
186
+ # transformer network.
187
+ self.code_embedding = nn.Embedding(in_tokens, model_channels)
188
+ self.code_converter = nn.Sequential(
189
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
190
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
191
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
192
+ )
193
+ self.code_norm = normalization(model_channels)
194
+ self.latent_conditioner = nn.Sequential(
195
+ nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
196
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
197
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
198
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
199
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
200
+ )
201
+ self.contextual_embedder = nn.Sequential(
202
+ nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2),
203
+ nn.Conv1d(model_channels, model_channels * 2, 3, padding=1, stride=2),
204
+ AttentionBlock(
205
+ model_channels * 2,
206
+ num_heads,
207
+ relative_pos_embeddings=True,
208
+ do_checkpoint=False,
209
+ ),
210
+ AttentionBlock(
211
+ model_channels * 2,
212
+ num_heads,
213
+ relative_pos_embeddings=True,
214
+ do_checkpoint=False,
215
+ ),
216
+ AttentionBlock(
217
+ model_channels * 2,
218
+ num_heads,
219
+ relative_pos_embeddings=True,
220
+ do_checkpoint=False,
221
+ ),
222
+ AttentionBlock(
223
+ model_channels * 2,
224
+ num_heads,
225
+ relative_pos_embeddings=True,
226
+ do_checkpoint=False,
227
+ ),
228
+ AttentionBlock(
229
+ model_channels * 2,
230
+ num_heads,
231
+ relative_pos_embeddings=True,
232
+ do_checkpoint=False,
233
+ ),
234
+ )
235
+ self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1))
236
+ self.conditioning_timestep_integrator = TimestepEmbedSequential(
237
+ DiffusionLayer(model_channels, dropout, num_heads),
238
+ DiffusionLayer(model_channels, dropout, num_heads),
239
+ DiffusionLayer(model_channels, dropout, num_heads),
240
+ )
241
+
242
+ self.integrating_conv = nn.Conv1d(
243
+ model_channels * 2, model_channels, kernel_size=1
244
+ )
245
+ self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
246
+
247
+ self.layers = nn.ModuleList(
248
+ [
249
+ DiffusionLayer(model_channels, dropout, num_heads)
250
+ for _ in range(num_layers)
251
+ ]
252
+ + [
253
+ ResBlock(
254
+ model_channels,
255
+ model_channels,
256
+ dropout,
257
+ dims=1,
258
+ use_scale_shift_norm=True,
259
+ )
260
+ for _ in range(3)
261
+ ]
262
+ )
263
+
264
+ self.out = nn.Sequential(
265
+ normalization(model_channels),
266
+ nn.SiLU(),
267
+ nn.Conv1d(model_channels, out_channels, 3, padding=1),
268
+ )
269
+
270
+ def get_grad_norm_parameter_groups(self):
271
+ groups = {
272
+ "minicoder": list(self.contextual_embedder.parameters()),
273
+ "layers": list(self.layers.parameters()),
274
+ "code_converters": list(self.code_embedding.parameters())
275
+ + list(self.code_converter.parameters())
276
+ + list(self.latent_conditioner.parameters())
277
+ + list(self.latent_conditioner.parameters()),
278
+ "timestep_integrator": list(
279
+ self.conditioning_timestep_integrator.parameters()
280
+ )
281
+ + list(self.integrating_conv.parameters()),
282
+ "time_embed": list(self.time_embed.parameters()),
283
+ }
284
+ return groups
285
+
286
+ def get_conditioning(self, conditioning_input):
287
+ speech_conditioning_input = (
288
+ conditioning_input.unsqueeze(1)
289
+ if len(conditioning_input.shape) == 3
290
+ else conditioning_input
291
+ )
292
+ conds = []
293
+ for j in range(speech_conditioning_input.shape[1]):
294
+ conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
295
+ conds = torch.cat(conds, dim=-1)
296
+ conds = conds.mean(dim=-1)
297
+ return conds
298
+
299
+ def timestep_independent(
300
+ self,
301
+ aligned_conditioning,
302
+ conditioning_latent,
303
+ expected_seq_len,
304
+ return_code_pred,
305
+ ):
306
+ # Shuffle aligned_latent to BxCxS format
307
+ if is_latent(aligned_conditioning):
308
+ aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
309
+
310
+ cond_scale, cond_shift = torch.chunk(conditioning_latent, 2, dim=1)
311
+ if is_latent(aligned_conditioning):
312
+ code_emb = self.latent_conditioner(aligned_conditioning)
313
+ else:
314
+ code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
315
+ code_emb = self.code_converter(code_emb)
316
+ code_emb = self.code_norm(code_emb) * (
317
+ 1 + cond_scale.unsqueeze(-1)
318
+ ) + cond_shift.unsqueeze(-1)
319
+
320
+ unconditioned_batches = torch.zeros(
321
+ (code_emb.shape[0], 1, 1), device=code_emb.device
322
+ )
323
+ # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
324
+ if self.training and self.unconditioned_percentage > 0:
325
+ unconditioned_batches = (
326
+ torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device)
327
+ < self.unconditioned_percentage
328
+ )
329
+ code_emb = torch.where(
330
+ unconditioned_batches,
331
+ self.unconditioned_embedding.repeat(
332
+ aligned_conditioning.shape[0], 1, 1
333
+ ),
334
+ code_emb,
335
+ )
336
+ expanded_code_emb = F.interpolate(
337
+ code_emb, size=expected_seq_len, mode="nearest"
338
+ )
339
+
340
+ if not return_code_pred:
341
+ return expanded_code_emb
342
+ else:
343
+ mel_pred = self.mel_head(expanded_code_emb)
344
+ # Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
345
+ mel_pred = mel_pred * unconditioned_batches.logical_not()
346
+ return expanded_code_emb, mel_pred
347
+
348
+ def forward(
349
+ self,
350
+ x,
351
+ timesteps,
352
+ aligned_conditioning=None,
353
+ conditioning_latent=None,
354
+ precomputed_aligned_embeddings=None,
355
+ conditioning_free=False,
356
+ return_code_pred=False,
357
+ ):
358
+ """
359
+ Apply the model to an input batch.
360
+
361
+ :param x: an [N x C x ...] Tensor of inputs.
362
+ :param timesteps: a 1-D batch of timesteps.
363
+ :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
364
+ :param conditioning_latent: a pre-computed conditioning latent; see get_conditioning().
365
+ :param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
366
+ :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
367
+ :return: an [N x C x ...] Tensor of outputs.
368
+ """
369
+ assert precomputed_aligned_embeddings is not None or (
370
+ aligned_conditioning is not None and conditioning_latent is not None
371
+ )
372
+ assert not (
373
+ return_code_pred and precomputed_aligned_embeddings is not None
374
+ ) # These two are mutually exclusive.
375
+
376
+ unused_params = []
377
+ if conditioning_free:
378
+ code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
379
+ unused_params.extend(
380
+ list(self.code_converter.parameters())
381
+ + list(self.code_embedding.parameters())
382
+ )
383
+ unused_params.extend(list(self.latent_conditioner.parameters()))
384
+ else:
385
+ if precomputed_aligned_embeddings is not None:
386
+ code_emb = precomputed_aligned_embeddings
387
+ else:
388
+ code_emb, mel_pred = self.timestep_independent(
389
+ aligned_conditioning, conditioning_latent, x.shape[-1], True
390
+ )
391
+ if is_latent(aligned_conditioning):
392
+ unused_params.extend(
393
+ list(self.code_converter.parameters())
394
+ + list(self.code_embedding.parameters())
395
+ )
396
+ else:
397
+ unused_params.extend(list(self.latent_conditioner.parameters()))
398
+
399
+ unused_params.append(self.unconditioned_embedding)
400
+
401
+ time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
402
+ code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
403
+ x = self.inp_block(x)
404
+ x = torch.cat([x, code_emb], dim=1)
405
+ x = self.integrating_conv(x)
406
+ for i, lyr in enumerate(self.layers):
407
+ # Do layer drop where applicable. Do not drop first and last layers.
408
+ if (
409
+ self.training
410
+ and self.layer_drop > 0
411
+ and i != 0
412
+ and i != (len(self.layers) - 1)
413
+ and random.random() < self.layer_drop
414
+ ):
415
+ unused_params.extend(list(lyr.parameters()))
416
+ else:
417
+ # First and last blocks will have autocast disabled for improved precision.
418
+ with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
419
+ x = lyr(x, time_emb)
420
+
421
+ x = x.float()
422
+ out = self.out(x)
423
+
424
+ # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
425
+ extraneous_addition = 0
426
+ for p in unused_params:
427
+ extraneous_addition = extraneous_addition + p.mean()
428
+ out = out + extraneous_addition * 0
429
+
430
+ if return_code_pred:
431
+ return out, mel_pred
432
+ return out
433
+
434
+
435
+ if __name__ == "__main__":
436
+ clip = torch.randn(2, 100, 400)
437
+ aligned_latent = torch.randn(2, 388, 512)
438
+ aligned_sequence = torch.randint(0, 8192, (2, 100))
439
+ cond = torch.randn(2, 100, 400)
440
+ ts = torch.LongTensor([600, 600])
441
+ model = DiffusionTts(512, layer_drop=0.3, unconditioned_percentage=0.5)
442
+ # Test with latent aligned conditioning
443
+ # o = model(clip, ts, aligned_latent, cond)
444
+ # Test with sequence aligned conditioning
445
+ o = model(clip, ts, aligned_sequence, cond)
tortoise/models/random_latent_generator.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5):
9
+ if bias is not None:
10
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
11
+ return (
12
+ F.leaky_relu(
13
+ input + bias.view(1, bias.shape[0], *rest_dim),
14
+ negative_slope=negative_slope,
15
+ )
16
+ * scale
17
+ )
18
+ else:
19
+ return F.leaky_relu(input, negative_slope=0.2) * scale
20
+
21
+
22
+ class EqualLinear(nn.Module):
23
+ def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1):
24
+ super().__init__()
25
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
26
+ if bias:
27
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
28
+ else:
29
+ self.bias = None
30
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
31
+ self.lr_mul = lr_mul
32
+
33
+ def forward(self, input):
34
+ out = F.linear(input, self.weight * self.scale)
35
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
36
+ return out
37
+
38
+
39
+ class RandomLatentConverter(nn.Module):
40
+ def __init__(self, channels):
41
+ super().__init__()
42
+ self.layers = nn.Sequential(
43
+ *[EqualLinear(channels, channels, lr_mul=0.1) for _ in range(5)],
44
+ nn.Linear(channels, channels)
45
+ )
46
+ self.channels = channels
47
+
48
+ def forward(self, ref):
49
+ r = torch.randn(ref.shape[0], self.channels, device=ref.device)
50
+ y = self.layers(r)
51
+ return y
52
+
53
+
54
+ if __name__ == "__main__":
55
+ model = RandomLatentConverter(512)
56
+ model(torch.randn(5, 512))
tortoise/models/transformer.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from einops import rearrange
4
+ from torch import nn
5
+
6
+ # helpers
7
+
8
+
9
+ def exists(val):
10
+ return val is not None
11
+
12
+
13
+ def default(val, d):
14
+ return val if exists(val) else d
15
+
16
+
17
+ def cast_tuple(val, depth=1):
18
+ if isinstance(val, list):
19
+ val = tuple(val)
20
+ return val if isinstance(val, tuple) else (val,) * depth
21
+
22
+
23
+ def max_neg_value(t):
24
+ return -torch.finfo(t.dtype).max
25
+
26
+
27
+ def stable_softmax(t, dim=-1, alpha=32**2):
28
+ t = t / alpha
29
+ t = t - torch.amax(t, dim=dim, keepdim=True).detach()
30
+ return (t * alpha).softmax(dim=dim)
31
+
32
+
33
+ def route_args(router, args, depth):
34
+ routed_args = [(dict(), dict()) for _ in range(depth)]
35
+ matched_keys = [key for key in args.keys() if key in router]
36
+
37
+ for key in matched_keys:
38
+ val = args[key]
39
+ for depth, ((f_args, g_args), routes) in enumerate(
40
+ zip(routed_args, router[key])
41
+ ):
42
+ new_f_args, new_g_args = map(
43
+ lambda route: ({key: val} if route else {}), routes
44
+ )
45
+ routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
46
+ return routed_args
47
+
48
+
49
+ # classes
50
+ class SequentialSequence(nn.Module):
51
+ def __init__(self, layers, args_route={}, layer_dropout=0.0):
52
+ super().__init__()
53
+ assert all(
54
+ len(route) == len(layers) for route in args_route.values()
55
+ ), "each argument route map must have the same depth as the number of sequential layers"
56
+ self.layers = layers
57
+ self.args_route = args_route
58
+ self.layer_dropout = layer_dropout
59
+
60
+ def forward(self, x, **kwargs):
61
+ args = route_args(self.args_route, kwargs, len(self.layers))
62
+ layers_and_args = list(zip(self.layers, args))
63
+
64
+ for (f, g), (f_args, g_args) in layers_and_args:
65
+ x = x + f(x, **f_args)
66
+ x = x + g(x, **g_args)
67
+ return x
68
+
69
+
70
+ class DivideMax(nn.Module):
71
+ def __init__(self, dim):
72
+ super().__init__()
73
+ self.dim = dim
74
+
75
+ def forward(self, x):
76
+ maxes = x.amax(dim=self.dim, keepdim=True).detach()
77
+ return x / maxes
78
+
79
+
80
+ # https://arxiv.org/abs/2103.17239
81
+ class LayerScale(nn.Module):
82
+ def __init__(self, dim, depth, fn):
83
+ super().__init__()
84
+ if depth <= 18:
85
+ init_eps = 0.1
86
+ elif depth > 18 and depth <= 24:
87
+ init_eps = 1e-5
88
+ else:
89
+ init_eps = 1e-6
90
+
91
+ scale = torch.zeros(1, 1, dim).fill_(init_eps)
92
+ self.scale = nn.Parameter(scale)
93
+ self.fn = fn
94
+
95
+ def forward(self, x, **kwargs):
96
+ return self.fn(x, **kwargs) * self.scale
97
+
98
+
99
+ # layer norm
100
+
101
+
102
+ class PreNorm(nn.Module):
103
+ def __init__(self, dim, fn, sandwich=False):
104
+ super().__init__()
105
+ self.norm = nn.LayerNorm(dim)
106
+ self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
107
+ self.fn = fn
108
+
109
+ def forward(self, x, **kwargs):
110
+ x = self.norm(x)
111
+ x = self.fn(x, **kwargs)
112
+ return self.norm_out(x)
113
+
114
+
115
+ # feed forward
116
+
117
+
118
+ class GEGLU(nn.Module):
119
+ def forward(self, x):
120
+ x, gates = x.chunk(2, dim=-1)
121
+ return x * F.gelu(gates)
122
+
123
+
124
+ class FeedForward(nn.Module):
125
+ def __init__(self, dim, dropout=0.0, mult=4.0):
126
+ super().__init__()
127
+ self.net = nn.Sequential(
128
+ nn.Linear(dim, dim * mult * 2),
129
+ GEGLU(),
130
+ nn.Dropout(dropout),
131
+ nn.Linear(dim * mult, dim),
132
+ )
133
+
134
+ def forward(self, x):
135
+ return self.net(x)
136
+
137
+
138
+ # Attention
139
+
140
+
141
+ class Attention(nn.Module):
142
+ def __init__(self, dim, seq_len, causal=True, heads=8, dim_head=64, dropout=0.0):
143
+ super().__init__()
144
+ inner_dim = dim_head * heads
145
+ self.heads = heads
146
+ self.seq_len = seq_len
147
+ self.scale = dim_head**-0.5
148
+
149
+ self.causal = causal
150
+
151
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
152
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
153
+
154
+ def forward(self, x, mask=None):
155
+ b, n, _, h, device = *x.shape, self.heads, x.device
156
+ softmax = torch.softmax
157
+
158
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
159
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
160
+
161
+ q = q * self.scale
162
+
163
+ dots = torch.einsum("b h i d, b h j d -> b h i j", q, k)
164
+ mask_value = max_neg_value(dots)
165
+
166
+ if exists(mask):
167
+ mask = rearrange(mask, "b j -> b () () j")
168
+ dots.masked_fill_(~mask, mask_value)
169
+ del mask
170
+
171
+ if self.causal:
172
+ i, j = dots.shape[-2:]
173
+ mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
174
+ dots.masked_fill_(mask, mask_value)
175
+
176
+ attn = softmax(dots, dim=-1)
177
+
178
+ out = torch.einsum("b h i j, b h j d -> b h i d", attn, v)
179
+ out = rearrange(out, "b h n d -> b n (h d)")
180
+ out = self.to_out(out)
181
+ return out
182
+
183
+
184
+ # main transformer class
185
+ class Transformer(nn.Module):
186
+ def __init__(
187
+ self,
188
+ *,
189
+ dim,
190
+ depth,
191
+ seq_len,
192
+ causal=True,
193
+ heads=8,
194
+ dim_head=64,
195
+ ff_mult=4,
196
+ attn_dropout=0.0,
197
+ ff_dropout=0.0,
198
+ sparse_attn=False,
199
+ sandwich_norm=False,
200
+ ):
201
+ super().__init__()
202
+ layers = nn.ModuleList([])
203
+ sparse_layer = cast_tuple(sparse_attn, depth)
204
+
205
+ for ind, sparse_attn in zip(range(depth), sparse_layer):
206
+ attn = Attention(
207
+ dim,
208
+ causal=causal,
209
+ seq_len=seq_len,
210
+ heads=heads,
211
+ dim_head=dim_head,
212
+ dropout=attn_dropout,
213
+ )
214
+
215
+ ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout)
216
+
217
+ layers.append(
218
+ nn.ModuleList(
219
+ [
220
+ LayerScale(
221
+ dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)
222
+ ),
223
+ LayerScale(
224
+ dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)
225
+ ),
226
+ ]
227
+ )
228
+ )
229
+
230
+ execute_type = SequentialSequence
231
+ route_attn = ((True, False),) * depth
232
+ attn_route_map = {"mask": route_attn}
233
+
234
+ self.layers = execute_type(layers, args_route=attn_route_map)
235
+
236
+ def forward(self, x, **kwargs):
237
+ return self.layers(x, **kwargs)
tortoise/models/utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ try: import gdown
3
+ except ImportError:
4
+ raise ImportError(
5
+ "Sorry, gdown is required in order to download the new BigVGAN vocoder.\n"
6
+ "Please install it with `pip install gdown` and try again."
7
+ )
8
+ from urllib import request
9
+
10
+ import progressbar
11
+
12
+ D_STEM = "https://drive.google.com/uc?id="
13
+
14
+ DEFAULT_MODELS_DIR = os.path.join(
15
+ os.path.expanduser("~"), ".cache", "tortoise", "models"
16
+ )
17
+ MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR)
18
+ MODELS = {
19
+ "autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth",
20
+ "classifier.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth",
21
+ "clvp2.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth",
22
+ "cvvp.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/cvvp.pth",
23
+ "diffusion_decoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth",
24
+ "vocoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth",
25
+ "rlg_auto.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth",
26
+ "rlg_diffuser.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth",
27
+ # these links are from the nvidia gdrive
28
+ "bigvgan_base_24khz_100band_g.pth": "https://drive.google.com/uc?id=1_cKskUDuvxQJUEBwdgjAxKuDTUW6kPdY",
29
+ "bigvgan_24khz_100band_g.pth": "https://drive.google.com/uc?id=1wmP_mAs7d00KHVfVEl8B5Gb72Kzpcavp",
30
+ }
31
+
32
+
33
+ pbar = None
34
+ def download_models(specific_models=None):
35
+ """
36
+ Call to download all the models that Tortoise uses.
37
+ """
38
+ os.makedirs(MODELS_DIR, exist_ok=True)
39
+
40
+ def show_progress(block_num, block_size, total_size):
41
+ global pbar
42
+ if pbar is None:
43
+ pbar = progressbar.ProgressBar(maxval=total_size)
44
+ pbar.start()
45
+
46
+ downloaded = block_num * block_size
47
+ if downloaded < total_size:
48
+ pbar.update(downloaded)
49
+ else:
50
+ pbar.finish()
51
+ pbar = None
52
+
53
+ for model_name, url in MODELS.items():
54
+ if specific_models is not None and model_name not in specific_models:
55
+ continue
56
+ model_path = os.path.join(MODELS_DIR, model_name)
57
+ if os.path.exists(model_path):
58
+ continue
59
+ print(f"Downloading {model_name} from {url}...")
60
+ if D_STEM in url:
61
+ gdown.download(url, model_path, quiet=False)
62
+ else:
63
+ request.urlretrieve(url, model_path, show_progress)
64
+ print("Done.")
65
+
66
+
67
+ def get_model_path(model_name, models_dir=MODELS_DIR):
68
+ """
69
+ Get path to given model, download it if it doesn't exist.
70
+ """
71
+ if model_name not in MODELS:
72
+ raise ValueError(f"Model {model_name} not found in available models.")
73
+ model_path = os.path.join(models_dir, model_name)
74
+ if not os.path.exists(model_path) and models_dir == MODELS_DIR:
75
+ download_models([model_name])
76
+ return model_path
77
+
78
+ if __name__ == "__main__":
79
+ download_models() # to download all models
80
+
tortoise/models/vocoder.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import json
6
+ from enum import Enum
7
+ from typing import Optional, Callable
8
+ from dataclasses import dataclass
9
+ try:
10
+ from BigVGAN.models import BigVGAN as BVGModel
11
+ from BigVGAN.env import AttrDict
12
+ except ImportError:
13
+ raise ImportError(
14
+ "BigVGAN not installed, can't use BigVGAN vocoder\n"
15
+ "Please see the installation instructions on README."
16
+ )
17
+
18
+ MAX_WAV_VALUE = 32768.0
19
+
20
+
21
+ class KernelPredictor(torch.nn.Module):
22
+ """Kernel predictor for the location-variable convolutions"""
23
+
24
+ def __init__(
25
+ self,
26
+ cond_channels,
27
+ conv_in_channels,
28
+ conv_out_channels,
29
+ conv_layers,
30
+ conv_kernel_size=3,
31
+ kpnet_hidden_channels=64,
32
+ kpnet_conv_size=3,
33
+ kpnet_dropout=0.0,
34
+ kpnet_nonlinear_activation="LeakyReLU",
35
+ kpnet_nonlinear_activation_params={"negative_slope": 0.1},
36
+ ):
37
+ """
38
+ Args:
39
+ cond_channels (int): number of channel for the conditioning sequence,
40
+ conv_in_channels (int): number of channel for the input sequence,
41
+ conv_out_channels (int): number of channel for the output sequence,
42
+ conv_layers (int): number of layers
43
+ """
44
+ super().__init__()
45
+
46
+ self.conv_in_channels = conv_in_channels
47
+ self.conv_out_channels = conv_out_channels
48
+ self.conv_kernel_size = conv_kernel_size
49
+ self.conv_layers = conv_layers
50
+
51
+ kpnet_kernel_channels = (
52
+ conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
53
+ ) # l_w
54
+ kpnet_bias_channels = conv_out_channels * conv_layers # l_b
55
+
56
+ self.input_conv = nn.Sequential(
57
+ nn.utils.weight_norm(
58
+ nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
59
+ ),
60
+ getattr(nn, kpnet_nonlinear_activation)(
61
+ **kpnet_nonlinear_activation_params
62
+ ),
63
+ )
64
+
65
+ self.residual_convs = nn.ModuleList()
66
+ padding = (kpnet_conv_size - 1) // 2
67
+ for _ in range(3):
68
+ self.residual_convs.append(
69
+ nn.Sequential(
70
+ nn.Dropout(kpnet_dropout),
71
+ nn.utils.weight_norm(
72
+ nn.Conv1d(
73
+ kpnet_hidden_channels,
74
+ kpnet_hidden_channels,
75
+ kpnet_conv_size,
76
+ padding=padding,
77
+ bias=True,
78
+ )
79
+ ),
80
+ getattr(nn, kpnet_nonlinear_activation)(
81
+ **kpnet_nonlinear_activation_params
82
+ ),
83
+ nn.utils.weight_norm(
84
+ nn.Conv1d(
85
+ kpnet_hidden_channels,
86
+ kpnet_hidden_channels,
87
+ kpnet_conv_size,
88
+ padding=padding,
89
+ bias=True,
90
+ )
91
+ ),
92
+ getattr(nn, kpnet_nonlinear_activation)(
93
+ **kpnet_nonlinear_activation_params
94
+ ),
95
+ )
96
+ )
97
+ self.kernel_conv = nn.utils.weight_norm(
98
+ nn.Conv1d(
99
+ kpnet_hidden_channels,
100
+ kpnet_kernel_channels,
101
+ kpnet_conv_size,
102
+ padding=padding,
103
+ bias=True,
104
+ )
105
+ )
106
+ self.bias_conv = nn.utils.weight_norm(
107
+ nn.Conv1d(
108
+ kpnet_hidden_channels,
109
+ kpnet_bias_channels,
110
+ kpnet_conv_size,
111
+ padding=padding,
112
+ bias=True,
113
+ )
114
+ )
115
+
116
+ def forward(self, c):
117
+ """
118
+ Args:
119
+ c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
120
+ """
121
+ batch, _, cond_length = c.shape
122
+ c = self.input_conv(c)
123
+ for residual_conv in self.residual_convs:
124
+ residual_conv.to(c.device)
125
+ c = c + residual_conv(c)
126
+ k = self.kernel_conv(c)
127
+ b = self.bias_conv(c)
128
+ kernels = k.contiguous().view(
129
+ batch,
130
+ self.conv_layers,
131
+ self.conv_in_channels,
132
+ self.conv_out_channels,
133
+ self.conv_kernel_size,
134
+ cond_length,
135
+ )
136
+ bias = b.contiguous().view(
137
+ batch,
138
+ self.conv_layers,
139
+ self.conv_out_channels,
140
+ cond_length,
141
+ )
142
+
143
+ return kernels, bias
144
+
145
+ def remove_weight_norm(self):
146
+ nn.utils.remove_weight_norm(self.input_conv[0])
147
+ nn.utils.remove_weight_norm(self.kernel_conv)
148
+ nn.utils.remove_weight_norm(self.bias_conv)
149
+ for block in self.residual_convs:
150
+ nn.utils.remove_weight_norm(block[1])
151
+ nn.utils.remove_weight_norm(block[3])
152
+
153
+
154
+ class LVCBlock(torch.nn.Module):
155
+ """the location-variable convolutions"""
156
+
157
+ def __init__(
158
+ self,
159
+ in_channels,
160
+ cond_channels,
161
+ stride,
162
+ dilations=[1, 3, 9, 27],
163
+ lReLU_slope=0.2,
164
+ conv_kernel_size=3,
165
+ cond_hop_length=256,
166
+ kpnet_hidden_channels=64,
167
+ kpnet_conv_size=3,
168
+ kpnet_dropout=0.0,
169
+ ):
170
+ super().__init__()
171
+
172
+ self.cond_hop_length = cond_hop_length
173
+ self.conv_layers = len(dilations)
174
+ self.conv_kernel_size = conv_kernel_size
175
+
176
+ self.kernel_predictor = KernelPredictor(
177
+ cond_channels=cond_channels,
178
+ conv_in_channels=in_channels,
179
+ conv_out_channels=2 * in_channels,
180
+ conv_layers=len(dilations),
181
+ conv_kernel_size=conv_kernel_size,
182
+ kpnet_hidden_channels=kpnet_hidden_channels,
183
+ kpnet_conv_size=kpnet_conv_size,
184
+ kpnet_dropout=kpnet_dropout,
185
+ kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
186
+ )
187
+
188
+ self.convt_pre = nn.Sequential(
189
+ nn.LeakyReLU(lReLU_slope),
190
+ nn.utils.weight_norm(
191
+ nn.ConvTranspose1d(
192
+ in_channels,
193
+ in_channels,
194
+ 2 * stride,
195
+ stride=stride,
196
+ padding=stride // 2 + stride % 2,
197
+ output_padding=stride % 2,
198
+ )
199
+ ),
200
+ )
201
+
202
+ self.conv_blocks = nn.ModuleList()
203
+ for dilation in dilations:
204
+ self.conv_blocks.append(
205
+ nn.Sequential(
206
+ nn.LeakyReLU(lReLU_slope),
207
+ nn.utils.weight_norm(
208
+ nn.Conv1d(
209
+ in_channels,
210
+ in_channels,
211
+ conv_kernel_size,
212
+ padding=dilation * (conv_kernel_size - 1) // 2,
213
+ dilation=dilation,
214
+ )
215
+ ),
216
+ nn.LeakyReLU(lReLU_slope),
217
+ )
218
+ )
219
+
220
+ def forward(self, x, c):
221
+ """forward propagation of the location-variable convolutions.
222
+ Args:
223
+ x (Tensor): the input sequence (batch, in_channels, in_length)
224
+ c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
225
+
226
+ Returns:
227
+ Tensor: the output sequence (batch, in_channels, in_length)
228
+ """
229
+ _, in_channels, _ = x.shape # (B, c_g, L')
230
+
231
+ x = self.convt_pre(x) # (B, c_g, stride * L')
232
+ kernels, bias = self.kernel_predictor(c)
233
+
234
+ for i, conv in enumerate(self.conv_blocks):
235
+ output = conv(x) # (B, c_g, stride * L')
236
+
237
+ k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
238
+ b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
239
+
240
+ output = self.location_variable_convolution(
241
+ output, k, b, hop_size=self.cond_hop_length
242
+ ) # (B, 2 * c_g, stride * L'): LVC
243
+ x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
244
+ output[:, in_channels:, :]
245
+ ) # (B, c_g, stride * L'): GAU
246
+
247
+ return x
248
+
249
+ def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
250
+ """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
251
+ Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
252
+ Args:
253
+ x (Tensor): the input sequence (batch, in_channels, in_length).
254
+ kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
255
+ bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
256
+ dilation (int): the dilation of convolution.
257
+ hop_size (int): the hop_size of the conditioning sequence.
258
+ Returns:
259
+ (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
260
+ """
261
+ batch, _, in_length = x.shape
262
+ batch, _, out_channels, kernel_size, kernel_length = kernel.shape
263
+ assert in_length == (
264
+ kernel_length * hop_size
265
+ ), "length of (x, kernel) is not matched"
266
+
267
+ padding = dilation * int((kernel_size - 1) / 2)
268
+ x = F.pad(
269
+ x, (padding, padding), "constant", 0
270
+ ) # (batch, in_channels, in_length + 2*padding)
271
+ x = x.unfold(
272
+ 2, hop_size + 2 * padding, hop_size
273
+ ) # (batch, in_channels, kernel_length, hop_size + 2*padding)
274
+
275
+ if hop_size < dilation:
276
+ x = F.pad(x, (0, dilation), "constant", 0)
277
+ x = x.unfold(
278
+ 3, dilation, dilation
279
+ ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
280
+ x = x[:, :, :, :, :hop_size]
281
+ x = x.transpose(
282
+ 3, 4
283
+ ) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
284
+ x = x.unfold(
285
+ 4, kernel_size, 1
286
+ ) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
287
+
288
+ o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
289
+ o = o.to(memory_format=torch.channels_last_3d)
290
+ bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
291
+ o = o + bias
292
+ o = o.contiguous().view(batch, out_channels, -1)
293
+
294
+ return o
295
+
296
+ def remove_weight_norm(self):
297
+ self.kernel_predictor.remove_weight_norm()
298
+ nn.utils.remove_weight_norm(self.convt_pre[1])
299
+ for block in self.conv_blocks:
300
+ nn.utils.remove_weight_norm(block[1])
301
+
302
+
303
+ class UnivNetGenerator(nn.Module):
304
+ """
305
+ UnivNet Generator
306
+
307
+ Originally from https://github.com/mindslab-ai/univnet/blob/master/model/generator.py.
308
+ """
309
+
310
+ def __init__(
311
+ self,
312
+ noise_dim=64,
313
+ channel_size=32,
314
+ dilations=[1, 3, 9, 27],
315
+ strides=[8, 8, 4],
316
+ lReLU_slope=0.2,
317
+ kpnet_conv_size=3,
318
+ # Below are MEL configurations options that this generator requires.
319
+ hop_length=256,
320
+ n_mel_channels=100,
321
+ ):
322
+ super(UnivNetGenerator, self).__init__()
323
+ self.mel_channel = n_mel_channels
324
+ self.noise_dim = noise_dim
325
+ self.hop_length = hop_length
326
+ channel_size = channel_size
327
+ kpnet_conv_size = kpnet_conv_size
328
+
329
+ self.res_stack = nn.ModuleList()
330
+ hop_length = 1
331
+ for stride in strides:
332
+ hop_length = stride * hop_length
333
+ self.res_stack.append(
334
+ LVCBlock(
335
+ channel_size,
336
+ n_mel_channels,
337
+ stride=stride,
338
+ dilations=dilations,
339
+ lReLU_slope=lReLU_slope,
340
+ cond_hop_length=hop_length,
341
+ kpnet_conv_size=kpnet_conv_size,
342
+ )
343
+ )
344
+
345
+ self.conv_pre = nn.utils.weight_norm(
346
+ nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")
347
+ )
348
+
349
+ self.conv_post = nn.Sequential(
350
+ nn.LeakyReLU(lReLU_slope),
351
+ nn.utils.weight_norm(
352
+ nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")
353
+ ),
354
+ nn.Tanh(),
355
+ )
356
+
357
+ def forward(self, c, z):
358
+ """
359
+ Args:
360
+ c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
361
+ z (Tensor): the noise sequence (batch, noise_dim, in_length)
362
+
363
+ """
364
+ z = self.conv_pre(z) # (B, c_g, L)
365
+
366
+ for res_block in self.res_stack:
367
+ res_block.to(z.device)
368
+ z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
369
+
370
+ z = self.conv_post(z) # (B, 1, L * 256)
371
+
372
+ return z
373
+
374
+ def eval(self, inference=False):
375
+ super(UnivNetGenerator, self).eval()
376
+ # don't remove weight norm while validation in training loop
377
+ if inference:
378
+ self.remove_weight_norm()
379
+
380
+ def remove_weight_norm(self):
381
+ nn.utils.remove_weight_norm(self.conv_pre)
382
+
383
+ for layer in self.conv_post:
384
+ if len(layer.state_dict()) != 0:
385
+ nn.utils.remove_weight_norm(layer)
386
+
387
+ for res_block in self.res_stack:
388
+ res_block.remove_weight_norm()
389
+
390
+ def inference(self, c, z=None):
391
+ # pad input mel with zeros to cut artifact
392
+ # see https://github.com/seungwonpark/melgan/issues/8
393
+ zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
394
+ mel = torch.cat((c, zero), dim=2)
395
+
396
+ if z is None:
397
+ z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
398
+
399
+ audio = self.forward(mel, z)
400
+ audio = audio[:, :, : -(self.hop_length * 10)]
401
+ audio = audio.clamp(min=-1, max=1)
402
+ return audio
403
+
404
+ from pathlib import Path
405
+ STATIC_DIR = Path(__file__).parent.parent.parent/'static'
406
+ assert STATIC_DIR.is_dir()
407
+ def BVGWithConf(fname: str):
408
+ json_config = json.loads(
409
+ (STATIC_DIR/fname).read_text()
410
+ )
411
+ return lambda: BVGModel(AttrDict(json_config))
412
+
413
+ @dataclass
414
+ class VocType:
415
+ constructor: Callable[[], nn.Module]
416
+ model_path: str
417
+ subkey: Optional[str] = None
418
+ def optionally_index(self, model_dict):
419
+ if self.subkey is not None:
420
+ return model_dict[self.subkey]
421
+ return model_dict
422
+ class VocConf(Enum):
423
+ Univnet = VocType(UnivNetGenerator, "vocoder.pth", 'model_g')
424
+ BigVGAN_Base = VocType(BVGWithConf("bigvgan_base_24khz_100band_config.json"), "bigvgan_base_24khz_100band_g.pth", 'generator')
425
+ BigVGAN = VocType(BVGWithConf("bigvgan_24khz_100band_config.json"), "bigvgan_24khz_100band_g.pth", 'generator')
426
+
427
+
428
+ if __name__ == "__main__":
429
+ model = UnivNetGenerator()
430
+
431
+ c = torch.randn(3, 100, 10)
432
+ z = torch.randn(3, 64, 10)
433
+ print(c.shape)
434
+
435
+ y = model(c, z)
436
+ print(y.shape)
437
+ assert y.shape == torch.Size([3, 1, 2560])
438
+
439
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
440
+ print(pytorch_total_params)
tortoise/models/xtransformers.py ADDED
@@ -0,0 +1,1436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import namedtuple
3
+ from functools import partial
4
+ from inspect import isfunction
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+ from torch import einsum, nn
10
+
11
+ DEFAULT_DIM_HEAD = 64
12
+
13
+ Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
14
+
15
+ LayerIntermediates = namedtuple(
16
+ "Intermediates",
17
+ [
18
+ "hiddens",
19
+ "attn_intermediates",
20
+ "past_key_values",
21
+ ],
22
+ )
23
+
24
+
25
+ # helpers
26
+
27
+
28
+ def exists(val):
29
+ return val is not None
30
+
31
+
32
+ def default(val, d):
33
+ if exists(val):
34
+ return val
35
+ return d() if isfunction(d) else d
36
+
37
+
38
+ def cast_tuple(val, depth):
39
+ return val if isinstance(val, tuple) else (val,) * depth
40
+
41
+
42
+ class always:
43
+ def __init__(self, val):
44
+ self.val = val
45
+
46
+ def __call__(self, *args, **kwargs):
47
+ return self.val
48
+
49
+
50
+ class not_equals:
51
+ def __init__(self, val):
52
+ self.val = val
53
+
54
+ def __call__(self, x, *args, **kwargs):
55
+ return x != self.val
56
+
57
+
58
+ class equals:
59
+ def __init__(self, val):
60
+ self.val = val
61
+
62
+ def __call__(self, x, *args, **kwargs):
63
+ return x == self.val
64
+
65
+
66
+ def max_neg_value(tensor):
67
+ return -torch.finfo(tensor.dtype).max
68
+
69
+
70
+ def l2norm(t):
71
+ return F.normalize(t, p=2, dim=-1)
72
+
73
+
74
+ # init helpers
75
+
76
+
77
+ def init_zero_(layer):
78
+ nn.init.constant_(layer.weight, 0.0)
79
+ if exists(layer.bias):
80
+ nn.init.constant_(layer.bias, 0.0)
81
+
82
+
83
+ # keyword argument helpers
84
+
85
+
86
+ def pick_and_pop(keys, d):
87
+ values = list(map(lambda key: d.pop(key), keys))
88
+ return dict(zip(keys, values))
89
+
90
+
91
+ def group_dict_by_key(cond, d):
92
+ return_val = [dict(), dict()]
93
+ for key in d.keys():
94
+ match = bool(cond(key))
95
+ ind = int(not match)
96
+ return_val[ind][key] = d[key]
97
+ return (*return_val,)
98
+
99
+
100
+ def string_begins_with(prefix, str):
101
+ return str.startswith(prefix)
102
+
103
+
104
+ def group_by_key_prefix(prefix, d):
105
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
106
+
107
+
108
+ def groupby_prefix_and_trim(prefix, d):
109
+ kwargs_with_prefix, kwargs = group_dict_by_key(
110
+ partial(string_begins_with, prefix), d
111
+ )
112
+ kwargs_without_prefix = dict(
113
+ map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))
114
+ )
115
+ return kwargs_without_prefix, kwargs
116
+
117
+
118
+ # activations
119
+
120
+
121
+ class ReluSquared(nn.Module):
122
+ def forward(self, x):
123
+ return F.relu(x) ** 2
124
+
125
+
126
+ # positional embeddings
127
+
128
+
129
+ class AbsolutePositionalEmbedding(nn.Module):
130
+ def __init__(self, dim, max_seq_len):
131
+ super().__init__()
132
+ self.scale = dim**-0.5
133
+ self.emb = nn.Embedding(max_seq_len, dim)
134
+
135
+ def forward(self, x):
136
+ n = torch.arange(x.shape[1], device=x.device)
137
+ pos_emb = self.emb(n)
138
+ pos_emb = rearrange(pos_emb, "n d -> () n d")
139
+ return pos_emb * self.scale
140
+
141
+
142
+ class FixedPositionalEmbedding(nn.Module):
143
+ def __init__(self, dim):
144
+ super().__init__()
145
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
146
+ self.register_buffer("inv_freq", inv_freq)
147
+
148
+ def forward(self, x, seq_dim=1, offset=0):
149
+ t = (
150
+ torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
151
+ + offset
152
+ )
153
+ sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq)
154
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
155
+ return rearrange(emb, "n d -> () n d")
156
+
157
+
158
+ class RelativePositionBias(nn.Module):
159
+ def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
160
+ super().__init__()
161
+ self.scale = scale
162
+ self.causal = causal
163
+ self.num_buckets = num_buckets
164
+ self.max_distance = max_distance
165
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
166
+
167
+ @staticmethod
168
+ def _relative_position_bucket(
169
+ relative_position, causal=True, num_buckets=32, max_distance=128
170
+ ):
171
+ ret = 0
172
+ n = -relative_position
173
+ if not causal:
174
+ num_buckets //= 2
175
+ ret += (n < 0).long() * num_buckets
176
+ n = torch.abs(n)
177
+ else:
178
+ n = torch.max(n, torch.zeros_like(n))
179
+
180
+ max_exact = num_buckets // 2
181
+ is_small = n < max_exact
182
+
183
+ val_if_large = (
184
+ max_exact
185
+ + (
186
+ torch.log(n.float() / max_exact)
187
+ / math.log(max_distance / max_exact)
188
+ * (num_buckets - max_exact)
189
+ ).long()
190
+ )
191
+ val_if_large = torch.min(
192
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1)
193
+ )
194
+
195
+ ret += torch.where(is_small, n, val_if_large)
196
+ return ret
197
+
198
+ def forward(self, qk_dots):
199
+ i, j, device = *qk_dots.shape[-2:], qk_dots.device
200
+ q_pos = torch.arange(i, dtype=torch.long, device=device)
201
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
202
+ rel_pos = k_pos[None, :] - q_pos[:, None]
203
+ rp_bucket = self._relative_position_bucket(
204
+ rel_pos,
205
+ causal=self.causal,
206
+ num_buckets=self.num_buckets,
207
+ max_distance=self.max_distance,
208
+ )
209
+ values = self.relative_attention_bias(rp_bucket)
210
+ bias = rearrange(values, "i j h -> () h i j")
211
+ return qk_dots + (bias * self.scale)
212
+
213
+
214
+ class AlibiPositionalBias(nn.Module):
215
+ def __init__(self, heads, **kwargs):
216
+ super().__init__()
217
+ self.heads = heads
218
+ slopes = torch.Tensor(self._get_slopes(heads))
219
+ slopes = rearrange(slopes, "h -> () h () ()")
220
+ self.register_buffer("slopes", slopes, persistent=False)
221
+ self.register_buffer("bias", None, persistent=False)
222
+
223
+ @staticmethod
224
+ def _get_slopes(heads):
225
+ def get_slopes_power_of_2(n):
226
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
227
+ ratio = start
228
+ return [start * ratio**i for i in range(n)]
229
+
230
+ if math.log2(heads).is_integer():
231
+ return get_slopes_power_of_2(heads)
232
+
233
+ closest_power_of_2 = 2 ** math.floor(math.log2(heads))
234
+ return (
235
+ get_slopes_power_of_2(closest_power_of_2)
236
+ + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
237
+ : heads - closest_power_of_2
238
+ ]
239
+ )
240
+
241
+ def forward(self, qk_dots):
242
+ h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
243
+
244
+ if exists(self.bias) and self.bias.shape[-1] >= j:
245
+ return qk_dots + self.bias[..., :j]
246
+
247
+ bias = torch.arange(j, device=device)
248
+ bias = rearrange(bias, "j -> () () () j")
249
+ bias = bias * self.slopes
250
+
251
+ num_heads_unalibied = h - bias.shape[1]
252
+ bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
253
+
254
+ self.register_buffer("bias", bias, persistent=False)
255
+ return qk_dots + self.bias
256
+
257
+
258
+ class LearnedAlibiPositionalBias(AlibiPositionalBias):
259
+ def __init__(self, heads, bidirectional=False):
260
+ super().__init__(heads)
261
+ los_slopes = torch.log(self.slopes)
262
+ self.learned_logslopes = nn.Parameter(los_slopes)
263
+
264
+ self.bidirectional = bidirectional
265
+ if self.bidirectional:
266
+ self.learned_logslopes_future = nn.Parameter(los_slopes)
267
+
268
+ def forward(self, qk_dots):
269
+ h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
270
+
271
+ def get_slopes(param):
272
+ return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[1]))
273
+
274
+ if exists(self.bias) and self.bias.shape[-1] >= j:
275
+ bias = self.bias[..., :i, :j]
276
+ else:
277
+ i_arange = torch.arange(i, device=device)
278
+ j_arange = torch.arange(j, device=device)
279
+ bias = rearrange(j_arange, "j -> 1 1 1 j") - rearrange(
280
+ i_arange, "i -> 1 1 i 1"
281
+ )
282
+ self.register_buffer("bias", bias, persistent=False)
283
+
284
+ if self.bidirectional:
285
+ past_slopes = get_slopes(self.learned_logslopes)
286
+ future_slopes = get_slopes(self.learned_logslopes_future)
287
+ bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes)
288
+ else:
289
+ slopes = get_slopes(self.learned_logslopes)
290
+ bias = bias * slopes
291
+
292
+ return qk_dots + bias
293
+
294
+
295
+ class RotaryEmbedding(nn.Module):
296
+ def __init__(self, dim):
297
+ super().__init__()
298
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
299
+ self.register_buffer("inv_freq", inv_freq)
300
+
301
+ def forward(self, max_seq_len, device):
302
+ t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq)
303
+ freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
304
+ emb = torch.cat((freqs, freqs), dim=-1)
305
+ return rearrange(emb, "n d -> () () n d")
306
+
307
+
308
+ def rotate_half(x):
309
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
310
+ x1, x2 = x.unbind(dim=-2)
311
+ return torch.cat((-x2, x1), dim=-1)
312
+
313
+
314
+ def apply_rotary_pos_emb(t, freqs):
315
+ seq_len = t.shape[-2]
316
+ freqs = freqs[:, :, -seq_len:]
317
+ return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
318
+
319
+
320
+ # norms
321
+
322
+
323
+ class Scale(nn.Module):
324
+ def __init__(self, value, fn):
325
+ super().__init__()
326
+ self.value = value
327
+ self.fn = fn
328
+
329
+ def forward(self, x, **kwargs):
330
+ out = self.fn(x, **kwargs)
331
+
332
+ def scale_fn(t):
333
+ return t * self.value
334
+
335
+ if not isinstance(out, tuple):
336
+ return scale_fn(out)
337
+
338
+ return (scale_fn(out[0]), *out[1:])
339
+
340
+
341
+ class Rezero(nn.Module):
342
+ def __init__(self, fn):
343
+ super().__init__()
344
+ self.fn = fn
345
+ self.g = nn.Parameter(torch.zeros(1))
346
+
347
+ def forward(self, x, **kwargs):
348
+ out = self.fn(x, **kwargs)
349
+
350
+ def rezero_fn(t):
351
+ return t * self.g
352
+
353
+ if not isinstance(out, tuple):
354
+ return rezero_fn(out)
355
+
356
+ return (rezero_fn(out[0]), *out[1:])
357
+
358
+
359
+ class ScaleNorm(nn.Module):
360
+ def __init__(self, dim, eps=1e-5):
361
+ super().__init__()
362
+ self.scale = dim**-0.5
363
+ self.eps = eps
364
+ self.g = nn.Parameter(torch.ones(1))
365
+
366
+ def forward(self, x):
367
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
368
+ return x / norm.clamp(min=self.eps) * self.g
369
+
370
+
371
+ class RMSNorm(nn.Module):
372
+ def __init__(self, dim, eps=1e-8):
373
+ super().__init__()
374
+ self.scale = dim**-0.5
375
+ self.eps = eps
376
+ self.g = nn.Parameter(torch.ones(dim))
377
+
378
+ def forward(self, x):
379
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
380
+ return x / norm.clamp(min=self.eps) * self.g
381
+
382
+
383
+ class RMSScaleShiftNorm(nn.Module):
384
+ def __init__(self, dim, eps=1e-8):
385
+ super().__init__()
386
+ self.scale = dim**-0.5
387
+ self.eps = eps
388
+ self.g = nn.Parameter(torch.ones(dim))
389
+ self.scale_shift_process = nn.Linear(dim * 2, dim * 2)
390
+
391
+ def forward(self, x, norm_scale_shift_inp):
392
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
393
+ norm = x / norm.clamp(min=self.eps) * self.g
394
+
395
+ ss_emb = self.scale_shift_process(norm_scale_shift_inp)
396
+ scale, shift = torch.chunk(ss_emb, 2, dim=1)
397
+ h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
398
+ return h
399
+
400
+
401
+ # residual and residual gates
402
+
403
+
404
+ class Residual(nn.Module):
405
+ def __init__(self, dim, scale_residual=False):
406
+ super().__init__()
407
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
408
+
409
+ def forward(self, x, residual):
410
+ if exists(self.residual_scale):
411
+ residual = residual * self.residual_scale
412
+
413
+ return x + residual
414
+
415
+
416
+ class GRUGating(nn.Module):
417
+ def __init__(self, dim, scale_residual=False):
418
+ super().__init__()
419
+ self.gru = nn.GRUCell(dim, dim)
420
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
421
+
422
+ def forward(self, x, residual):
423
+ if exists(self.residual_scale):
424
+ residual = residual * self.residual_scale
425
+
426
+ gated_output = self.gru(
427
+ rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d")
428
+ )
429
+
430
+ return gated_output.reshape_as(x)
431
+
432
+
433
+ # token shifting
434
+
435
+
436
+ def shift(t, amount, mask=None):
437
+ if amount == 0:
438
+ return t
439
+
440
+ if exists(mask):
441
+ t = t.masked_fill(~mask[..., None], 0.0)
442
+
443
+ return F.pad(t, (0, 0, amount, -amount), value=0.0)
444
+
445
+
446
+ class ShiftTokens(nn.Module):
447
+ def __init__(self, shifts, fn):
448
+ super().__init__()
449
+ self.fn = fn
450
+ self.shifts = tuple(shifts)
451
+
452
+ def forward(self, x, **kwargs):
453
+ mask = kwargs.get("mask", None)
454
+ shifts = self.shifts
455
+ segments = len(shifts)
456
+ feats_per_shift = x.shape[-1] // segments
457
+ splitted = x.split(feats_per_shift, dim=-1)
458
+ segments_to_shift, rest = splitted[:segments], splitted[segments:]
459
+ segments_to_shift = list(
460
+ map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts))
461
+ )
462
+ x = torch.cat((*segments_to_shift, *rest), dim=-1)
463
+ return self.fn(x, **kwargs)
464
+
465
+
466
+ # feedforward
467
+
468
+
469
+ class GLU(nn.Module):
470
+ def __init__(self, dim_in, dim_out, activation):
471
+ super().__init__()
472
+ self.act = activation
473
+ self.proj = nn.Linear(dim_in, dim_out * 2)
474
+
475
+ def forward(self, x):
476
+ x, gate = self.proj(x).chunk(2, dim=-1)
477
+ return x * self.act(gate)
478
+
479
+
480
+ class FeedForward(nn.Module):
481
+ def __init__(
482
+ self,
483
+ dim,
484
+ dim_out=None,
485
+ mult=4,
486
+ glu=False,
487
+ relu_squared=False,
488
+ post_act_ln=False,
489
+ dropout=0.0,
490
+ zero_init_output=False,
491
+ ):
492
+ super().__init__()
493
+ inner_dim = int(dim * mult)
494
+ dim_out = default(dim_out, dim)
495
+ activation = ReluSquared() if relu_squared else nn.GELU()
496
+
497
+ project_in = (
498
+ nn.Sequential(nn.Linear(dim, inner_dim), activation)
499
+ if not glu
500
+ else GLU(dim, inner_dim, activation)
501
+ )
502
+
503
+ self.net = nn.Sequential(
504
+ project_in,
505
+ nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
506
+ nn.Dropout(dropout),
507
+ nn.Linear(inner_dim, dim_out),
508
+ )
509
+
510
+ # init last linear layer to 0
511
+ if zero_init_output:
512
+ init_zero_(self.net[-1])
513
+
514
+ def forward(self, x):
515
+ return self.net(x)
516
+
517
+
518
+ # attention.
519
+
520
+
521
+ class Attention(nn.Module):
522
+ def __init__(
523
+ self,
524
+ dim,
525
+ dim_head=DEFAULT_DIM_HEAD,
526
+ heads=8,
527
+ causal=False,
528
+ talking_heads=False,
529
+ head_scale=False,
530
+ collab_heads=False,
531
+ collab_compression=0.3,
532
+ sparse_topk=None,
533
+ use_entmax15=False,
534
+ num_mem_kv=0,
535
+ dropout=0.0,
536
+ on_attn=False,
537
+ gate_values=False,
538
+ zero_init_output=False,
539
+ max_attend_past=None,
540
+ qk_norm=False,
541
+ scale_init_value=None,
542
+ rel_pos_bias=False,
543
+ rel_pos_num_buckets=32,
544
+ rel_pos_max_distance=128,
545
+ ):
546
+ super().__init__()
547
+ self.scale = dim_head**-0.5
548
+
549
+ self.heads = heads
550
+ self.causal = causal
551
+ self.max_attend_past = max_attend_past
552
+
553
+ qk_dim = v_dim = dim_head * heads
554
+
555
+ # collaborative heads
556
+ self.collab_heads = collab_heads
557
+ if self.collab_heads:
558
+ qk_dim = int(collab_compression * qk_dim)
559
+ self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
560
+
561
+ self.to_q = nn.Linear(dim, qk_dim, bias=False)
562
+ self.to_k = nn.Linear(dim, qk_dim, bias=False)
563
+ self.to_v = nn.Linear(dim, v_dim, bias=False)
564
+
565
+ self.dropout = nn.Dropout(dropout)
566
+
567
+ # add GLU gating for aggregated values, from alphafold2
568
+ self.to_v_gate = None
569
+ if gate_values:
570
+ self.to_v_gate = nn.Linear(dim, v_dim)
571
+ nn.init.constant_(self.to_v_gate.weight, 0)
572
+ nn.init.constant_(self.to_v_gate.bias, 1)
573
+
574
+ # cosine sim attention
575
+ self.qk_norm = qk_norm
576
+ if qk_norm:
577
+ scale_init_value = default(
578
+ scale_init_value, -3
579
+ ) # if not provided, initialize as though it were sequence length of 1024
580
+ self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
581
+
582
+ # talking heads
583
+ self.talking_heads = talking_heads
584
+ if talking_heads:
585
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
586
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
587
+
588
+ # head scaling
589
+ self.head_scale = head_scale
590
+ if head_scale:
591
+ self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
592
+
593
+ # explicit topk sparse attention
594
+ self.sparse_topk = sparse_topk
595
+
596
+ # entmax
597
+ self.attn_fn = F.softmax
598
+
599
+ # add memory key / values
600
+ self.num_mem_kv = num_mem_kv
601
+ if num_mem_kv > 0:
602
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
603
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
604
+
605
+ # attention on attention
606
+ self.attn_on_attn = on_attn
607
+ self.to_out = (
608
+ nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU())
609
+ if on_attn
610
+ else nn.Linear(v_dim, dim)
611
+ )
612
+
613
+ self.rel_pos_bias = rel_pos_bias
614
+ if rel_pos_bias:
615
+ assert (
616
+ rel_pos_num_buckets <= rel_pos_max_distance
617
+ ), "number of relative position buckets must be less than the relative position max distance"
618
+ self.rel_pos = RelativePositionBias(
619
+ scale=dim_head**0.5,
620
+ causal=causal,
621
+ heads=heads,
622
+ num_buckets=rel_pos_num_buckets,
623
+ max_distance=rel_pos_max_distance,
624
+ )
625
+
626
+ # init output projection 0
627
+ if zero_init_output:
628
+ init_zero_(self.to_out)
629
+
630
+ def forward(
631
+ self,
632
+ x,
633
+ context=None,
634
+ mask=None,
635
+ context_mask=None,
636
+ attn_mask=None,
637
+ sinusoidal_emb=None,
638
+ rotary_pos_emb=None,
639
+ prev_attn=None,
640
+ mem=None,
641
+ layer_past=None,
642
+ ):
643
+ (
644
+ b,
645
+ n,
646
+ _,
647
+ h,
648
+ talking_heads,
649
+ collab_heads,
650
+ head_scale,
651
+ scale,
652
+ device,
653
+ has_context,
654
+ ) = (
655
+ *x.shape,
656
+ self.heads,
657
+ self.talking_heads,
658
+ self.collab_heads,
659
+ self.head_scale,
660
+ self.scale,
661
+ x.device,
662
+ exists(context),
663
+ )
664
+ kv_input = default(context, x)
665
+
666
+ q_input = x
667
+ k_input = kv_input
668
+ v_input = kv_input
669
+
670
+ if exists(mem):
671
+ k_input = torch.cat((mem, k_input), dim=-2)
672
+ v_input = torch.cat((mem, v_input), dim=-2)
673
+
674
+ if exists(sinusoidal_emb):
675
+ # in shortformer, the query would start at a position offset depending on the past cached memory
676
+ offset = k_input.shape[-2] - q_input.shape[-2]
677
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
678
+ k_input = k_input + sinusoidal_emb(k_input)
679
+
680
+ q = self.to_q(q_input)
681
+ k = self.to_k(k_input)
682
+ v = self.to_v(v_input)
683
+
684
+ if not collab_heads:
685
+ q, k, v = map(
686
+ lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)
687
+ )
688
+ else:
689
+ q = einsum("b i d, h d -> b h i d", q, self.collab_mixing)
690
+ k = rearrange(k, "b n d -> b () n d")
691
+ v = rearrange(v, "b n (h d) -> b h n d", h=h)
692
+
693
+ if layer_past is not None:
694
+ past_key, past_value = layer_past
695
+ k = torch.cat([past_key, k], dim=-2)
696
+ v = torch.cat([past_value, v], dim=-2)
697
+ k_cache = k
698
+ v_cache = v
699
+
700
+ if exists(rotary_pos_emb) and not has_context:
701
+ l = rotary_pos_emb.shape[-1]
702
+ (ql, qr), (kl, kr), (vl, vr) = map(
703
+ lambda t: (t[..., :l], t[..., l:]), (q, k, v)
704
+ )
705
+ ql, kl, vl = map(
706
+ lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)
707
+ )
708
+ q, k, v = map(
709
+ lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))
710
+ )
711
+
712
+ input_mask = None
713
+ if any(map(exists, (mask, context_mask))):
714
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
715
+ k_mask = q_mask if not exists(context) else context_mask
716
+ k_mask = default(
717
+ k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()
718
+ )
719
+ q_mask = rearrange(q_mask, "b i -> b () i ()")
720
+ k_mask = rearrange(k_mask, "b j -> b () () j")
721
+ input_mask = q_mask * k_mask
722
+
723
+ if self.num_mem_kv > 0:
724
+ mem_k, mem_v = map(
725
+ lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v)
726
+ )
727
+ k = torch.cat((mem_k, k), dim=-2)
728
+ v = torch.cat((mem_v, v), dim=-2)
729
+ if exists(input_mask):
730
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
731
+
732
+ if collab_heads:
733
+ k = k.expand(-1, h, -1, -1)
734
+
735
+ if self.qk_norm:
736
+ q, k = map(l2norm, (q, k))
737
+ scale = 1 / (self.scale.exp().clamp(min=1e-2))
738
+
739
+ dots = einsum("b h i d, b h j d -> b h i j", q, k) * scale
740
+ mask_value = max_neg_value(dots)
741
+
742
+ if exists(prev_attn):
743
+ dots = dots + prev_attn
744
+
745
+ pre_softmax_attn = dots.clone()
746
+
747
+ if talking_heads:
748
+ dots = einsum(
749
+ "b h i j, h k -> b k i j", dots, self.pre_softmax_proj
750
+ ).contiguous()
751
+
752
+ if self.rel_pos_bias:
753
+ dots = self.rel_pos(dots)
754
+
755
+ if exists(input_mask):
756
+ dots.masked_fill_(~input_mask, mask_value)
757
+ del input_mask
758
+
759
+ if exists(attn_mask):
760
+ assert (
761
+ 2 <= attn_mask.ndim <= 4
762
+ ), "attention mask must have greater than 2 dimensions but less than or equal to 4"
763
+ if attn_mask.ndim == 2:
764
+ attn_mask = rearrange(attn_mask, "i j -> () () i j")
765
+ elif attn_mask.ndim == 3:
766
+ attn_mask = rearrange(attn_mask, "h i j -> () h i j")
767
+ dots.masked_fill_(~attn_mask, mask_value)
768
+
769
+ if exists(self.max_attend_past):
770
+ i, j = dots.shape[-2:]
771
+ range_q = torch.arange(j - i, j, device=device)
772
+ range_k = torch.arange(j, device=device)
773
+ dist = rearrange(range_q, "i -> () () i ()") - rearrange(
774
+ range_k, "j -> () () () j"
775
+ )
776
+ mask = dist > self.max_attend_past
777
+ dots.masked_fill_(mask, mask_value)
778
+ del mask
779
+
780
+ if self.causal:
781
+ i, j = dots.shape[-2:]
782
+ r = torch.arange(i, device=device)
783
+ mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
784
+ mask = F.pad(mask, (j - i, 0), value=False)
785
+ dots.masked_fill_(mask, mask_value)
786
+ del mask
787
+
788
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
789
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
790
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
791
+ mask = dots < vk
792
+ dots.masked_fill_(mask, mask_value)
793
+ del mask
794
+
795
+ attn = self.attn_fn(dots, dim=-1)
796
+ post_softmax_attn = attn.clone()
797
+
798
+ attn = self.dropout(attn)
799
+
800
+ if talking_heads:
801
+ attn = einsum(
802
+ "b h i j, h k -> b k i j", attn, self.post_softmax_proj
803
+ ).contiguous()
804
+
805
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
806
+
807
+ if head_scale:
808
+ out = out * self.head_scale_params
809
+
810
+ out = rearrange(out, "b h n d -> b n (h d)")
811
+
812
+ if exists(self.to_v_gate):
813
+ gates = self.to_v_gate(x)
814
+ out = out * gates.sigmoid()
815
+
816
+ intermediates = Intermediates(
817
+ pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn
818
+ )
819
+
820
+ return self.to_out(out), intermediates, k_cache, v_cache
821
+
822
+
823
+ class AttentionLayers(nn.Module):
824
+ def __init__(
825
+ self,
826
+ dim,
827
+ depth,
828
+ heads=8,
829
+ causal=False,
830
+ cross_attend=False,
831
+ only_cross=False,
832
+ use_scalenorm=False,
833
+ use_rms_scaleshift_norm=False,
834
+ use_rmsnorm=False,
835
+ use_rezero=False,
836
+ alibi_pos_bias=False,
837
+ alibi_num_heads=None,
838
+ alibi_learned=False,
839
+ position_infused_attn=False,
840
+ rotary_pos_emb=False,
841
+ rotary_emb_dim=None,
842
+ custom_layers=None,
843
+ sandwich_coef=None,
844
+ par_ratio=None,
845
+ residual_attn=False,
846
+ cross_residual_attn=False,
847
+ macaron=False,
848
+ pre_norm=True,
849
+ gate_residual=False,
850
+ scale_residual=False,
851
+ shift_tokens=0,
852
+ sandwich_norm=False,
853
+ use_qk_norm_attn=False,
854
+ qk_norm_attn_seq_len=None,
855
+ zero_init_branch_output=False,
856
+ **kwargs,
857
+ ):
858
+ super().__init__()
859
+ ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs)
860
+ attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs)
861
+
862
+ dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD)
863
+
864
+ self.dim = dim
865
+ self.depth = depth
866
+ self.layers = nn.ModuleList([])
867
+ self.causal = causal
868
+
869
+ rel_pos_bias = "rel_pos_bias" in attn_kwargs
870
+ self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
871
+ self.pia_pos_emb = (
872
+ FixedPositionalEmbedding(dim) if position_infused_attn else None
873
+ )
874
+
875
+ rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
876
+ self.rotary_pos_emb = (
877
+ RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
878
+ )
879
+
880
+ assert not (
881
+ alibi_pos_bias and rel_pos_bias
882
+ ), "you can only choose Alibi positional bias or T5 relative positional bias, not both"
883
+
884
+ if alibi_pos_bias:
885
+ alibi_num_heads = default(alibi_num_heads, heads)
886
+ assert (
887
+ alibi_num_heads <= heads
888
+ ), "number of ALiBi heads must be less than the total number of heads"
889
+ alibi_pos_klass = (
890
+ LearnedAlibiPositionalBias
891
+ if alibi_learned or not causal
892
+ else AlibiPositionalBias
893
+ )
894
+ self.rel_pos = alibi_pos_klass(
895
+ heads=alibi_num_heads, bidirectional=not causal
896
+ )
897
+ else:
898
+ self.rel_pos = None
899
+
900
+ assert not (
901
+ not pre_norm and sandwich_norm
902
+ ), "sandwich norm cannot be used when not using prenorm"
903
+ self.pre_norm = pre_norm
904
+ self.sandwich_norm = sandwich_norm
905
+
906
+ self.residual_attn = residual_attn
907
+ self.cross_residual_attn = cross_residual_attn
908
+ self.cross_attend = cross_attend
909
+
910
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
911
+ norm_class = RMSNorm if use_rmsnorm else norm_class
912
+ norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class
913
+ norm_fn = partial(norm_class, dim)
914
+
915
+ norm_fn = nn.Identity if use_rezero else norm_fn
916
+ branch_fn = Rezero if use_rezero else None
917
+
918
+ if cross_attend and not only_cross:
919
+ default_block = ("a", "c", "f")
920
+ elif cross_attend and only_cross:
921
+ default_block = ("c", "f")
922
+ else:
923
+ default_block = ("a", "f")
924
+
925
+ if macaron:
926
+ default_block = ("f",) + default_block
927
+
928
+ # qk normalization
929
+
930
+ if use_qk_norm_attn:
931
+ attn_scale_init_value = (
932
+ -math.log(math.log2(qk_norm_attn_seq_len**2 - qk_norm_attn_seq_len))
933
+ if exists(qk_norm_attn_seq_len)
934
+ else None
935
+ )
936
+ attn_kwargs = {
937
+ **attn_kwargs,
938
+ "qk_norm": True,
939
+ "scale_init_value": attn_scale_init_value,
940
+ }
941
+
942
+ # zero init
943
+
944
+ if zero_init_branch_output:
945
+ attn_kwargs = {**attn_kwargs, "zero_init_output": True}
946
+ ff_kwargs = {**ff_kwargs, "zero_init_output": True}
947
+
948
+ # calculate layer block order
949
+
950
+ if exists(custom_layers):
951
+ layer_types = custom_layers
952
+ elif exists(par_ratio):
953
+ par_depth = depth * len(default_block)
954
+ assert 1 < par_ratio <= par_depth, "par ratio out of range"
955
+ default_block = tuple(filter(not_equals("f"), default_block))
956
+ par_attn = par_depth // par_ratio
957
+ depth_cut = (
958
+ par_depth * 2 // 3
959
+ ) # 2 / 3 attention layer cutoff suggested by PAR paper
960
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
961
+ assert (
962
+ len(default_block) <= par_width
963
+ ), "default block is too large for par_ratio"
964
+ par_block = default_block + ("f",) * (par_width - len(default_block))
965
+ par_head = par_block * par_attn
966
+ layer_types = par_head + ("f",) * (par_depth - len(par_head))
967
+ elif exists(sandwich_coef):
968
+ assert (
969
+ sandwich_coef > 0 and sandwich_coef <= depth
970
+ ), "sandwich coefficient should be less than the depth"
971
+ layer_types = (
972
+ ("a",) * sandwich_coef
973
+ + default_block * (depth - sandwich_coef)
974
+ + ("f",) * sandwich_coef
975
+ )
976
+ else:
977
+ layer_types = default_block * depth
978
+
979
+ self.layer_types = layer_types
980
+ self.num_attn_layers = len(list(filter(equals("a"), layer_types)))
981
+
982
+ # calculate token shifting
983
+
984
+ shift_tokens = cast_tuple(shift_tokens, len(layer_types))
985
+
986
+ # iterate and construct layers
987
+
988
+ for ind, (layer_type, layer_shift_tokens) in enumerate(
989
+ zip(self.layer_types, shift_tokens)
990
+ ):
991
+ is_last_layer = ind == (len(self.layer_types) - 1)
992
+
993
+ if layer_type == "a":
994
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
995
+ elif layer_type == "c":
996
+ layer = Attention(dim, heads=heads, **attn_kwargs)
997
+ elif layer_type == "f":
998
+ layer = FeedForward(dim, **ff_kwargs)
999
+ layer = layer if not macaron else Scale(0.5, layer)
1000
+ else:
1001
+ raise Exception(f"invalid layer type {layer_type}")
1002
+
1003
+ if layer_shift_tokens > 0:
1004
+ shift_range_upper = layer_shift_tokens + 1
1005
+ shift_range_lower = -layer_shift_tokens if not causal else 0
1006
+ layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
1007
+
1008
+ if exists(branch_fn):
1009
+ layer = branch_fn(layer)
1010
+
1011
+ residual_fn = GRUGating if gate_residual else Residual
1012
+ residual = residual_fn(dim, scale_residual=scale_residual)
1013
+
1014
+ layer_uses_qk_norm = use_qk_norm_attn and layer_type in ("a", "c")
1015
+
1016
+ pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None
1017
+ post_branch_norm = (
1018
+ norm_fn() if sandwich_norm or layer_uses_qk_norm else None
1019
+ )
1020
+ post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None
1021
+
1022
+ norms = nn.ModuleList([pre_branch_norm, post_branch_norm, post_main_norm])
1023
+
1024
+ self.layers.append(nn.ModuleList([norms, layer, residual]))
1025
+
1026
+ def forward(
1027
+ self,
1028
+ x,
1029
+ context=None,
1030
+ full_context=None, # for passing a list of hidden states from an encoder
1031
+ mask=None,
1032
+ context_mask=None,
1033
+ attn_mask=None,
1034
+ mems=None,
1035
+ return_hiddens=False,
1036
+ norm_scale_shift_inp=None,
1037
+ past_key_values=None,
1038
+ expected_seq_len=None,
1039
+ ):
1040
+
1041
+ assert not (
1042
+ self.cross_attend ^ (exists(context) or exists(full_context))
1043
+ ), "context must be passed in if cross_attend is set to True"
1044
+ assert (
1045
+ context is None or full_context is None
1046
+ ), "only one of full_context or context can be provided"
1047
+
1048
+ hiddens = []
1049
+ intermediates = []
1050
+ prev_attn = None
1051
+ prev_cross_attn = None
1052
+
1053
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
1054
+ norm_args = {}
1055
+ if exists(norm_scale_shift_inp):
1056
+ norm_args["norm_scale_shift_inp"] = norm_scale_shift_inp
1057
+
1058
+ rotary_pos_emb = None
1059
+ if exists(self.rotary_pos_emb):
1060
+ if not self.training and self.causal:
1061
+ assert (
1062
+ expected_seq_len is not None
1063
+ ), "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
1064
+ elif expected_seq_len is None:
1065
+ expected_seq_len = 0
1066
+ seq_len = x.shape[1]
1067
+ if past_key_values is not None:
1068
+ seq_len += past_key_values[0][0].shape[-2]
1069
+ max_rotary_emb_length = max(
1070
+ list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems))
1071
+ + [expected_seq_len]
1072
+ )
1073
+ rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
1074
+
1075
+ present_key_values = []
1076
+ cross_attn_count = 0
1077
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
1078
+ zip(self.layer_types, self.layers)
1079
+ ):
1080
+ if layer_type == "a":
1081
+ layer_mem = mems.pop(0) if mems else None
1082
+
1083
+ residual = x
1084
+
1085
+ pre_branch_norm, post_branch_norm, post_main_norm = norm
1086
+
1087
+ if exists(pre_branch_norm):
1088
+ x = pre_branch_norm(x, **norm_args)
1089
+
1090
+ if layer_type == "a" or layer_type == "c":
1091
+ if past_key_values is not None:
1092
+ layer_kv = past_key_values.pop(0)
1093
+ layer_past = tuple(s.to(x.device) for s in layer_kv)
1094
+ else:
1095
+ layer_past = None
1096
+
1097
+ if layer_type == "a":
1098
+ out, inter, k, v = block(
1099
+ x,
1100
+ None,
1101
+ mask,
1102
+ None,
1103
+ attn_mask,
1104
+ self.pia_pos_emb,
1105
+ rotary_pos_emb,
1106
+ prev_attn,
1107
+ layer_mem,
1108
+ layer_past,
1109
+ )
1110
+ elif layer_type == "c":
1111
+ if exists(full_context):
1112
+ out, inter, k, v = block(
1113
+ x,
1114
+ full_context[cross_attn_count],
1115
+ mask,
1116
+ context_mask,
1117
+ None,
1118
+ None,
1119
+ None,
1120
+ prev_attn,
1121
+ None,
1122
+ layer_past,
1123
+ )
1124
+ else:
1125
+ out, inter, k, v = block(
1126
+ x,
1127
+ context,
1128
+ mask,
1129
+ context_mask,
1130
+ None,
1131
+ None,
1132
+ None,
1133
+ prev_attn,
1134
+ None,
1135
+ layer_past,
1136
+ )
1137
+ elif layer_type == "f":
1138
+ out = block(x)
1139
+
1140
+ if (
1141
+ layer_type == "a"
1142
+ or layer_type == "c"
1143
+ and present_key_values is not None
1144
+ ):
1145
+ present_key_values.append((k.detach(), v.detach()))
1146
+
1147
+ if exists(post_branch_norm):
1148
+ out = post_branch_norm(out, **norm_args)
1149
+
1150
+ x = residual_fn(out, residual)
1151
+
1152
+ if layer_type in ("a", "c"):
1153
+ intermediates.append(inter)
1154
+
1155
+ if layer_type == "a" and self.residual_attn:
1156
+ prev_attn = inter.pre_softmax_attn
1157
+ elif layer_type == "c" and self.cross_residual_attn:
1158
+ prev_cross_attn = inter.pre_softmax_attn
1159
+
1160
+ if exists(post_main_norm):
1161
+ x = post_main_norm(x, **norm_args)
1162
+
1163
+ if layer_type == "c":
1164
+ cross_attn_count += 1
1165
+
1166
+ if layer_type == "f":
1167
+ hiddens.append(x)
1168
+
1169
+ if return_hiddens:
1170
+ intermediates = LayerIntermediates(
1171
+ hiddens=hiddens,
1172
+ attn_intermediates=intermediates,
1173
+ past_key_values=present_key_values,
1174
+ )
1175
+
1176
+ return x, intermediates
1177
+
1178
+ return x
1179
+
1180
+
1181
+ class Encoder(AttentionLayers):
1182
+ def __init__(self, **kwargs):
1183
+ assert "causal" not in kwargs, "cannot set causality on encoder"
1184
+ super().__init__(causal=False, **kwargs)
1185
+
1186
+
1187
+ class Decoder(AttentionLayers):
1188
+ def __init__(self, **kwargs):
1189
+ assert "causal" not in kwargs, "cannot set causality on decoder"
1190
+ super().__init__(causal=True, **kwargs)
1191
+
1192
+
1193
+ class CrossAttender(AttentionLayers):
1194
+ def __init__(self, **kwargs):
1195
+ super().__init__(cross_attend=True, only_cross=True, **kwargs)
1196
+
1197
+
1198
+ class ViTransformerWrapper(nn.Module):
1199
+ def __init__(
1200
+ self,
1201
+ *,
1202
+ image_size,
1203
+ patch_size,
1204
+ attn_layers,
1205
+ num_classes=None,
1206
+ dropout=0.0,
1207
+ emb_dropout=0.0,
1208
+ ):
1209
+ super().__init__()
1210
+ assert isinstance(attn_layers, Encoder), "attention layers must be an Encoder"
1211
+ assert (
1212
+ image_size % patch_size == 0
1213
+ ), "image dimensions must be divisible by the patch size"
1214
+ dim = attn_layers.dim
1215
+ num_patches = (image_size // patch_size) ** 2
1216
+ patch_dim = 3 * patch_size**2
1217
+
1218
+ self.patch_size = patch_size
1219
+
1220
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
1221
+ self.patch_to_embedding = nn.Linear(patch_dim, dim)
1222
+ self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
1223
+ self.dropout = nn.Dropout(emb_dropout)
1224
+
1225
+ self.attn_layers = attn_layers
1226
+ self.norm = nn.LayerNorm(dim)
1227
+ self.mlp_head = (
1228
+ FeedForward(dim, dim_out=num_classes, dropout=dropout)
1229
+ if exists(num_classes)
1230
+ else None
1231
+ )
1232
+
1233
+ def forward(self, img, return_embeddings=False):
1234
+ p = self.patch_size
1235
+
1236
+ x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p)
1237
+ x = self.patch_to_embedding(x)
1238
+ b, n, _ = x.shape
1239
+
1240
+ cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
1241
+ x = torch.cat((cls_tokens, x), dim=1)
1242
+ x = x + self.pos_embedding[:, : (n + 1)]
1243
+ x = self.dropout(x)
1244
+
1245
+ x = self.attn_layers(x)
1246
+ x = self.norm(x)
1247
+
1248
+ if not exists(self.mlp_head) or return_embeddings:
1249
+ return x
1250
+
1251
+ return self.mlp_head(x[:, 0])
1252
+
1253
+
1254
+ class TransformerWrapper(nn.Module):
1255
+ def __init__(
1256
+ self,
1257
+ *,
1258
+ num_tokens,
1259
+ max_seq_len,
1260
+ attn_layers,
1261
+ emb_dim=None,
1262
+ max_mem_len=0.0,
1263
+ shift_mem_down=0,
1264
+ emb_dropout=0.0,
1265
+ num_memory_tokens=None,
1266
+ tie_embedding=False,
1267
+ use_pos_emb=True,
1268
+ ):
1269
+ super().__init__()
1270
+ assert isinstance(
1271
+ attn_layers, AttentionLayers
1272
+ ), "attention layers must be one of Encoder or Decoder"
1273
+
1274
+ dim = attn_layers.dim
1275
+ emb_dim = default(emb_dim, dim)
1276
+
1277
+ self.max_seq_len = max_seq_len
1278
+ self.max_mem_len = max_mem_len
1279
+ self.shift_mem_down = shift_mem_down
1280
+
1281
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
1282
+ self.pos_emb = (
1283
+ AbsolutePositionalEmbedding(emb_dim, max_seq_len)
1284
+ if (use_pos_emb and not attn_layers.has_pos_emb)
1285
+ else always(0)
1286
+ )
1287
+ self.emb_dropout = nn.Dropout(emb_dropout)
1288
+
1289
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
1290
+ self.attn_layers = attn_layers
1291
+ self.norm = nn.LayerNorm(dim)
1292
+
1293
+ self.init_()
1294
+
1295
+ self.to_logits = (
1296
+ nn.Linear(dim, num_tokens)
1297
+ if not tie_embedding
1298
+ else lambda t: t @ self.token_emb.weight.t()
1299
+ )
1300
+
1301
+ # memory tokens (like [cls]) from Memory Transformers paper
1302
+ num_memory_tokens = default(num_memory_tokens, 0)
1303
+ self.num_memory_tokens = num_memory_tokens
1304
+ if num_memory_tokens > 0:
1305
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
1306
+
1307
+ def init_(self):
1308
+ nn.init.kaiming_normal_(self.token_emb.weight)
1309
+
1310
+ def forward(
1311
+ self,
1312
+ x,
1313
+ return_embeddings=False,
1314
+ mask=None,
1315
+ return_hiddens=False,
1316
+ return_attn=False,
1317
+ mems=None,
1318
+ use_cache=False,
1319
+ **kwargs,
1320
+ ):
1321
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
1322
+ x = self.token_emb(x)
1323
+ x = x + self.pos_emb(x)
1324
+ x = self.emb_dropout(x)
1325
+
1326
+ x = self.project_emb(x)
1327
+
1328
+ if num_mem > 0:
1329
+ mem = repeat(self.memory_tokens, "n d -> b n d", b=b)
1330
+ x = torch.cat((mem, x), dim=1)
1331
+
1332
+ # auto-handle masking after appending memory tokens
1333
+ if exists(mask):
1334
+ mask = F.pad(mask, (num_mem, 0), value=True)
1335
+
1336
+ if self.shift_mem_down and exists(mems):
1337
+ mems_l, mems_r = mems[: self.shift_mem_down], mems[self.shift_mem_down :]
1338
+ mems = [*mems_r, *mems_l]
1339
+
1340
+ x, intermediates = self.attn_layers(
1341
+ x, mask=mask, mems=mems, return_hiddens=True, **kwargs
1342
+ )
1343
+ x = self.norm(x)
1344
+
1345
+ mem, x = x[:, :num_mem], x[:, num_mem:]
1346
+
1347
+ out = self.to_logits(x) if not return_embeddings else x
1348
+
1349
+ if return_hiddens:
1350
+ hiddens = intermediates.hiddens
1351
+ return out, hiddens
1352
+
1353
+ res = [out]
1354
+ if return_attn:
1355
+ attn_maps = list(
1356
+ map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)
1357
+ )
1358
+ res.append(attn_maps)
1359
+ if use_cache:
1360
+ res.append(intermediates.past_key_values)
1361
+
1362
+ if len(res) > 1:
1363
+ return tuple(res)
1364
+ return res[0]
1365
+
1366
+
1367
+ class ContinuousTransformerWrapper(nn.Module):
1368
+ def __init__(
1369
+ self,
1370
+ *,
1371
+ max_seq_len,
1372
+ attn_layers,
1373
+ dim_in=None,
1374
+ dim_out=None,
1375
+ emb_dim=None,
1376
+ emb_dropout=0.0,
1377
+ use_pos_emb=True,
1378
+ ):
1379
+ super().__init__()
1380
+ assert isinstance(
1381
+ attn_layers, AttentionLayers
1382
+ ), "attention layers must be one of Encoder or Decoder"
1383
+
1384
+ dim = attn_layers.dim
1385
+
1386
+ self.max_seq_len = max_seq_len
1387
+
1388
+ self.pos_emb = (
1389
+ AbsolutePositionalEmbedding(dim, max_seq_len)
1390
+ if (use_pos_emb and not attn_layers.has_pos_emb)
1391
+ else always(0)
1392
+ )
1393
+ self.emb_dropout = nn.Dropout(emb_dropout)
1394
+
1395
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
1396
+
1397
+ self.attn_layers = attn_layers
1398
+ self.norm = nn.LayerNorm(dim)
1399
+
1400
+ self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
1401
+
1402
+ def forward(
1403
+ self,
1404
+ x,
1405
+ return_embeddings=False,
1406
+ mask=None,
1407
+ return_attn=False,
1408
+ mems=None,
1409
+ use_cache=False,
1410
+ **kwargs,
1411
+ ):
1412
+ b, n, _, device = *x.shape, x.device
1413
+
1414
+ x = self.project_in(x)
1415
+ x = x + self.pos_emb(x)
1416
+ x = self.emb_dropout(x)
1417
+
1418
+ x, intermediates = self.attn_layers(
1419
+ x, mask=mask, mems=mems, return_hiddens=True, **kwargs
1420
+ )
1421
+ x = self.norm(x)
1422
+
1423
+ out = self.project_out(x) if not return_embeddings else x
1424
+
1425
+ res = [out]
1426
+ if return_attn:
1427
+ attn_maps = list(
1428
+ map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)
1429
+ )
1430
+ res.append(attn_maps)
1431
+ if use_cache:
1432
+ res.append(intermediates.past_key_values)
1433
+
1434
+ if len(res) > 1:
1435
+ return tuple(res)
1436
+ return res[0]
tortoise/utils/__init__.py ADDED
File without changes
tortoise/utils/audio.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ from typing import Dict, List
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import torch
8
+ import torchaudio
9
+ from scipy.io.wavfile import read
10
+
11
+ from tortoise.utils.stft import STFT
12
+
13
+ BUILTIN_VOICES_DIR = os.path.join(
14
+ os.path.dirname(os.path.realpath(__file__)), "../voices"
15
+ )
16
+
17
+
18
+ def load_wav_to_torch(full_path):
19
+ sampling_rate, data = read(full_path)
20
+ if data.dtype == np.int32:
21
+ norm_fix = 2**31
22
+ elif data.dtype == np.int16:
23
+ norm_fix = 2**15
24
+ elif data.dtype == np.float16 or data.dtype == np.float32:
25
+ norm_fix = 1.0
26
+ else:
27
+ raise NotImplementedError(f"Provided data dtype not supported: {data.dtype}")
28
+ return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
29
+
30
+
31
+ def check_audio(audio, audiopath: str):
32
+ # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
33
+ # '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
34
+ if torch.any(audio > 2) or not torch.any(audio < 0):
35
+ print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
36
+ audio.clip_(-1, 1)
37
+
38
+
39
+ def read_audio_file(audiopath: str):
40
+ if audiopath[-4:] == ".wav":
41
+ audio, lsr = load_wav_to_torch(audiopath)
42
+ elif audiopath[-4:] == ".mp3":
43
+ audio, lsr = librosa.load(audiopath, sr=None)
44
+ audio = torch.FloatTensor(audio)
45
+ else:
46
+ assert False, f"Unsupported audio format provided: {audiopath[-4:]}"
47
+
48
+ # Remove any channel data.
49
+ if len(audio.shape) > 1:
50
+ if audio.shape[0] < 5:
51
+ audio = audio[0]
52
+ else:
53
+ assert audio.shape[1] < 5
54
+ audio = audio[:, 0]
55
+
56
+ return audio, lsr
57
+
58
+
59
+ def load_required_audio(audiopath: str):
60
+ audio, lsr = read_audio_file(audiopath)
61
+
62
+ audios = [
63
+ torchaudio.functional.resample(audio, lsr, sampling_rate)
64
+ for sampling_rate in (22050, 24000)
65
+ ]
66
+ for audio in audios:
67
+ check_audio(audio, audiopath)
68
+
69
+ return [audio.unsqueeze(0) for audio in audios]
70
+
71
+
72
+ def load_audio(audiopath, sampling_rate):
73
+ audio, lsr = read_audio_file(audiopath)
74
+
75
+ if lsr != sampling_rate:
76
+ audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
77
+ check_audio(audio, audiopath)
78
+
79
+ return audio.unsqueeze(0)
80
+
81
+
82
+ TACOTRON_MEL_MAX = 2.3143386840820312
83
+ TACOTRON_MEL_MIN = -11.512925148010254
84
+
85
+
86
+ def denormalize_tacotron_mel(norm_mel):
87
+ return ((norm_mel + 1) / 2) * (
88
+ TACOTRON_MEL_MAX - TACOTRON_MEL_MIN
89
+ ) + TACOTRON_MEL_MIN
90
+
91
+
92
+ def normalize_tacotron_mel(mel):
93
+ return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
94
+
95
+
96
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
97
+ """
98
+ PARAMS
99
+ ------
100
+ C: compression factor
101
+ """
102
+ return torch.log(torch.clamp(x, min=clip_val) * C)
103
+
104
+
105
+ def dynamic_range_decompression(x, C=1):
106
+ """
107
+ PARAMS
108
+ ------
109
+ C: compression factor used to compress
110
+ """
111
+ return torch.exp(x) / C
112
+
113
+
114
+ def get_voices(extra_voice_dirs: List[str] = []):
115
+ dirs = [BUILTIN_VOICES_DIR] + extra_voice_dirs
116
+ voices: Dict[str, List[str]] = {}
117
+ for d in dirs:
118
+ subs = os.listdir(d)
119
+ for sub in subs:
120
+ subj = os.path.join(d, sub)
121
+ if os.path.isdir(subj):
122
+ voices[sub] = (
123
+ list(glob(f"{subj}/*.wav"))
124
+ + list(glob(f"{subj}/*.mp3"))
125
+ + list(glob(f"{subj}/*.pth"))
126
+ )
127
+ return voices
128
+
129
+
130
+ def load_voice(voice: str, extra_voice_dirs: List[str] = []):
131
+ if voice == "random":
132
+ return None, None
133
+
134
+ voices = get_voices(extra_voice_dirs)
135
+ paths = voices[voice]
136
+ if len(paths) == 1 and paths[0].endswith(".pth"):
137
+ return None, torch.load(paths[0])
138
+ else:
139
+ conds = []
140
+ for cond_path in paths:
141
+ c = load_required_audio(cond_path)
142
+ conds.append(c)
143
+ return conds, None
144
+
145
+
146
+ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
147
+ latents = []
148
+ clips = []
149
+ for voice in voices:
150
+ if voice == "random":
151
+ if len(voices) > 1:
152
+ print(
153
+ "Cannot combine a random voice with a non-random voice. Just using a random voice."
154
+ )
155
+ return None, None
156
+ clip, latent = load_voice(voice, extra_voice_dirs)
157
+ if latent is None:
158
+ assert (
159
+ len(latents) == 0
160
+ ), "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
161
+ clips.extend(clip)
162
+ elif clip is None:
163
+ assert (
164
+ len(clips) == 0
165
+ ), "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
166
+ latents.append(latent)
167
+ if len(latents) == 0:
168
+ return clips, None
169
+ else:
170
+ latents_0 = torch.stack([l[0] for l in latents], dim=0).mean(dim=0)
171
+ latents_1 = torch.stack([l[1] for l in latents], dim=0).mean(dim=0)
172
+ latents = (latents_0, latents_1)
173
+ return None, latents
174
+
175
+
176
+ class TacotronSTFT(torch.nn.Module):
177
+ def __init__(
178
+ self,
179
+ filter_length=1024,
180
+ hop_length=256,
181
+ win_length=1024,
182
+ n_mel_channels=80,
183
+ sampling_rate=22050,
184
+ mel_fmin=0.0,
185
+ mel_fmax=8000.0,
186
+ ):
187
+ super(TacotronSTFT, self).__init__()
188
+ self.n_mel_channels = n_mel_channels
189
+ self.sampling_rate = sampling_rate
190
+ self.stft_fn = STFT(filter_length, hop_length, win_length)
191
+ from librosa.filters import mel as librosa_mel_fn
192
+
193
+ mel_basis = librosa_mel_fn(
194
+ sr=sampling_rate,
195
+ n_fft=filter_length,
196
+ n_mels=n_mel_channels,
197
+ fmin=mel_fmin,
198
+ fmax=mel_fmax,
199
+ )
200
+ mel_basis = torch.from_numpy(mel_basis).float()
201
+ self.register_buffer("mel_basis", mel_basis)
202
+
203
+ def spectral_normalize(self, magnitudes):
204
+ output = dynamic_range_compression(magnitudes)
205
+ return output
206
+
207
+ def spectral_de_normalize(self, magnitudes):
208
+ output = dynamic_range_decompression(magnitudes)
209
+ return output
210
+
211
+ def mel_spectrogram(self, y):
212
+ """Computes mel-spectrograms from a batch of waves
213
+ PARAMS
214
+ ------
215
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
216
+
217
+ RETURNS
218
+ -------
219
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
220
+ """
221
+ assert torch.min(y.data) >= -10
222
+ assert torch.max(y.data) <= 10
223
+ y = torch.clip(y, min=-1, max=1)
224
+
225
+ magnitudes, phases = self.stft_fn.transform(y)
226
+ magnitudes = magnitudes.data
227
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
228
+ mel_output = self.spectral_normalize(mel_output)
229
+ return mel_output
230
+
231
+
232
+ def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"):
233
+ stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
234
+ stft = stft.to(device)
235
+ mel = stft.mel_spectrogram(wav)
236
+ if do_normalization:
237
+ mel = normalize_tacotron_mel(mel)
238
+ return mel
tortoise/utils/diffusion.py ADDED
@@ -0,0 +1,1469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is an almost carbon copy of gaussian_diffusion.py from OpenAI's ImprovedDiffusion repo, which itself:
3
+
4
+ This code started out as a PyTorch port of Ho et al's diffusion models:
5
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
6
+
7
+ Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
8
+ """
9
+ # AGPL: a notification must be added stating that changes have been made to that file.
10
+
11
+ import enum
12
+ import math
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch as th
17
+ from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
18
+ from tqdm import tqdm
19
+
20
+ from tortoise.dpm_solver_pytorch import DPM_Solver, NoiseScheduleVP, model_wrapper
21
+
22
+ K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m}
23
+ SAMPLERS = ["dpm++2m", "p", "ddim"]
24
+
25
+
26
+ def normal_kl(mean1, logvar1, mean2, logvar2):
27
+ """
28
+ Compute the KL divergence between two gaussians.
29
+
30
+ Shapes are automatically broadcasted, so batches can be compared to
31
+ scalars, among other use cases.
32
+ """
33
+ tensor = None
34
+ for obj in (mean1, logvar1, mean2, logvar2):
35
+ if isinstance(obj, th.Tensor):
36
+ tensor = obj
37
+ break
38
+ assert tensor is not None, "at least one argument must be a Tensor"
39
+
40
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
41
+ # Tensors, but it does not work for th.exp().
42
+ logvar1, logvar2 = [
43
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
44
+ for x in (logvar1, logvar2)
45
+ ]
46
+
47
+ return 0.5 * (
48
+ -1.0
49
+ + logvar2
50
+ - logvar1
51
+ + th.exp(logvar1 - logvar2)
52
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
53
+ )
54
+
55
+
56
+ def approx_standard_normal_cdf(x):
57
+ """
58
+ A fast approximation of the cumulative distribution function of the
59
+ standard normal.
60
+ """
61
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
62
+
63
+
64
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
65
+ """
66
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
67
+ given image.
68
+
69
+ :param x: the target images. It is assumed that this was uint8 values,
70
+ rescaled to the range [-1, 1].
71
+ :param means: the Gaussian mean Tensor.
72
+ :param log_scales: the Gaussian log stddev Tensor.
73
+ :return: a tensor like x of log probabilities (in nats).
74
+ """
75
+ assert x.shape == means.shape == log_scales.shape
76
+ centered_x = x - means
77
+ inv_stdv = th.exp(-log_scales)
78
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
79
+ cdf_plus = approx_standard_normal_cdf(plus_in)
80
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
81
+ cdf_min = approx_standard_normal_cdf(min_in)
82
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
83
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
84
+ cdf_delta = cdf_plus - cdf_min
85
+ log_probs = th.where(
86
+ x < -0.999,
87
+ log_cdf_plus,
88
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
89
+ )
90
+ assert log_probs.shape == x.shape
91
+ return log_probs
92
+
93
+
94
+ def mean_flat(tensor):
95
+ """
96
+ Take the mean over all non-batch dimensions.
97
+ """
98
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
99
+
100
+
101
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
102
+ """
103
+ Get a pre-defined beta schedule for the given name.
104
+
105
+ The beta schedule library consists of beta schedules which remain similar
106
+ in the limit of num_diffusion_timesteps.
107
+ Beta schedules may be added, but should not be removed or changed once
108
+ they are committed to maintain backwards compatibility.
109
+ """
110
+ if schedule_name == "linear":
111
+ # Linear schedule from Ho et al, extended to work for any number of
112
+ # diffusion steps.
113
+ scale = 1000 / num_diffusion_timesteps
114
+ beta_start = scale * 0.0001
115
+ beta_end = scale * 0.02
116
+ return np.linspace(
117
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
118
+ )
119
+ elif schedule_name == "cosine":
120
+ return betas_for_alpha_bar(
121
+ num_diffusion_timesteps,
122
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
123
+ )
124
+ else:
125
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
126
+
127
+
128
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
129
+ """
130
+ Create a beta schedule that discretizes the given alpha_t_bar function,
131
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
132
+
133
+ :param num_diffusion_timesteps: the number of betas to produce.
134
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
135
+ produces the cumulative product of (1-beta) up to that
136
+ part of the diffusion process.
137
+ :param max_beta: the maximum beta to use; use values lower than 1 to
138
+ prevent singularities.
139
+ """
140
+ betas = []
141
+ for i in range(num_diffusion_timesteps):
142
+ t1 = i / num_diffusion_timesteps
143
+ t2 = (i + 1) / num_diffusion_timesteps
144
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
145
+ return np.array(betas)
146
+
147
+
148
+ class ModelMeanType(enum.Enum):
149
+ """
150
+ Which type of output the model predicts.
151
+ """
152
+
153
+ PREVIOUS_X = "previous_x" # the model predicts x_{t-1}
154
+ START_X = "start_x" # the model predicts x_0
155
+ EPSILON = "epsilon" # the model predicts epsilon
156
+
157
+
158
+ class ModelVarType(enum.Enum):
159
+ """
160
+ What is used as the model's output variance.
161
+
162
+ The LEARNED_RANGE option has been added to allow the model to predict
163
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
164
+ """
165
+
166
+ LEARNED = "learned"
167
+ FIXED_SMALL = "fixed_small"
168
+ FIXED_LARGE = "fixed_large"
169
+ LEARNED_RANGE = "learned_range"
170
+
171
+
172
+ class LossType(enum.Enum):
173
+ MSE = "mse" # use raw MSE loss (and KL when learning variances)
174
+ RESCALED_MSE = (
175
+ "rescaled_mse" # use raw MSE loss (with RESCALED_KL when learning variances)
176
+ )
177
+ KL = "kl" # use the variational lower-bound
178
+ RESCALED_KL = "rescaled_kl" # like KL, but rescale to estimate the full VLB
179
+
180
+ def is_vb(self):
181
+ return self == LossType.KL or self == LossType.RESCALED_KL
182
+
183
+
184
+ class GaussianDiffusion:
185
+ """
186
+ Utilities for training and sampling diffusion models.
187
+
188
+ Ported directly from here, and then adapted over time to further experimentation.
189
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
190
+
191
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
192
+ starting at T and going to 1.
193
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
194
+ :param model_var_type: a ModelVarType determining how variance is output.
195
+ :param loss_type: a LossType determining the loss function to use.
196
+ :param rescale_timesteps: if True, pass floating point timesteps into the
197
+ model so that they are always scaled like in the
198
+ original paper (0 to 1000).
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ *,
204
+ betas,
205
+ model_mean_type,
206
+ model_var_type,
207
+ loss_type,
208
+ rescale_timesteps=False, # this is generally False
209
+ conditioning_free=False,
210
+ conditioning_free_k=1,
211
+ ramp_conditioning_free=True,
212
+ sampler="ddim",
213
+ ):
214
+ self.sampler = sampler
215
+ self.model_mean_type = ModelMeanType(model_mean_type)
216
+ self.model_var_type = ModelVarType(model_var_type)
217
+ self.loss_type = LossType(loss_type)
218
+ self.rescale_timesteps = rescale_timesteps
219
+ self.conditioning_free = conditioning_free
220
+ self.conditioning_free_k = conditioning_free_k
221
+ self.ramp_conditioning_free = ramp_conditioning_free
222
+
223
+ # Use float64 for accuracy.
224
+ betas = np.array(betas, dtype=np.float64)
225
+ self.betas = betas
226
+ assert len(betas.shape) == 1, "betas must be 1-D"
227
+ assert (betas > 0).all() and (betas <= 1).all()
228
+
229
+ self.num_timesteps = int(betas.shape[0])
230
+
231
+ alphas = 1.0 - betas
232
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
233
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
234
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
235
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
236
+
237
+ # calculations for diffusion q(x_t | x_{t-1}) and others
238
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
239
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
240
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
241
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
242
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
243
+
244
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
245
+ self.posterior_variance = (
246
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
247
+ )
248
+ # log calculation clipped because the posterior variance is 0 at the
249
+ # beginning of the diffusion chain.
250
+ self.posterior_log_variance_clipped = np.log(
251
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
252
+ )
253
+ self.posterior_mean_coef1 = (
254
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
255
+ )
256
+ self.posterior_mean_coef2 = (
257
+ (1.0 - self.alphas_cumprod_prev)
258
+ * np.sqrt(alphas)
259
+ / (1.0 - self.alphas_cumprod)
260
+ )
261
+
262
+ def q_mean_variance(self, x_start, t):
263
+ """
264
+ Get the distribution q(x_t | x_0).
265
+
266
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
267
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
268
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
269
+ """
270
+ mean = (
271
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
272
+ )
273
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
274
+ log_variance = _extract_into_tensor(
275
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
276
+ )
277
+ return mean, variance, log_variance
278
+
279
+ def q_sample(self, x_start, t, noise=None):
280
+ """
281
+ Diffuse the data for a given number of diffusion steps.
282
+
283
+ In other words, sample from q(x_t | x_0).
284
+
285
+ :param x_start: the initial data batch.
286
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
287
+ :param noise: if specified, the split-out normal noise.
288
+ :return: A noisy version of x_start.
289
+ """
290
+ if noise is None:
291
+ noise = th.randn_like(x_start)
292
+ assert noise.shape == x_start.shape
293
+ return (
294
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
295
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
296
+ * noise
297
+ )
298
+
299
+ def q_posterior_mean_variance(self, x_start, x_t, t):
300
+ """
301
+ Compute the mean and variance of the diffusion posterior:
302
+
303
+ q(x_{t-1} | x_t, x_0)
304
+
305
+ """
306
+ assert x_start.shape == x_t.shape
307
+ posterior_mean = (
308
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
309
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
310
+ )
311
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
312
+ posterior_log_variance_clipped = _extract_into_tensor(
313
+ self.posterior_log_variance_clipped, t, x_t.shape
314
+ )
315
+ assert (
316
+ posterior_mean.shape[0]
317
+ == posterior_variance.shape[0]
318
+ == posterior_log_variance_clipped.shape[0]
319
+ == x_start.shape[0]
320
+ )
321
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
322
+
323
+ def p_mean_variance(
324
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
325
+ ):
326
+ """
327
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
328
+ the initial x, x_0.
329
+
330
+ :param model: the model, which takes a signal and a batch of timesteps
331
+ as input.
332
+ :param x: the [N x C x ...] tensor at time t.
333
+ :param t: a 1-D Tensor of timesteps.
334
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
335
+ :param denoised_fn: if not None, a function which applies to the
336
+ x_start prediction before it is used to sample. Applies before
337
+ clip_denoised.
338
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
339
+ pass to the model. This can be used for conditioning.
340
+ :return: a dict with the following keys:
341
+ - 'mean': the model mean output.
342
+ - 'variance': the model variance output.
343
+ - 'log_variance': the log of 'variance'.
344
+ - 'pred_xstart': the prediction for x_0.
345
+ """
346
+ if model_kwargs is None:
347
+ model_kwargs = {}
348
+
349
+ assert self.model_var_type == ModelVarType.LEARNED_RANGE
350
+ assert self.model_mean_type == ModelMeanType.EPSILON
351
+ assert denoised_fn is None
352
+ assert clip_denoised is True
353
+ B, C = x.shape[:2]
354
+ assert t.shape == (B,)
355
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
356
+ if self.conditioning_free:
357
+ model_output_no_conditioning = model(
358
+ x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs
359
+ )
360
+
361
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
362
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
363
+ model_output, model_var_values = th.split(model_output, C, dim=1)
364
+ if self.conditioning_free:
365
+ model_output_no_conditioning, _ = th.split(
366
+ model_output_no_conditioning, C, dim=1
367
+ )
368
+ if self.model_var_type == ModelVarType.LEARNED:
369
+ assert False
370
+ model_log_variance = model_var_values
371
+ model_variance = th.exp(model_log_variance)
372
+ else:
373
+ min_log = _extract_into_tensor(
374
+ self.posterior_log_variance_clipped, t, x.shape
375
+ )
376
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
377
+ # The model_var_values is [-1, 1] for [min_var, max_var].
378
+ frac = (model_var_values + 1) / 2
379
+ model_log_variance = frac * max_log + (1 - frac) * min_log
380
+ model_variance = th.exp(model_log_variance)
381
+ else:
382
+ assert False
383
+ model_variance, model_log_variance = {
384
+ # for fixedlarge, we set the initial (log-)variance like so
385
+ # to get a better decoder log likelihood.
386
+ ModelVarType.FIXED_LARGE: (
387
+ np.append(self.posterior_variance[1], self.betas[1:]),
388
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
389
+ ),
390
+ ModelVarType.FIXED_SMALL: (
391
+ self.posterior_variance,
392
+ self.posterior_log_variance_clipped,
393
+ ),
394
+ }[self.model_var_type]
395
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
396
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
397
+
398
+ if self.conditioning_free:
399
+ if self.ramp_conditioning_free:
400
+ assert t.shape[0] == 1 # This should only be used in inference.
401
+ cfk = self.conditioning_free_k * (
402
+ 1 - self._scale_timesteps(t)[0].item() / self.num_timesteps
403
+ )
404
+ else:
405
+ cfk = self.conditioning_free_k
406
+ model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning
407
+
408
+ def process_xstart(x):
409
+ if denoised_fn is not None:
410
+ assert False
411
+ x = denoised_fn(x)
412
+ if clip_denoised:
413
+ return x.clamp(-1, 1)
414
+ assert False
415
+ return x
416
+
417
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
418
+ assert False
419
+ pred_xstart = process_xstart(
420
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
421
+ )
422
+ model_mean = model_output
423
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
424
+ if self.model_mean_type == ModelMeanType.START_X:
425
+ assert False
426
+ pred_xstart = process_xstart(model_output)
427
+ else:
428
+ pred_xstart = process_xstart(
429
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
430
+ )
431
+ model_mean, _, _ = self.q_posterior_mean_variance(
432
+ x_start=pred_xstart, x_t=x, t=t
433
+ )
434
+ else:
435
+ raise NotImplementedError(self.model_mean_type)
436
+
437
+ assert (
438
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
439
+ )
440
+ return {
441
+ "mean": model_mean,
442
+ "variance": model_variance,
443
+ "log_variance": model_log_variance,
444
+ "pred_xstart": pred_xstart,
445
+ }
446
+
447
+ def _predict_xstart_from_eps(self, x_t, t, eps):
448
+ assert x_t.shape == eps.shape
449
+ return (
450
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
451
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
452
+ )
453
+
454
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
455
+ assert x_t.shape == xprev.shape
456
+ return ( # (xprev - coef2*x_t) / coef1
457
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
458
+ - _extract_into_tensor(
459
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
460
+ )
461
+ * x_t
462
+ )
463
+
464
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
465
+ return (
466
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
467
+ - pred_xstart
468
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
469
+
470
+ def _scale_timesteps(self, t):
471
+ if self.rescale_timesteps:
472
+ return t.float() * (1000.0 / self.num_timesteps)
473
+ return t
474
+
475
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
476
+ """
477
+ Compute the mean for the previous step, given a function cond_fn that
478
+ computes the gradient of a conditional log probability with respect to
479
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
480
+ condition on y.
481
+
482
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
483
+ """
484
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
485
+ new_mean = (
486
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
487
+ )
488
+ return new_mean
489
+
490
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
491
+ """
492
+ Compute what the p_mean_variance output would have been, should the
493
+ model's score function be conditioned by cond_fn.
494
+
495
+ See condition_mean() for details on cond_fn.
496
+
497
+ Unlike condition_mean(), this instead uses the conditioning strategy
498
+ from Song et al (2020).
499
+ """
500
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
501
+
502
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
503
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
504
+ x, self._scale_timesteps(t), **model_kwargs
505
+ )
506
+
507
+ out = p_mean_var.copy()
508
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
509
+ out["mean"], _, _ = self.q_posterior_mean_variance(
510
+ x_start=out["pred_xstart"], x_t=x, t=t
511
+ )
512
+ return out
513
+
514
+ def p_sample(
515
+ self,
516
+ model,
517
+ x,
518
+ t,
519
+ clip_denoised=True,
520
+ denoised_fn=None,
521
+ cond_fn=None,
522
+ model_kwargs=None,
523
+ ):
524
+ """
525
+ Sample x_{t-1} from the model at the given timestep.
526
+
527
+ :param model: the model to sample from.
528
+ :param x: the current tensor at x_{t-1}.
529
+ :param t: the value of t, starting at 0 for the first diffusion step.
530
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
531
+ :param denoised_fn: if not None, a function which applies to the
532
+ x_start prediction before it is used to sample.
533
+ :param cond_fn: if not None, this is a gradient function that acts
534
+ similarly to the model.
535
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
536
+ pass to the model. This can be used for conditioning.
537
+ :return: a dict containing the following keys:
538
+ - 'sample': a random sample from the model.
539
+ - 'pred_xstart': a prediction of x_0.
540
+ """
541
+ out = self.p_mean_variance(
542
+ model,
543
+ x,
544
+ t,
545
+ clip_denoised=clip_denoised,
546
+ denoised_fn=denoised_fn,
547
+ model_kwargs=model_kwargs,
548
+ )
549
+ noise = th.randn_like(x)
550
+ nonzero_mask = (
551
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
552
+ ) # no noise when t == 0
553
+ if cond_fn is not None:
554
+ out["mean"] = self.condition_mean(
555
+ cond_fn, out, x, t, model_kwargs=model_kwargs
556
+ )
557
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
558
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
559
+
560
+ def k_diffusion_sample_loop(
561
+ self,
562
+ k_sampler,
563
+ pbar,
564
+ model,
565
+ shape,
566
+ noise=None, # all given
567
+ clip_denoised=True,
568
+ denoised_fn=None,
569
+ cond_fn=None,
570
+ device=None, # ALL UNUSED
571
+ model_kwargs=None, # {'precomputed_aligned_embeddings': precomputed_embeddings},
572
+ progress=False, # unused as well
573
+ ):
574
+ assert isinstance(model_kwargs, dict)
575
+ if device is None:
576
+ device = next(model.parameters()).device
577
+ s_in = noise.new_ones([noise.shape[0]])
578
+
579
+ def model_split(*args, **kwargs):
580
+ model_output = model(*args, **kwargs)
581
+ model_epsilon, model_var = th.split(
582
+ model_output, model_output.shape[1] // 2, dim=1
583
+ )
584
+ return model_epsilon, model_var
585
+
586
+ #
587
+ """
588
+ print(self.betas)
589
+ print(th.tensor(self.betas))
590
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=th.tensor(self.betas))
591
+ """
592
+ noise_schedule = NoiseScheduleVP(
593
+ schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4
594
+ )
595
+
596
+ def model_fn_prewrap(x, t, *args, **kwargs):
597
+ """
598
+ x_in = torch.cat([x] * 2)
599
+ t_in = torch.cat([t_continuous] * 2)
600
+ c_in = torch.cat([unconditional_condition, condition])
601
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
602
+ print(t)
603
+ print(self.timestep_map)
604
+ exit()
605
+ """
606
+ """
607
+ model_output = model(x, self._scale_timesteps(t*4000), **model_kwargs)
608
+ out = self.p_mean_variance(model, x, t*4000, model_kwargs=model_kwargs)
609
+ return out['pred_xstart']
610
+ """
611
+ x, _ = x.chunk(2)
612
+ t, _ = (t * 1000).chunk(2)
613
+ res = torch.cat(
614
+ [
615
+ model_split(x, t, conditioning_free=True, **model_kwargs)[0],
616
+ model_split(x, t, **model_kwargs)[0],
617
+ ]
618
+ )
619
+ pbar.update(1)
620
+ return res
621
+
622
+ model_fn = model_wrapper(
623
+ model_fn_prewrap,
624
+ noise_schedule,
625
+ model_type="noise", # "noise" or "x_start" or "v" or "score"
626
+ model_kwargs=model_kwargs,
627
+ guidance_type="classifier-free",
628
+ condition=th.Tensor(1),
629
+ unconditional_condition=th.Tensor(1),
630
+ guidance_scale=self.conditioning_free_k,
631
+ )
632
+ """
633
+ model_fn = model_wrapper(
634
+ model_fn_prewrap,
635
+ noise_schedule,
636
+ model_type='x_start',
637
+ model_kwargs={}
638
+ )
639
+ #
640
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
641
+ x_sample = dpm_solver.sample(
642
+ noise,
643
+ steps=20,
644
+ order=3,
645
+ skip_type="time_uniform",
646
+ method="singlestep",
647
+ )
648
+ """
649
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
650
+ x_sample = dpm_solver.sample(
651
+ noise,
652
+ steps=self.num_timesteps,
653
+ order=2,
654
+ skip_type="time_uniform",
655
+ method="multistep",
656
+ )
657
+ #'''
658
+ return x_sample
659
+
660
+ # HF DIFFUSION ATTEMPT
661
+ """
662
+ from .hf_diffusion import EulerAncestralDiscreteScheduler
663
+ Scheduler = EulerAncestralDiscreteScheduler()
664
+ Scheduler.set_timesteps(100)
665
+ for timestep in Scheduler.timesteps:
666
+ noise_input = Scheduler.scale_model_input(noise, timestep)
667
+ ts = s_in * timestep
668
+ model_output = model(noise_input, ts, **model_kwargs)
669
+ model_epsilon, _model_var = th.split(model_output, model_output.shape[1]//2, dim=1)
670
+ noise, _x0 = Scheduler.step(model_epsilon, timestep, noise)
671
+ return noise
672
+ """
673
+
674
+ # KARRAS DIFFUSION ATTEMPT
675
+ """
676
+ TRAINED_DIFFUSION_STEPS = 4000 # HARDCODED
677
+ ratio = TRAINED_DIFFUSION_STEPS/14.5
678
+ def call_model(*args, **kwargs):
679
+ model_output = model(*args, **kwargs)
680
+ model_output, model_var_values = th.split(model_output, model_output.shape[1]//2, dim=1)
681
+ return model_output
682
+ print(get_sigmas_karras(self.num_timesteps, sigma_min=0.0, sigma_max=4000, device=device))
683
+ exit()
684
+ sigmas = get_sigmas_karras(self.num_timesteps, sigma_min=0.03, sigma_max=14.5, device=device)
685
+ return k_sampler(call_model, noise, sigmas, extra_args=model_kwargs, disable=not progress)
686
+ '''
687
+ sigmas = get_sigmas_karras(self.num_timesteps, sigma_min=0.03, sigma_max=14.5, device=device)
688
+ step = 0 # LMAO
689
+ global_sigmas = None
690
+ #
691
+ def fakemodel(x, t, **model_kwargs):
692
+ print(t,global_sigmas*ratio)
693
+ return model(x, t, **model_kwargs)
694
+ def denoised(x, sigmas, **extra_args):
695
+ t = th.tensor([self.num_timesteps-step-1] * shape[0], device=device)
696
+ nonlocal global_sigmas
697
+ global_sigmas = sigmas
698
+ with th.no_grad():
699
+ out = self.p_sample(
700
+ fakemodel,
701
+ x,
702
+ t,
703
+ clip_denoised=clip_denoised,
704
+ denoised_fn=denoised_fn,
705
+ cond_fn=cond_fn,
706
+ model_kwargs=model_kwargs,
707
+ )
708
+ return out["sample"]
709
+ def callback(d):
710
+ nonlocal step
711
+ step += 1
712
+
713
+ return k_sampler(denoised, noise, sigmas, extra_args=model_kwargs, callback=callback, disable=not progress)
714
+ '''
715
+ """
716
+
717
+ def sample_loop(self, *args, **kwargs):
718
+ s = self.sampler
719
+ if s == "p":
720
+ return self.p_sample_loop(*args, **kwargs)
721
+ elif s == "ddim":
722
+ return self.ddim_sample_loop(*args, **kwargs)
723
+ elif s == "dpm++2m":
724
+ if self.conditioning_free is not True:
725
+ raise RuntimeError("cond_free must be true")
726
+ with tqdm(total=self.num_timesteps) as pbar:
727
+ return self.k_diffusion_sample_loop(
728
+ K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs
729
+ )
730
+ else:
731
+ raise RuntimeError("sampler not impl")
732
+
733
+ def p_sample_loop(
734
+ self,
735
+ model,
736
+ shape,
737
+ noise=None,
738
+ clip_denoised=True,
739
+ denoised_fn=None,
740
+ cond_fn=None,
741
+ model_kwargs=None,
742
+ device=None,
743
+ progress=False,
744
+ ):
745
+ """
746
+ Generate samples from the model.
747
+
748
+ :param model: the model module.
749
+ :param shape: the shape of the samples, (N, C, H, W).
750
+ :param noise: if specified, the noise from the encoder to sample.
751
+ Should be of the same shape as `shape`.
752
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
753
+ :param denoised_fn: if not None, a function which applies to the
754
+ x_start prediction before it is used to sample.
755
+ :param cond_fn: if not None, this is a gradient function that acts
756
+ similarly to the model.
757
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
758
+ pass to the model. This can be used for conditioning.
759
+ :param device: if specified, the device to create the samples on.
760
+ If not specified, use a model parameter's device.
761
+ :param progress: if True, show a tqdm progress bar.
762
+ :return: a non-differentiable batch of samples.
763
+ """
764
+ final = None
765
+ for sample in self.p_sample_loop_progressive(
766
+ model,
767
+ shape,
768
+ noise=noise,
769
+ clip_denoised=clip_denoised,
770
+ denoised_fn=denoised_fn,
771
+ cond_fn=cond_fn,
772
+ model_kwargs=model_kwargs,
773
+ device=device,
774
+ progress=progress,
775
+ ):
776
+ final = sample
777
+ return final["sample"]
778
+
779
+ def p_sample_loop_progressive(
780
+ self,
781
+ model,
782
+ shape,
783
+ noise=None,
784
+ clip_denoised=True,
785
+ denoised_fn=None,
786
+ cond_fn=None,
787
+ model_kwargs=None,
788
+ device=None,
789
+ progress=False,
790
+ ):
791
+ """
792
+ Generate samples from the model and yield intermediate samples from
793
+ each timestep of diffusion.
794
+
795
+ Arguments are the same as p_sample_loop().
796
+ Returns a generator over dicts, where each dict is the return value of
797
+ p_sample().
798
+ """
799
+ if device is None:
800
+ device = next(model.parameters()).device
801
+ assert isinstance(shape, (tuple, list))
802
+ if noise is not None:
803
+ img = noise
804
+ else:
805
+ img = th.randn(*shape, device=device)
806
+ indices = list(range(self.num_timesteps))[::-1]
807
+
808
+ for i in tqdm(indices, disable=not progress):
809
+ t = th.tensor([i] * shape[0], device=device)
810
+ with th.no_grad():
811
+ out = self.p_sample(
812
+ model,
813
+ img,
814
+ t,
815
+ clip_denoised=clip_denoised,
816
+ denoised_fn=denoised_fn,
817
+ cond_fn=cond_fn,
818
+ model_kwargs=model_kwargs,
819
+ )
820
+ yield out
821
+ img = out["sample"]
822
+
823
+ def ddim_sample(
824
+ self,
825
+ model,
826
+ x,
827
+ t,
828
+ clip_denoised=True,
829
+ denoised_fn=None,
830
+ cond_fn=None,
831
+ model_kwargs=None,
832
+ eta=0.0,
833
+ ):
834
+ """
835
+ Sample x_{t-1} from the model using DDIM.
836
+
837
+ Same usage as p_sample().
838
+ """
839
+ out = self.p_mean_variance(
840
+ model,
841
+ x,
842
+ t,
843
+ clip_denoised=clip_denoised,
844
+ denoised_fn=denoised_fn,
845
+ model_kwargs=model_kwargs,
846
+ )
847
+ if cond_fn is not None:
848
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
849
+
850
+ # Usually our model outputs epsilon, but we re-derive it
851
+ # in case we used x_start or x_prev prediction.
852
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
853
+
854
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
855
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
856
+ sigma = (
857
+ eta
858
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
859
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
860
+ )
861
+ # Equation 12.
862
+ noise = th.randn_like(x)
863
+ mean_pred = (
864
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
865
+ + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
866
+ )
867
+ nonzero_mask = (
868
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
869
+ ) # no noise when t == 0
870
+ sample = mean_pred + nonzero_mask * sigma * noise
871
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
872
+
873
+ def ddim_reverse_sample(
874
+ self,
875
+ model,
876
+ x,
877
+ t,
878
+ clip_denoised=True,
879
+ denoised_fn=None,
880
+ model_kwargs=None,
881
+ eta=0.0,
882
+ ):
883
+ """
884
+ Sample x_{t+1} from the model using DDIM reverse ODE.
885
+ """
886
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
887
+ out = self.p_mean_variance(
888
+ model,
889
+ x,
890
+ t,
891
+ clip_denoised=clip_denoised,
892
+ denoised_fn=denoised_fn,
893
+ model_kwargs=model_kwargs,
894
+ )
895
+ # Usually our model outputs epsilon, but we re-derive it
896
+ # in case we used x_start or x_prev prediction.
897
+ eps = (
898
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
899
+ - out["pred_xstart"]
900
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
901
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
902
+
903
+ # Equation 12. reversed
904
+ mean_pred = (
905
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
906
+ + th.sqrt(1 - alpha_bar_next) * eps
907
+ )
908
+
909
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
910
+
911
+ def ddim_sample_loop(
912
+ self,
913
+ model,
914
+ shape,
915
+ noise=None,
916
+ clip_denoised=True,
917
+ denoised_fn=None,
918
+ cond_fn=None,
919
+ model_kwargs=None,
920
+ device=None,
921
+ progress=False,
922
+ eta=0.0,
923
+ ):
924
+ """
925
+ Generate samples from the model using DDIM.
926
+
927
+ Same usage as p_sample_loop().
928
+ """
929
+ final = None
930
+ for sample in self.ddim_sample_loop_progressive(
931
+ model,
932
+ shape,
933
+ noise=noise,
934
+ clip_denoised=clip_denoised,
935
+ denoised_fn=denoised_fn,
936
+ cond_fn=cond_fn,
937
+ model_kwargs=model_kwargs,
938
+ device=device,
939
+ progress=progress,
940
+ eta=eta,
941
+ ):
942
+ final = sample
943
+ return final["sample"]
944
+
945
+ def ddim_sample_loop_progressive(
946
+ self,
947
+ model,
948
+ shape,
949
+ noise=None,
950
+ clip_denoised=True,
951
+ denoised_fn=None,
952
+ cond_fn=None,
953
+ model_kwargs=None,
954
+ device=None,
955
+ progress=False,
956
+ eta=0.0,
957
+ ):
958
+ """
959
+ Use DDIM to sample from the model and yield intermediate samples from
960
+ each timestep of DDIM.
961
+
962
+ Same usage as p_sample_loop_progressive().
963
+ """
964
+ if device is None:
965
+ device = next(model.parameters()).device
966
+ assert isinstance(shape, (tuple, list))
967
+ if noise is not None:
968
+ img = noise
969
+ else:
970
+ img = th.randn(*shape, device=device)
971
+ indices = list(range(self.num_timesteps))[::-1]
972
+
973
+ if progress:
974
+ # Lazy import so that we don't depend on tqdm.
975
+ from tqdm.auto import tqdm
976
+
977
+ indices = tqdm(indices, disable=not progress)
978
+
979
+ for i in indices:
980
+ t = th.tensor([i] * shape[0], device=device)
981
+ with th.no_grad():
982
+ out = self.ddim_sample(
983
+ model,
984
+ img,
985
+ t,
986
+ clip_denoised=clip_denoised,
987
+ denoised_fn=denoised_fn,
988
+ cond_fn=cond_fn,
989
+ model_kwargs=model_kwargs,
990
+ eta=eta,
991
+ )
992
+ yield out
993
+ img = out["sample"]
994
+
995
+ def _vb_terms_bpd(
996
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
997
+ ):
998
+ """
999
+ Get a term for the variational lower-bound.
1000
+
1001
+ The resulting units are bits (rather than nats, as one might expect).
1002
+ This allows for comparison to other papers.
1003
+
1004
+ :return: a dict with the following keys:
1005
+ - 'output': a shape [N] tensor of NLLs or KLs.
1006
+ - 'pred_xstart': the x_0 predictions.
1007
+ """
1008
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
1009
+ x_start=x_start, x_t=x_t, t=t
1010
+ )
1011
+ out = self.p_mean_variance(
1012
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
1013
+ )
1014
+ kl = normal_kl(
1015
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
1016
+ )
1017
+ kl = mean_flat(kl) / np.log(2.0)
1018
+
1019
+ decoder_nll = -discretized_gaussian_log_likelihood(
1020
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
1021
+ )
1022
+ assert decoder_nll.shape == x_start.shape
1023
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
1024
+
1025
+ # At the first timestep return the decoder NLL,
1026
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
1027
+ output = th.where((t == 0), decoder_nll, kl)
1028
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
1029
+
1030
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
1031
+ """
1032
+ Compute training losses for a single timestep.
1033
+
1034
+ :param model: the model to evaluate loss on.
1035
+ :param x_start: the [N x C x ...] tensor of inputs.
1036
+ :param t: a batch of timestep indices.
1037
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
1038
+ pass to the model. This can be used for conditioning.
1039
+ :param noise: if specified, the specific Gaussian noise to try to remove.
1040
+ :return: a dict with the key "loss" containing a tensor of shape [N].
1041
+ Some mean or variance settings may also have other keys.
1042
+ """
1043
+ if model_kwargs is None:
1044
+ model_kwargs = {}
1045
+ if noise is None:
1046
+ noise = th.randn_like(x_start)
1047
+ x_t = self.q_sample(x_start, t, noise=noise)
1048
+
1049
+ terms = {}
1050
+
1051
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
1052
+ # TODO: support multiple model outputs for this mode.
1053
+ terms["loss"] = self._vb_terms_bpd(
1054
+ model=model,
1055
+ x_start=x_start,
1056
+ x_t=x_t,
1057
+ t=t,
1058
+ clip_denoised=False,
1059
+ model_kwargs=model_kwargs,
1060
+ )["output"]
1061
+ if self.loss_type == LossType.RESCALED_KL:
1062
+ terms["loss"] *= self.num_timesteps
1063
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
1064
+ model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs)
1065
+ if isinstance(model_outputs, tuple):
1066
+ model_output = model_outputs[0]
1067
+ terms["extra_outputs"] = model_outputs[1:]
1068
+ else:
1069
+ model_output = model_outputs
1070
+
1071
+ if self.model_var_type in [
1072
+ ModelVarType.LEARNED,
1073
+ ModelVarType.LEARNED_RANGE,
1074
+ ]:
1075
+ B, C = x_t.shape[:2]
1076
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
1077
+ model_output, model_var_values = th.split(model_output, C, dim=1)
1078
+ # Learn the variance using the variational bound, but don't let
1079
+ # it affect our mean prediction.
1080
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
1081
+ terms["vb"] = self._vb_terms_bpd(
1082
+ model=lambda *args, r=frozen_out: r,
1083
+ x_start=x_start,
1084
+ x_t=x_t,
1085
+ t=t,
1086
+ clip_denoised=False,
1087
+ )["output"]
1088
+ if self.loss_type == LossType.RESCALED_MSE:
1089
+ # Divide by 1000 for equivalence with initial implementation.
1090
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
1091
+ terms["vb"] *= self.num_timesteps / 1000.0
1092
+
1093
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
1094
+ target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[
1095
+ 0
1096
+ ]
1097
+ x_start_pred = torch.zeros(x_start) # Not supported.
1098
+ elif self.model_mean_type == ModelMeanType.START_X:
1099
+ target = x_start
1100
+ x_start_pred = model_output
1101
+ elif self.model_mean_type == ModelMeanType.EPSILON:
1102
+ target = noise
1103
+ x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
1104
+ else:
1105
+ raise NotImplementedError(self.model_mean_type)
1106
+ assert model_output.shape == target.shape == x_start.shape
1107
+ terms["mse"] = mean_flat((target - model_output) ** 2)
1108
+ terms["x_start_predicted"] = x_start_pred
1109
+ if "vb" in terms:
1110
+ terms["loss"] = terms["mse"] + terms["vb"]
1111
+ else:
1112
+ terms["loss"] = terms["mse"]
1113
+ else:
1114
+ raise NotImplementedError(self.loss_type)
1115
+
1116
+ return terms
1117
+
1118
+ def autoregressive_training_losses(
1119
+ self,
1120
+ model,
1121
+ x_start,
1122
+ t,
1123
+ model_output_keys,
1124
+ gd_out_key,
1125
+ model_kwargs=None,
1126
+ noise=None,
1127
+ ):
1128
+ """
1129
+ Compute training losses for a single timestep.
1130
+
1131
+ :param model: the model to evaluate loss on.
1132
+ :param x_start: the [N x C x ...] tensor of inputs.
1133
+ :param t: a batch of timestep indices.
1134
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
1135
+ pass to the model. This can be used for conditioning.
1136
+ :param noise: if specified, the specific Gaussian noise to try to remove.
1137
+ :return: a dict with the key "loss" containing a tensor of shape [N].
1138
+ Some mean or variance settings may also have other keys.
1139
+ """
1140
+ if model_kwargs is None:
1141
+ model_kwargs = {}
1142
+ if noise is None:
1143
+ noise = th.randn_like(x_start)
1144
+ x_t = self.q_sample(x_start, t, noise=noise)
1145
+ terms = {}
1146
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
1147
+ assert False # not currently supported for this type of diffusion.
1148
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
1149
+ model_outputs = model(
1150
+ x_t, x_start, self._scale_timesteps(t), **model_kwargs
1151
+ )
1152
+ terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
1153
+ model_output = terms[gd_out_key]
1154
+ if self.model_var_type in [
1155
+ ModelVarType.LEARNED,
1156
+ ModelVarType.LEARNED_RANGE,
1157
+ ]:
1158
+ B, C = x_t.shape[:2]
1159
+ assert model_output.shape == (B, C, 2, *x_t.shape[2:])
1160
+ model_output, model_var_values = (
1161
+ model_output[:, :, 0],
1162
+ model_output[:, :, 1],
1163
+ )
1164
+ # Learn the variance using the variational bound, but don't let
1165
+ # it affect our mean prediction.
1166
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
1167
+ terms["vb"] = self._vb_terms_bpd(
1168
+ model=lambda *args, r=frozen_out: r,
1169
+ x_start=x_start,
1170
+ x_t=x_t,
1171
+ t=t,
1172
+ clip_denoised=False,
1173
+ )["output"]
1174
+ if self.loss_type == LossType.RESCALED_MSE:
1175
+ # Divide by 1000 for equivalence with initial implementation.
1176
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
1177
+ terms["vb"] *= self.num_timesteps / 1000.0
1178
+
1179
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
1180
+ target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[
1181
+ 0
1182
+ ]
1183
+ x_start_pred = torch.zeros(x_start) # Not supported.
1184
+ elif self.model_mean_type == ModelMeanType.START_X:
1185
+ target = x_start
1186
+ x_start_pred = model_output
1187
+ elif self.model_mean_type == ModelMeanType.EPSILON:
1188
+ target = noise
1189
+ x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
1190
+ else:
1191
+ raise NotImplementedError(self.model_mean_type)
1192
+ assert model_output.shape == target.shape == x_start.shape
1193
+ terms["mse"] = mean_flat((target - model_output) ** 2)
1194
+ terms["x_start_predicted"] = x_start_pred
1195
+ if "vb" in terms:
1196
+ terms["loss"] = terms["mse"] + terms["vb"]
1197
+ else:
1198
+ terms["loss"] = terms["mse"]
1199
+ else:
1200
+ raise NotImplementedError(self.loss_type)
1201
+
1202
+ return terms
1203
+
1204
+ def _prior_bpd(self, x_start):
1205
+ """
1206
+ Get the prior KL term for the variational lower-bound, measured in
1207
+ bits-per-dim.
1208
+
1209
+ This term can't be optimized, as it only depends on the encoder.
1210
+
1211
+ :param x_start: the [N x C x ...] tensor of inputs.
1212
+ :return: a batch of [N] KL values (in bits), one per batch element.
1213
+ """
1214
+ batch_size = x_start.shape[0]
1215
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1216
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1217
+ kl_prior = normal_kl(
1218
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
1219
+ )
1220
+ return mean_flat(kl_prior) / np.log(2.0)
1221
+
1222
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
1223
+ """
1224
+ Compute the entire variational lower-bound, measured in bits-per-dim,
1225
+ as well as other related quantities.
1226
+
1227
+ :param model: the model to evaluate loss on.
1228
+ :param x_start: the [N x C x ...] tensor of inputs.
1229
+ :param clip_denoised: if True, clip denoised samples.
1230
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
1231
+ pass to the model. This can be used for conditioning.
1232
+
1233
+ :return: a dict containing the following keys:
1234
+ - total_bpd: the total variational lower-bound, per batch element.
1235
+ - prior_bpd: the prior term in the lower-bound.
1236
+ - vb: an [N x T] tensor of terms in the lower-bound.
1237
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
1238
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
1239
+ """
1240
+ device = x_start.device
1241
+ batch_size = x_start.shape[0]
1242
+
1243
+ vb = []
1244
+ xstart_mse = []
1245
+ mse = []
1246
+ for t in list(range(self.num_timesteps))[::-1]:
1247
+ t_batch = th.tensor([t] * batch_size, device=device)
1248
+ noise = th.randn_like(x_start)
1249
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
1250
+ # Calculate VLB term at the current timestep
1251
+ with th.no_grad():
1252
+ out = self._vb_terms_bpd(
1253
+ model,
1254
+ x_start=x_start,
1255
+ x_t=x_t,
1256
+ t=t_batch,
1257
+ clip_denoised=clip_denoised,
1258
+ model_kwargs=model_kwargs,
1259
+ )
1260
+ vb.append(out["output"])
1261
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
1262
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
1263
+ mse.append(mean_flat((eps - noise) ** 2))
1264
+
1265
+ vb = th.stack(vb, dim=1)
1266
+ xstart_mse = th.stack(xstart_mse, dim=1)
1267
+ mse = th.stack(mse, dim=1)
1268
+
1269
+ prior_bpd = self._prior_bpd(x_start)
1270
+ total_bpd = vb.sum(dim=1) + prior_bpd
1271
+ return {
1272
+ "total_bpd": total_bpd,
1273
+ "prior_bpd": prior_bpd,
1274
+ "vb": vb,
1275
+ "xstart_mse": xstart_mse,
1276
+ "mse": mse,
1277
+ }
1278
+
1279
+
1280
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
1281
+ """
1282
+ Get a pre-defined beta schedule for the given name.
1283
+
1284
+ The beta schedule library consists of beta schedules which remain similar
1285
+ in the limit of num_diffusion_timesteps.
1286
+ Beta schedules may be added, but should not be removed or changed once
1287
+ they are committed to maintain backwards compatibility.
1288
+ """
1289
+ if schedule_name == "linear":
1290
+ # Linear schedule from Ho et al, extended to work for any number of
1291
+ # diffusion steps.
1292
+ scale = 1000 / num_diffusion_timesteps
1293
+ beta_start = scale * 0.0001
1294
+ beta_end = scale * 0.02
1295
+ return np.linspace(
1296
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
1297
+ )
1298
+ elif schedule_name == "cosine":
1299
+ return betas_for_alpha_bar(
1300
+ num_diffusion_timesteps,
1301
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
1302
+ )
1303
+ else:
1304
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
1305
+
1306
+
1307
+ class SpacedDiffusion(GaussianDiffusion):
1308
+ """
1309
+ A diffusion process which can skip steps in a base diffusion process.
1310
+
1311
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
1312
+ original diffusion process to retain.
1313
+ :param kwargs: the kwargs to create the base diffusion process.
1314
+ """
1315
+
1316
+ def __init__(self, use_timesteps, **kwargs):
1317
+ self.use_timesteps = set(use_timesteps)
1318
+ self.timestep_map = []
1319
+ self.original_num_steps = len(kwargs["betas"])
1320
+
1321
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
1322
+ last_alpha_cumprod = 1.0
1323
+ new_betas = []
1324
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
1325
+ if i in self.use_timesteps:
1326
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
1327
+ last_alpha_cumprod = alpha_cumprod
1328
+ self.timestep_map.append(i)
1329
+ kwargs["betas"] = np.array(new_betas)
1330
+ super().__init__(**kwargs)
1331
+
1332
+ def p_mean_variance(
1333
+ self, model, *args, **kwargs
1334
+ ): # pylint: disable=signature-differs
1335
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
1336
+
1337
+ def training_losses(
1338
+ self, model, *args, **kwargs
1339
+ ): # pylint: disable=signature-differs
1340
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
1341
+
1342
+ def autoregressive_training_losses(
1343
+ self, model, *args, **kwargs
1344
+ ): # pylint: disable=signature-differs
1345
+ return super().autoregressive_training_losses(
1346
+ self._wrap_model(model, True), *args, **kwargs
1347
+ )
1348
+
1349
+ def condition_mean(self, cond_fn, *args, **kwargs):
1350
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
1351
+
1352
+ def condition_score(self, cond_fn, *args, **kwargs):
1353
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
1354
+
1355
+ def _wrap_model(self, model, autoregressive=False):
1356
+ if isinstance(model, _WrappedModel) or isinstance(
1357
+ model, _WrappedAutoregressiveModel
1358
+ ):
1359
+ return model
1360
+ mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel
1361
+ return mod(
1362
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
1363
+ )
1364
+
1365
+ def _scale_timesteps(self, t):
1366
+ # Scaling is done by the wrapped model.
1367
+ return t
1368
+
1369
+
1370
+ def space_timesteps(num_timesteps, section_counts):
1371
+ """
1372
+ Create a list of timesteps to use from an original diffusion process,
1373
+ given the number of timesteps we want to take from equally-sized portions
1374
+ of the original process.
1375
+
1376
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
1377
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
1378
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
1379
+
1380
+ If the stride is a string starting with "ddim", then the fixed striding
1381
+ from the DDIM paper is used, and only one section is allowed.
1382
+
1383
+ :param num_timesteps: the number of diffusion steps in the original
1384
+ process to divide up.
1385
+ :param section_counts: either a list of numbers, or a string containing
1386
+ comma-separated numbers, indicating the step count
1387
+ per section. As a special case, use "ddimN" where N
1388
+ is a number of steps to use the striding from the
1389
+ DDIM paper.
1390
+ :return: a set of diffusion steps from the original process to use.
1391
+ """
1392
+ if isinstance(section_counts, str):
1393
+ if section_counts.startswith("ddim"):
1394
+ desired_count = int(section_counts[len("ddim") :])
1395
+ for i in range(1, num_timesteps):
1396
+ if len(range(0, num_timesteps, i)) == desired_count:
1397
+ return set(range(0, num_timesteps, i))
1398
+ raise ValueError(
1399
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
1400
+ )
1401
+ section_counts = [int(x) for x in section_counts.split(",")]
1402
+ size_per = num_timesteps // len(section_counts)
1403
+ extra = num_timesteps % len(section_counts)
1404
+ start_idx = 0
1405
+ all_steps = []
1406
+ for i, section_count in enumerate(section_counts):
1407
+ size = size_per + (1 if i < extra else 0)
1408
+ if size < section_count:
1409
+ raise ValueError(
1410
+ f"cannot divide section of {size} steps into {section_count}"
1411
+ )
1412
+ if section_count <= 1:
1413
+ frac_stride = 1
1414
+ else:
1415
+ frac_stride = (size - 1) / (section_count - 1)
1416
+ cur_idx = 0.0
1417
+ taken_steps = []
1418
+ for _ in range(section_count):
1419
+ taken_steps.append(start_idx + round(cur_idx))
1420
+ cur_idx += frac_stride
1421
+ all_steps += taken_steps
1422
+ start_idx += size
1423
+ return set(all_steps)
1424
+
1425
+
1426
+ class _WrappedModel:
1427
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
1428
+ self.model = model
1429
+ self.timestep_map = timestep_map
1430
+ self.rescale_timesteps = rescale_timesteps
1431
+ self.original_num_steps = original_num_steps
1432
+
1433
+ def __call__(self, x, ts, **kwargs):
1434
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
1435
+ new_ts = map_tensor[ts]
1436
+ if self.rescale_timesteps:
1437
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
1438
+ return self.model(x, new_ts, **kwargs)
1439
+
1440
+
1441
+ class _WrappedAutoregressiveModel:
1442
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
1443
+ self.model = model
1444
+ self.timestep_map = timestep_map
1445
+ self.rescale_timesteps = rescale_timesteps
1446
+ self.original_num_steps = original_num_steps
1447
+
1448
+ def __call__(self, x, x0, ts, **kwargs):
1449
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
1450
+ new_ts = map_tensor[ts]
1451
+ if self.rescale_timesteps:
1452
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
1453
+ return self.model(x, x0, new_ts, **kwargs)
1454
+
1455
+
1456
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
1457
+ """
1458
+ Extract values from a 1-D numpy array for a batch of indices.
1459
+
1460
+ :param arr: the 1-D numpy array.
1461
+ :param timesteps: a tensor of indices into the array to extract.
1462
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1463
+ dimension equal to the length of timesteps.
1464
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1465
+ """
1466
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1467
+ while len(res.shape) < len(broadcast_shape):
1468
+ res = res[..., None]
1469
+ return res.expand(broadcast_shape)
tortoise/utils/stft.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BSD 3-Clause License
3
+
4
+ Copyright (c) 2017, Prem Seetharaman
5
+ All rights reserved.
6
+
7
+ * Redistribution and use in source and binary forms, with or without
8
+ modification, are permitted provided that the following conditions are met:
9
+
10
+ * Redistributions of source code must retain the above copyright notice,
11
+ this list of conditions and the following disclaimer.
12
+
13
+ * Redistributions in binary form must reproduce the above copyright notice, this
14
+ list of conditions and the following disclaimer in the
15
+ documentation and/or other materials provided with the distribution.
16
+
17
+ * Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived from this
19
+ software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ """
32
+
33
+ import librosa.util as librosa_util
34
+ import numpy as np
35
+ import torch
36
+ import torch.nn.functional as F
37
+ from librosa.util import pad_center, tiny
38
+ from scipy.signal import get_window
39
+ from torch.autograd import Variable
40
+
41
+
42
+ def window_sumsquare(
43
+ window,
44
+ n_frames,
45
+ hop_length=200,
46
+ win_length=800,
47
+ n_fft=800,
48
+ dtype=np.float32,
49
+ norm=None,
50
+ ):
51
+ """
52
+ # from librosa 0.6
53
+ Compute the sum-square envelope of a window function at a given hop length.
54
+
55
+ This is used to estimate modulation effects induced by windowing
56
+ observations in short-time fourier transforms.
57
+
58
+ Parameters
59
+ ----------
60
+ window : string, tuple, number, callable, or list-like
61
+ Window specification, as in `get_window`
62
+
63
+ n_frames : int > 0
64
+ The number of analysis frames
65
+
66
+ hop_length : int > 0
67
+ The number of samples to advance between frames
68
+
69
+ win_length : [optional]
70
+ The length of the window function. By default, this matches `n_fft`.
71
+
72
+ n_fft : int > 0
73
+ The length of each analysis frame.
74
+
75
+ dtype : np.dtype
76
+ The data type of the output
77
+
78
+ Returns
79
+ -------
80
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
81
+ The sum-squared envelope of the window function
82
+ """
83
+ if win_length is None:
84
+ win_length = n_fft
85
+
86
+ n = n_fft + hop_length * (n_frames - 1)
87
+ x = np.zeros(n, dtype=dtype)
88
+
89
+ # Compute the squared window at the desired length
90
+ win_sq = get_window(window, win_length, fftbins=True)
91
+ win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
92
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
93
+
94
+ # Fill the envelope
95
+ for i in range(n_frames):
96
+ sample = i * hop_length
97
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
98
+ return x
99
+
100
+
101
+ class STFT(torch.nn.Module):
102
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
103
+
104
+ def __init__(
105
+ self, filter_length=800, hop_length=200, win_length=800, window="hann"
106
+ ):
107
+ super(STFT, self).__init__()
108
+ self.filter_length = filter_length
109
+ self.hop_length = hop_length
110
+ self.win_length = win_length
111
+ self.window = window
112
+ self.forward_transform = None
113
+ scale = self.filter_length / self.hop_length
114
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
115
+
116
+ cutoff = int((self.filter_length / 2 + 1))
117
+ fourier_basis = np.vstack(
118
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
119
+ )
120
+
121
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
122
+ inverse_basis = torch.FloatTensor(
123
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :]
124
+ )
125
+
126
+ if window is not None:
127
+ assert filter_length >= win_length
128
+ # get window and zero center pad it to filter_length
129
+ fft_window = get_window(window, win_length, fftbins=True)
130
+ fft_window = pad_center(fft_window, size=filter_length)
131
+ fft_window = torch.from_numpy(fft_window).float()
132
+
133
+ # window the bases
134
+ forward_basis *= fft_window
135
+ inverse_basis *= fft_window
136
+
137
+ self.register_buffer("forward_basis", forward_basis.float())
138
+ self.register_buffer("inverse_basis", inverse_basis.float())
139
+
140
+ def transform(self, input_data):
141
+ num_batches = input_data.size(0)
142
+ num_samples = input_data.size(1)
143
+
144
+ self.num_samples = num_samples
145
+
146
+ # similar to librosa, reflect-pad the input
147
+ input_data = input_data.view(num_batches, 1, num_samples)
148
+ input_data = F.pad(
149
+ input_data.unsqueeze(1),
150
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
151
+ mode="reflect",
152
+ )
153
+ input_data = input_data.squeeze(1)
154
+
155
+ forward_transform = F.conv1d(
156
+ input_data,
157
+ Variable(self.forward_basis, requires_grad=False),
158
+ stride=self.hop_length,
159
+ padding=0,
160
+ )
161
+
162
+ cutoff = int((self.filter_length / 2) + 1)
163
+ real_part = forward_transform[:, :cutoff, :]
164
+ imag_part = forward_transform[:, cutoff:, :]
165
+
166
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
167
+ phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
168
+
169
+ return magnitude, phase
170
+
171
+ def inverse(self, magnitude, phase):
172
+ recombine_magnitude_phase = torch.cat(
173
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
174
+ )
175
+
176
+ inverse_transform = F.conv_transpose1d(
177
+ recombine_magnitude_phase,
178
+ Variable(self.inverse_basis, requires_grad=False),
179
+ stride=self.hop_length,
180
+ padding=0,
181
+ )
182
+
183
+ if self.window is not None:
184
+ window_sum = window_sumsquare(
185
+ self.window,
186
+ magnitude.size(-1),
187
+ hop_length=self.hop_length,
188
+ win_length=self.win_length,
189
+ n_fft=self.filter_length,
190
+ dtype=np.float32,
191
+ )
192
+ # remove modulation effects
193
+ approx_nonzero_indices = torch.from_numpy(
194
+ np.where(window_sum > tiny(window_sum))[0]
195
+ )
196
+ window_sum = torch.autograd.Variable(
197
+ torch.from_numpy(window_sum), requires_grad=False
198
+ )
199
+ window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
200
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
201
+ approx_nonzero_indices
202
+ ]
203
+
204
+ # scale by hop ratio
205
+ inverse_transform *= float(self.filter_length) / self.hop_length
206
+
207
+ inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
208
+ inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
209
+
210
+ return inverse_transform
211
+
212
+ def forward(self, input_data):
213
+ self.magnitude, self.phase = self.transform(input_data)
214
+ reconstruction = self.inverse(self.magnitude, self.phase)
215
+ return reconstruction
tortoise/utils/text.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def split_and_recombine_text(text, desired_length=200, max_length=300):
5
+ """Split text it into chunks of a desired length trying to keep sentences intact."""
6
+ # normalize text, remove redundant whitespace and convert non-ascii quotes to ascii
7
+ text = re.sub(r"\n\n+", "\n", text)
8
+ text = re.sub(r"\s+", " ", text)
9
+ text = re.sub(r"[“”]", '"', text)
10
+
11
+ rv = []
12
+ in_quote = False
13
+ current = ""
14
+ split_pos = []
15
+ pos = -1
16
+ end_pos = len(text) - 1
17
+
18
+ def seek(delta):
19
+ nonlocal pos, in_quote, current
20
+ is_neg = delta < 0
21
+ for _ in range(abs(delta)):
22
+ if is_neg:
23
+ pos -= 1
24
+ current = current[:-1]
25
+ else:
26
+ pos += 1
27
+ current += text[pos]
28
+ if text[pos] == '"':
29
+ in_quote = not in_quote
30
+ return text[pos]
31
+
32
+ def peek(delta):
33
+ p = pos + delta
34
+ return text[p] if p < end_pos and p >= 0 else ""
35
+
36
+ def commit():
37
+ nonlocal rv, current, split_pos
38
+ rv.append(current)
39
+ current = ""
40
+ split_pos = []
41
+
42
+ while pos < end_pos:
43
+ c = seek(1)
44
+ # do we need to force a split?
45
+ if len(current) >= max_length:
46
+ if len(split_pos) > 0 and len(current) > (desired_length / 2):
47
+ # we have at least one sentence and we are over half the desired length, seek back to the last split
48
+ d = pos - split_pos[-1]
49
+ seek(-d)
50
+ else:
51
+ # no full sentences, seek back until we are not in the middle of a word and split there
52
+ while c not in "!?.\n " and pos > 0 and len(current) > desired_length:
53
+ c = seek(-1)
54
+ commit()
55
+ # check for sentence boundaries
56
+ elif not in_quote and (c in "!?\n" or (c == "." and peek(1) in "\n ")):
57
+ # seek forward if we have consecutive boundary markers but still within the max length
58
+ while (
59
+ pos < len(text) - 1 and len(current) < max_length and peek(1) in "!?."
60
+ ):
61
+ c = seek(1)
62
+ split_pos.append(pos)
63
+ if len(current) >= desired_length:
64
+ commit()
65
+ # treat end of quote as a boundary if its followed by a space or newline
66
+ elif in_quote and peek(1) == '"' and peek(2) in "\n ":
67
+ seek(2)
68
+ split_pos.append(pos)
69
+ rv.append(current)
70
+
71
+ # clean up, remove lines with only whitespace or punctuation
72
+ rv = [s.strip() for s in rv]
73
+ rv = [s for s in rv if len(s) > 0 and not re.match(r"^[\s\.,;:!?]*$", s)]
74
+
75
+ return rv
76
+
77
+
78
+ if __name__ == "__main__":
79
+ import os
80
+ import unittest
81
+
82
+ class Test(unittest.TestCase):
83
+ def test_split_and_recombine_text(self):
84
+ text = """
85
+ This is a sample sentence.
86
+ This is another sample sentence.
87
+ This is a longer sample sentence that should force a split inthemiddlebutinotinthislongword.
88
+ "Don't split my quote... please"
89
+ """
90
+ self.assertEqual(
91
+ split_and_recombine_text(text, desired_length=20, max_length=40),
92
+ [
93
+ "This is a sample sentence.",
94
+ "This is another sample sentence.",
95
+ "This is a longer sample sentence that",
96
+ "should force a split",
97
+ "inthemiddlebutinotinthislongword.",
98
+ '"Don\'t split my quote... please"',
99
+ ],
100
+ )
101
+
102
+ def test_split_and_recombine_text_2(self):
103
+ text = """
104
+ When you are really angry sometimes you use consecutive exclamation marks!!!!!! Is this a good thing to do?!?!?!
105
+ I don't know but we should handle this situation..........................
106
+ """
107
+ self.assertEqual(
108
+ split_and_recombine_text(text, desired_length=30, max_length=50),
109
+ [
110
+ "When you are really angry sometimes you use",
111
+ "consecutive exclamation marks!!!!!!",
112
+ "Is this a good thing to do?!?!?!",
113
+ "I don't know but we should handle this situation.",
114
+ ],
115
+ )
116
+
117
+ def test_split_and_recombine_text_3(self):
118
+ text_src = os.path.join(
119
+ os.path.dirname(__file__), "../data/riding_hood.txt"
120
+ )
121
+ with open(text_src, "r") as f:
122
+ text = f.read()
123
+ self.assertEqual(
124
+ split_and_recombine_text(text),
125
+ [
126
+ "Once upon a time there lived in a certain village a little country girl, the prettiest creature who was ever seen. Her mother was excessively fond of her; and her grandmother doted on her still more. This good woman had a little red riding hood made for her.",
127
+ 'It suited the girl so extremely well that everybody called her Little Red Riding Hood. One day her mother, having made some cakes, said to her, "Go, my dear, and see how your grandmother is doing, for I hear she has been very ill. Take her a cake, and this little pot of butter."',
128
+ "Little Red Riding Hood set out immediately to go to her grandmother, who lived in another village. As she was going through the wood, she met with a wolf, who had a very great mind to eat her up, but he dared not, because of some woodcutters working nearby in the forest.",
129
+ 'He asked her where she was going. The poor child, who did not know that it was dangerous to stay and talk to a wolf, said to him, "I am going to see my grandmother and carry her a cake and a little pot of butter from my mother." "Does she live far off?" said the wolf "Oh I say,"',
130
+ 'answered Little Red Riding Hood; "it is beyond that mill you see there, at the first house in the village." "Well," said the wolf, "and I\'ll go and see her too. I\'ll go this way and go you that, and we shall see who will be there first."',
131
+ "The wolf ran as fast as he could, taking the shortest path, and the little girl took a roundabout way, entertaining herself by gathering nuts, running after butterflies, and gathering bouquets of little flowers.",
132
+ 'It was not long before the wolf arrived at the old woman\'s house. He knocked at the door: tap, tap. "Who\'s there?" "Your grandchild, Little Red Riding Hood," replied the wolf, counterfeiting her voice; "who has brought you a cake and a little pot of butter sent you by mother."',
133
+ 'The good grandmother, who was in bed, because she was somewhat ill, cried out, "Pull the bobbin, and the latch will go up."',
134
+ "The wolf pulled the bobbin, and the door opened, and then he immediately fell upon the good woman and ate her up in a moment, for it been more than three days since he had eaten.",
135
+ "He then shut the door and got into the grandmother's bed, expecting Little Red Riding Hood, who came some time afterwards and knocked at the door: tap, tap. \"Who's there?\"",
136
+ 'Little Red Riding Hood, hearing the big voice of the wolf, was at first afraid; but believing her grandmother had a cold and was hoarse, answered, "It is your grandchild Little Red Riding Hood, who has brought you a cake and a little pot of butter mother sends you."',
137
+ 'The wolf cried out to her, softening his voice as much as he could, "Pull the bobbin, and the latch will go up." Little Red Riding Hood pulled the bobbin, and the door opened.',
138
+ 'The wolf, seeing her come in, said to her, hiding himself under the bedclothes, "Put the cake and the little pot of butter upon the stool, and come get into bed with me." Little Red Riding Hood took off her clothes and got into bed.',
139
+ 'She was greatly amazed to see how her grandmother looked in her nightclothes, and said to her, "Grandmother, what big arms you have!" "All the better to hug you with, my dear." "Grandmother, what big legs you have!" "All the better to run with, my child." "Grandmother, what big ears you have!"',
140
+ '"All the better to hear with, my child." "Grandmother, what big eyes you have!" "All the better to see with, my child." "Grandmother, what big teeth you have got!" "All the better to eat you up with." And, saying these words, this wicked wolf fell upon Little Red Riding Hood, and ate her all up.',
141
+ ],
142
+ )
143
+
144
+ unittest.main()
tortoise/utils/tokenizer.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import inflect
5
+ import torch
6
+ from tokenizers import Tokenizer
7
+
8
+ # Regular expression matching whitespace:
9
+ from unidecode import unidecode
10
+
11
+ _whitespace_re = re.compile(r"\s+")
12
+
13
+
14
+ # List of (regular expression, replacement) pairs for abbreviations:
15
+ _abbreviations = [
16
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
17
+ for x in [
18
+ ("mrs", "misess"),
19
+ ("mr", "mister"),
20
+ ("dr", "doctor"),
21
+ ("st", "saint"),
22
+ ("co", "company"),
23
+ ("jr", "junior"),
24
+ ("maj", "major"),
25
+ ("gen", "general"),
26
+ ("drs", "doctors"),
27
+ ("rev", "reverend"),
28
+ ("lt", "lieutenant"),
29
+ ("hon", "honorable"),
30
+ ("sgt", "sergeant"),
31
+ ("capt", "captain"),
32
+ ("esq", "esquire"),
33
+ ("ltd", "limited"),
34
+ ("col", "colonel"),
35
+ ("ft", "fort"),
36
+ ]
37
+ ]
38
+
39
+
40
+ def expand_abbreviations(text):
41
+ for regex, replacement in _abbreviations:
42
+ text = re.sub(regex, replacement, text)
43
+ return text
44
+
45
+
46
+ _inflect = inflect.engine()
47
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
48
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
49
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
50
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
51
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
52
+ _number_re = re.compile(r"[0-9]+")
53
+
54
+
55
+ def _remove_commas(m):
56
+ return m.group(1).replace(",", "")
57
+
58
+
59
+ def _expand_decimal_point(m):
60
+ return m.group(1).replace(".", " point ")
61
+
62
+
63
+ def _expand_dollars(m):
64
+ match = m.group(1)
65
+ parts = match.split(".")
66
+ if len(parts) > 2:
67
+ return match + " dollars" # Unexpected format
68
+ dollars = int(parts[0]) if parts[0] else 0
69
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
70
+ if dollars and cents:
71
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
72
+ cent_unit = "cent" if cents == 1 else "cents"
73
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
74
+ elif dollars:
75
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
76
+ return "%s %s" % (dollars, dollar_unit)
77
+ elif cents:
78
+ cent_unit = "cent" if cents == 1 else "cents"
79
+ return "%s %s" % (cents, cent_unit)
80
+ else:
81
+ return "zero dollars"
82
+
83
+
84
+ def _expand_ordinal(m):
85
+ return _inflect.number_to_words(m.group(0))
86
+
87
+
88
+ def _expand_number(m):
89
+ num = int(m.group(0))
90
+ if num > 1000 and num < 3000:
91
+ if num == 2000:
92
+ return "two thousand"
93
+ elif num > 2000 and num < 2010:
94
+ return "two thousand " + _inflect.number_to_words(num % 100)
95
+ elif num % 100 == 0:
96
+ return _inflect.number_to_words(num // 100) + " hundred"
97
+ else:
98
+ return _inflect.number_to_words(
99
+ num, andword="", zero="oh", group=2
100
+ ).replace(", ", " ")
101
+ else:
102
+ return _inflect.number_to_words(num, andword="")
103
+
104
+
105
+ def normalize_numbers(text):
106
+ text = re.sub(_comma_number_re, _remove_commas, text)
107
+ text = re.sub(_pounds_re, r"\1 pounds", text)
108
+ text = re.sub(_dollars_re, _expand_dollars, text)
109
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
110
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
111
+ text = re.sub(_number_re, _expand_number, text)
112
+ return text
113
+
114
+
115
+ def expand_numbers(text):
116
+ return normalize_numbers(text)
117
+
118
+
119
+ def lowercase(text):
120
+ return text.lower()
121
+
122
+
123
+ def collapse_whitespace(text):
124
+ return re.sub(_whitespace_re, " ", text)
125
+
126
+
127
+ def convert_to_ascii(text):
128
+ return unidecode(text)
129
+
130
+
131
+ def basic_cleaners(text):
132
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
133
+ text = lowercase(text)
134
+ text = collapse_whitespace(text)
135
+ return text
136
+
137
+
138
+ def transliteration_cleaners(text):
139
+ """Pipeline for non-English text that transliterates to ASCII."""
140
+ text = convert_to_ascii(text)
141
+ text = lowercase(text)
142
+ text = collapse_whitespace(text)
143
+ return text
144
+
145
+
146
+ def english_cleaners(text):
147
+ """Pipeline for English text, including number and abbreviation expansion."""
148
+ text = convert_to_ascii(text)
149
+ text = lowercase(text)
150
+ text = expand_numbers(text)
151
+ text = expand_abbreviations(text)
152
+ text = collapse_whitespace(text)
153
+ text = text.replace('"', "")
154
+ return text
155
+
156
+
157
+ def lev_distance(s1, s2):
158
+ if len(s1) > len(s2):
159
+ s1, s2 = s2, s1
160
+
161
+ distances = range(len(s1) + 1)
162
+ for i2, c2 in enumerate(s2):
163
+ distances_ = [i2 + 1]
164
+ for i1, c1 in enumerate(s1):
165
+ if c1 == c2:
166
+ distances_.append(distances[i1])
167
+ else:
168
+ distances_.append(
169
+ 1 + min((distances[i1], distances[i1 + 1], distances_[-1]))
170
+ )
171
+ distances = distances_
172
+ return distances[-1]
173
+
174
+
175
+ DEFAULT_VOCAB_FILE = os.path.join(
176
+ os.path.dirname(os.path.realpath(__file__)), "../data/tokenizer.json"
177
+ )
178
+
179
+
180
+ class VoiceBpeTokenizer:
181
+ def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
182
+ if vocab_file is not None:
183
+ self.tokenizer = Tokenizer.from_file(vocab_file)
184
+
185
+ def preprocess_text(self, txt):
186
+ txt = english_cleaners(txt)
187
+ return txt
188
+
189
+ def encode(self, txt):
190
+ txt = self.preprocess_text(txt)
191
+ txt = txt.replace(" ", "[SPACE]")
192
+ return self.tokenizer.encode(txt).ids
193
+
194
+ def decode(self, seq):
195
+ if isinstance(seq, torch.Tensor):
196
+ seq = seq.cpu().numpy()
197
+ txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "")
198
+ txt = txt.replace("[SPACE]", " ")
199
+ txt = txt.replace("[STOP]", "")
200
+ txt = txt.replace("[UNK]", "")
201
+ return txt
tortoise/utils/typical_sampling.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import LogitsWarper
3
+
4
+
5
+ class TypicalLogitsWarper(LogitsWarper):
6
+ def __init__(
7
+ self,
8
+ mass: float = 0.9,
9
+ filter_value: float = -float("Inf"),
10
+ min_tokens_to_keep: int = 1,
11
+ ):
12
+ self.filter_value = filter_value
13
+ self.mass = mass
14
+ self.min_tokens_to_keep = min_tokens_to_keep
15
+
16
+ def __call__(
17
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
18
+ ) -> torch.FloatTensor:
19
+ # calculate entropy
20
+ normalized = torch.nn.functional.log_softmax(scores, dim=-1)
21
+ p = torch.exp(normalized)
22
+ ent = -(normalized * p).nansum(-1, keepdim=True)
23
+
24
+ # shift and sort
25
+ shifted_scores = torch.abs((-normalized) - ent)
26
+ sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
27
+ sorted_logits = scores.gather(-1, sorted_indices)
28
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
29
+
30
+ # Remove tokens with cumulative mass above the threshold
31
+ last_ind = (cumulative_probs < self.mass).sum(dim=1)
32
+ last_ind[last_ind < 0] = 0
33
+ sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
34
+ 1, last_ind.view(-1, 1)
35
+ )
36
+ if self.min_tokens_to_keep > 1:
37
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
38
+ sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
39
+ indices_to_remove = sorted_indices_to_remove.scatter(
40
+ 1, sorted_indices, sorted_indices_to_remove
41
+ )
42
+
43
+ scores = scores.masked_fill(indices_to_remove, self.filter_value)
44
+ return scores
tortoise/utils/wav2vec_alignment.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2ForCTC
4
+
5
+
6
+ def max_alignment(s1, s2, skip_character="~", record=None):
7
+ """
8
+ A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is
9
+ used to replace that character.
10
+
11
+ Finally got to use my DP skills!
12
+ """
13
+ if record is None:
14
+ record = {}
15
+ assert (
16
+ skip_character not in s1
17
+ ), f"Found the skip character {skip_character} in the provided string, {s1}"
18
+ if len(s1) == 0:
19
+ return ""
20
+ if len(s2) == 0:
21
+ return skip_character * len(s1)
22
+ if s1 == s2:
23
+ return s1
24
+ if s1[0] == s2[0]:
25
+ return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record)
26
+
27
+ take_s1_key = (len(s1), len(s2) - 1)
28
+ if take_s1_key in record:
29
+ take_s1, take_s1_score = record[take_s1_key]
30
+ else:
31
+ take_s1 = max_alignment(s1, s2[1:], skip_character, record)
32
+ take_s1_score = len(take_s1.replace(skip_character, ""))
33
+ record[take_s1_key] = (take_s1, take_s1_score)
34
+
35
+ take_s2_key = (len(s1) - 1, len(s2))
36
+ if take_s2_key in record:
37
+ take_s2, take_s2_score = record[take_s2_key]
38
+ else:
39
+ take_s2 = max_alignment(s1[1:], s2, skip_character, record)
40
+ take_s2_score = len(take_s2.replace(skip_character, ""))
41
+ record[take_s2_key] = (take_s2, take_s2_score)
42
+
43
+ return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2
44
+
45
+
46
+ class Wav2VecAlignment:
47
+ """
48
+ Uses wav2vec2 to perform audio<->text alignment.
49
+ """
50
+
51
+ def __init__(self, device="cuda"):
52
+ self.model = Wav2Vec2ForCTC.from_pretrained(
53
+ "jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli"
54
+ ).cpu()
55
+ self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
56
+ "facebook/wav2vec2-large-960h"
57
+ )
58
+ self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
59
+ "jbetker/tacotron-symbols"
60
+ )
61
+ self.device = device
62
+
63
+ def align(self, audio, expected_text, audio_sample_rate=24000):
64
+ orig_len = audio.shape[-1]
65
+
66
+ with torch.no_grad():
67
+ self.model = self.model.to(self.device)
68
+ audio = audio.to(self.device)
69
+ audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
70
+ clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
71
+ logits = self.model(clip_norm).logits
72
+ self.model = self.model.cpu()
73
+
74
+ logits = logits[0]
75
+ pred_string = self.tokenizer.decode(logits.argmax(-1).tolist())
76
+
77
+ fixed_expectation = max_alignment(expected_text.lower(), pred_string)
78
+ w2v_compression = orig_len // logits.shape[0]
79
+ expected_tokens = self.tokenizer.encode(fixed_expectation)
80
+ expected_chars = list(fixed_expectation)
81
+ if len(expected_tokens) == 1:
82
+ return [0] # The alignment is simple; there is only one token.
83
+ expected_tokens.pop(0) # The first token is a given.
84
+ expected_chars.pop(0)
85
+
86
+ alignments = [0]
87
+
88
+ def pop_till_you_win():
89
+ if len(expected_tokens) == 0:
90
+ return None
91
+ popped = expected_tokens.pop(0)
92
+ popped_char = expected_chars.pop(0)
93
+ while popped_char == "~":
94
+ alignments.append(-1)
95
+ if len(expected_tokens) == 0:
96
+ return None
97
+ popped = expected_tokens.pop(0)
98
+ popped_char = expected_chars.pop(0)
99
+ return popped
100
+
101
+ next_expected_token = pop_till_you_win()
102
+ for i, logit in enumerate(logits):
103
+ top = logit.argmax()
104
+ if next_expected_token == top:
105
+ alignments.append(i * w2v_compression)
106
+ if len(expected_tokens) > 0:
107
+ next_expected_token = pop_till_you_win()
108
+ else:
109
+ break
110
+
111
+ pop_till_you_win()
112
+ if not (len(expected_tokens) == 0 and len(alignments) == len(expected_text)):
113
+ torch.save([audio, expected_text], "alignment_debug.pth")
114
+ assert False, (
115
+ "Something went wrong with the alignment algorithm. I've dumped a file, 'alignment_debug.pth' to"
116
+ "your current working directory. Please report this along with the file so it can get fixed."
117
+ )
118
+
119
+ # Now fix up alignments. Anything with -1 should be interpolated.
120
+ alignments.append(
121
+ orig_len
122
+ ) # This'll get removed but makes the algorithm below more readable.
123
+ for i in range(len(alignments)):
124
+ if alignments[i] == -1:
125
+ for j in range(i + 1, len(alignments)):
126
+ if alignments[j] != -1:
127
+ next_found_token = j
128
+ break
129
+ for j in range(i, next_found_token):
130
+ gap = alignments[next_found_token] - alignments[i - 1]
131
+ alignments[j] = (j - i + 1) * gap // (
132
+ next_found_token - i + 1
133
+ ) + alignments[i - 1]
134
+
135
+ return alignments[:-1]
136
+
137
+ def redact(self, audio, expected_text, audio_sample_rate=24000):
138
+ if "[" not in expected_text:
139
+ return audio
140
+ splitted = expected_text.split("[")
141
+ fully_split = [splitted[0]]
142
+ for spl in splitted[1:]:
143
+ assert (
144
+ "]" in spl
145
+ ), 'Every "[" character must be paired with a "]" with no nesting.'
146
+ fully_split.extend(spl.split("]"))
147
+
148
+ # At this point, fully_split is a list of strings, with every other string being something that should be redacted.
149
+ non_redacted_intervals = []
150
+ last_point = 0
151
+ for i in range(len(fully_split)):
152
+ if i % 2 == 0:
153
+ end_interval = max(0, last_point + len(fully_split[i]) - 1)
154
+ non_redacted_intervals.append((last_point, end_interval))
155
+ last_point += len(fully_split[i])
156
+
157
+ bare_text = "".join(fully_split)
158
+ alignments = self.align(audio, bare_text, audio_sample_rate)
159
+
160
+ output_audio = []
161
+ for nri in non_redacted_intervals:
162
+ start, stop = nri
163
+ output_audio.append(audio[:, alignments[start] : alignments[stop]])
164
+ return torch.cat(output_audio, dim=-1)
tortoise/voices/william/1.wav ADDED
Binary file (266 kB). View file
 
tortoise/voices/william/2.wav ADDED
Binary file (631 kB). View file
 
tortoise/voices/william/3.wav ADDED
Binary file (682 kB). View file
 
tortoise/voices/william/4.wav ADDED
Binary file (471 kB). View file