arnavmehta7 commited on
Commit
8520a55
1 Parent(s): 673f68d
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ docs/assets/intro_vid.mp4 filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU AFFERO GENERAL PUBLIC LICENSE
2
+ Version 3, 19 November 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU Affero General Public License is a free, copyleft license for
11
+ software and other kinds of works, specifically designed to ensure
12
+ cooperation with the community in the case of network server software.
13
+
14
+ The licenses for most software and other practical works are designed
15
+ to take away your freedom to share and change the works. By contrast,
16
+ our General Public Licenses are intended to guarantee your freedom to
17
+ share and change all versions of a program--to make sure it remains free
18
+ software for all its users.
19
+
20
+ When we speak of free software, we are referring to freedom, not
21
+ price. Our General Public Licenses are designed to make sure that you
22
+ have the freedom to distribute copies of free software (and charge for
23
+ them if you wish), that you receive source code or can get it if you
24
+ want it, that you can change the software or use pieces of it in new
25
+ free programs, and that you know you can do these things.
26
+
27
+ Developers that use our General Public Licenses protect your rights
28
+ with two steps: (1) assert copyright on the software, and (2) offer
29
+ you this License which gives you legal permission to copy, distribute
30
+ and/or modify the software.
31
+
32
+ A secondary benefit of defending all users' freedom is that
33
+ improvements made in alternate versions of the program, if they
34
+ receive widespread use, become available for other developers to
35
+ incorporate. Many developers of free software are heartened and
36
+ encouraged by the resulting cooperation. However, in the case of
37
+ software used on network servers, this result may fail to come about.
38
+ The GNU General Public License permits making a modified version and
39
+ letting the public access it on a server without ever releasing its
40
+ source code to the public.
41
+
42
+ The GNU Affero General Public License is designed specifically to
43
+ ensure that, in such cases, the modified source code becomes available
44
+ to the community. It requires the operator of a network server to
45
+ provide the source code of the modified version running there to the
46
+ users of that server. Therefore, public use of a modified version, on
47
+ a publicly accessible server, gives the public access to the source
48
+ code of the modified version.
49
+
50
+ An older license, called the Affero General Public License and
51
+ published by Affero, was designed to accomplish similar goals. This is
52
+ a different license, not a version of the Affero GPL, but Affero has
53
+ released a new version of the Affero GPL which permits relicensing under
54
+ this license.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ TERMS AND CONDITIONS
60
+
61
+ 0. Definitions.
62
+
63
+ "This License" refers to version 3 of the GNU Affero General Public License.
64
+
65
+ "Copyright" also means copyright-like laws that apply to other kinds of
66
+ works, such as semiconductor masks.
67
+
68
+ "The Program" refers to any copyrightable work licensed under this
69
+ License. Each licensee is addressed as "you". "Licensees" and
70
+ "recipients" may be individuals or organizations.
71
+
72
+ To "modify" a work means to copy from or adapt all or part of the work
73
+ in a fashion requiring copyright permission, other than the making of an
74
+ exact copy. The resulting work is called a "modified version" of the
75
+ earlier work or a work "based on" the earlier work.
76
+
77
+ A "covered work" means either the unmodified Program or a work based
78
+ on the Program.
79
+
80
+ To "propagate" a work means to do anything with it that, without
81
+ permission, would make you directly or secondarily liable for
82
+ infringement under applicable copyright law, except executing it on a
83
+ computer or modifying a private copy. Propagation includes copying,
84
+ distribution (with or without modification), making available to the
85
+ public, and in some countries other activities as well.
86
+
87
+ To "convey" a work means any kind of propagation that enables other
88
+ parties to make or receive copies. Mere interaction with a user through
89
+ a computer network, with no transfer of a copy, is not conveying.
90
+
91
+ An interactive user interface displays "Appropriate Legal Notices"
92
+ to the extent that it includes a convenient and prominently visible
93
+ feature that (1) displays an appropriate copyright notice, and (2)
94
+ tells the user that there is no warranty for the work (except to the
95
+ extent that warranties are provided), that licensees may convey the
96
+ work under this License, and how to view a copy of this License. If
97
+ the interface presents a list of user commands or options, such as a
98
+ menu, a prominent item in the list meets this criterion.
99
+
100
+ 1. Source Code.
101
+
102
+ The "source code" for a work means the preferred form of the work
103
+ for making modifications to it. "Object code" means any non-source
104
+ form of a work.
105
+
106
+ A "Standard Interface" means an interface that either is an official
107
+ standard defined by a recognized standards body, or, in the case of
108
+ interfaces specified for a particular programming language, one that
109
+ is widely used among developers working in that language.
110
+
111
+ The "System Libraries" of an executable work include anything, other
112
+ than the work as a whole, that (a) is included in the normal form of
113
+ packaging a Major Component, but which is not part of that Major
114
+ Component, and (b) serves only to enable use of the work with that
115
+ Major Component, or to implement a Standard Interface for which an
116
+ implementation is available to the public in source code form. A
117
+ "Major Component", in this context, means a major essential component
118
+ (kernel, window system, and so on) of the specific operating system
119
+ (if any) on which the executable work runs, or a compiler used to
120
+ produce the work, or an object code interpreter used to run it.
121
+
122
+ The "Corresponding Source" for a work in object code form means all
123
+ the source code needed to generate, install, and (for an executable
124
+ work) run the object code and to modify the work, including scripts to
125
+ control those activities. However, it does not include the work's
126
+ System Libraries, or general-purpose tools or generally available free
127
+ programs which are used unmodified in performing those activities but
128
+ which are not part of the work. For example, Corresponding Source
129
+ includes interface definition files associated with source files for
130
+ the work, and the source code for shared libraries and dynamically
131
+ linked subprograms that the work is specifically designed to require,
132
+ such as by intimate data communication or control flow between those
133
+ subprograms and other parts of the work.
134
+
135
+ The Corresponding Source need not include anything that users
136
+ can regenerate automatically from other parts of the Corresponding
137
+ Source.
138
+
139
+ The Corresponding Source for a work in source code form is that
140
+ same work.
141
+
142
+ 2. Basic Permissions.
143
+
144
+ All rights granted under this License are granted for the term of
145
+ copyright on the Program, and are irrevocable provided the stated
146
+ conditions are met. This License explicitly affirms your unlimited
147
+ permission to run the unmodified Program. The output from running a
148
+ covered work is covered by this License only if the output, given its
149
+ content, constitutes a covered work. This License acknowledges your
150
+ rights of fair use or other equivalent, as provided by copyright law.
151
+
152
+ You may make, run and propagate covered works that you do not
153
+ convey, without conditions so long as your license otherwise remains
154
+ in force. You may convey covered works to others for the sole purpose
155
+ of having them make modifications exclusively for you, or provide you
156
+ with facilities for running those works, provided that you comply with
157
+ the terms of this License in conveying all material for which you do
158
+ not control copyright. Those thus making or running the covered works
159
+ for you must do so exclusively on your behalf, under your direction
160
+ and control, on terms that prohibit them from making any copies of
161
+ your copyrighted material outside their relationship with you.
162
+
163
+ Conveying under any other circumstances is permitted solely under
164
+ the conditions stated below. Sublicensing is not allowed; section 10
165
+ makes it unnecessary.
166
+
167
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168
+
169
+ No covered work shall be deemed part of an effective technological
170
+ measure under any applicable law fulfilling obligations under article
171
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172
+ similar laws prohibiting or restricting circumvention of such
173
+ measures.
174
+
175
+ When you convey a covered work, you waive any legal power to forbid
176
+ circumvention of technological measures to the extent such circumvention
177
+ is effected by exercising rights under this License with respect to
178
+ the covered work, and you disclaim any intention to limit operation or
179
+ modification of the work as a means of enforcing, against the work's
180
+ users, your or third parties' legal rights to forbid circumvention of
181
+ technological measures.
182
+
183
+ 4. Conveying Verbatim Copies.
184
+
185
+ You may convey verbatim copies of the Program's source code as you
186
+ receive it, in any medium, provided that you conspicuously and
187
+ appropriately publish on each copy an appropriate copyright notice;
188
+ keep intact all notices stating that this License and any
189
+ non-permissive terms added in accord with section 7 apply to the code;
190
+ keep intact all notices of the absence of any warranty; and give all
191
+ recipients a copy of this License along with the Program.
192
+
193
+ You may charge any price or no price for each copy that you convey,
194
+ and you may offer support or warranty protection for a fee.
195
+
196
+ 5. Conveying Modified Source Versions.
197
+
198
+ You may convey a work based on the Program, or the modifications to
199
+ produce it from the Program, in the form of source code under the
200
+ terms of section 4, provided that you also meet all of these conditions:
201
+
202
+ a) The work must carry prominent notices stating that you modified
203
+ it, and giving a relevant date.
204
+
205
+ b) The work must carry prominent notices stating that it is
206
+ released under this License and any conditions added under section
207
+ 7. This requirement modifies the requirement in section 4 to
208
+ "keep intact all notices".
209
+
210
+ c) You must license the entire work, as a whole, under this
211
+ License to anyone who comes into possession of a copy. This
212
+ License will therefore apply, along with any applicable section 7
213
+ additional terms, to the whole of the work, and all its parts,
214
+ regardless of how they are packaged. This License gives no
215
+ permission to license the work in any other way, but it does not
216
+ invalidate such permission if you have separately received it.
217
+
218
+ d) If the work has interactive user interfaces, each must display
219
+ Appropriate Legal Notices; however, if the Program has interactive
220
+ interfaces that do not display Appropriate Legal Notices, your
221
+ work need not make them do so.
222
+
223
+ A compilation of a covered work with other separate and independent
224
+ works, which are not by their nature extensions of the covered work,
225
+ and which are not combined with it such as to form a larger program,
226
+ in or on a volume of a storage or distribution medium, is called an
227
+ "aggregate" if the compilation and its resulting copyright are not
228
+ used to limit the access or legal rights of the compilation's users
229
+ beyond what the individual works permit. Inclusion of a covered work
230
+ in an aggregate does not cause this License to apply to the other
231
+ parts of the aggregate.
232
+
233
+ 6. Conveying Non-Source Forms.
234
+
235
+ You may convey a covered work in object code form under the terms
236
+ of sections 4 and 5, provided that you also convey the
237
+ machine-readable Corresponding Source under the terms of this License,
238
+ in one of these ways:
239
+
240
+ a) Convey the object code in, or embodied in, a physical product
241
+ (including a physical distribution medium), accompanied by the
242
+ Corresponding Source fixed on a durable physical medium
243
+ customarily used for software interchange.
244
+
245
+ b) Convey the object code in, or embodied in, a physical product
246
+ (including a physical distribution medium), accompanied by a
247
+ written offer, valid for at least three years and valid for as
248
+ long as you offer spare parts or customer support for that product
249
+ model, to give anyone who possesses the object code either (1) a
250
+ copy of the Corresponding Source for all the software in the
251
+ product that is covered by this License, on a durable physical
252
+ medium customarily used for software interchange, for a price no
253
+ more than your reasonable cost of physically performing this
254
+ conveying of source, or (2) access to copy the
255
+ Corresponding Source from a network server at no charge.
256
+
257
+ c) Convey individual copies of the object code with a copy of the
258
+ written offer to provide the Corresponding Source. This
259
+ alternative is allowed only occasionally and noncommercially, and
260
+ only if you received the object code with such an offer, in accord
261
+ with subsection 6b.
262
+
263
+ d) Convey the object code by offering access from a designated
264
+ place (gratis or for a charge), and offer equivalent access to the
265
+ Corresponding Source in the same way through the same place at no
266
+ further charge. You need not require recipients to copy the
267
+ Corresponding Source along with the object code. If the place to
268
+ copy the object code is a network server, the Corresponding Source
269
+ may be on a different server (operated by you or a third party)
270
+ that supports equivalent copying facilities, provided you maintain
271
+ clear directions next to the object code saying where to find the
272
+ Corresponding Source. Regardless of what server hosts the
273
+ Corresponding Source, you remain obligated to ensure that it is
274
+ available for as long as needed to satisfy these requirements.
275
+
276
+ e) Convey the object code using peer-to-peer transmission, provided
277
+ you inform other peers where the object code and Corresponding
278
+ Source of the work are being offered to the general public at no
279
+ charge under subsection 6d.
280
+
281
+ A separable portion of the object code, whose source code is excluded
282
+ from the Corresponding Source as a System Library, need not be
283
+ included in conveying the object code work.
284
+
285
+ A "User Product" is either (1) a "consumer product", which means any
286
+ tangible personal property which is normally used for personal, family,
287
+ or household purposes, or (2) anything designed or sold for incorporation
288
+ into a dwelling. In determining whether a product is a consumer product,
289
+ doubtful cases shall be resolved in favor of coverage. For a particular
290
+ product received by a particular user, "normally used" refers to a
291
+ typical or common use of that class of product, regardless of the status
292
+ of the particular user or of the way in which the particular user
293
+ actually uses, or expects or is expected to use, the product. A product
294
+ is a consumer product regardless of whether the product has substantial
295
+ commercial, industrial or non-consumer uses, unless such uses represent
296
+ the only significant mode of use of the product.
297
+
298
+ "Installation Information" for a User Product means any methods,
299
+ procedures, authorization keys, or other information required to install
300
+ and execute modified versions of a covered work in that User Product from
301
+ a modified version of its Corresponding Source. The information must
302
+ suffice to ensure that the continued functioning of the modified object
303
+ code is in no case prevented or interfered with solely because
304
+ modification has been made.
305
+
306
+ If you convey an object code work under this section in, or with, or
307
+ specifically for use in, a User Product, and the conveying occurs as
308
+ part of a transaction in which the right of possession and use of the
309
+ User Product is transferred to the recipient in perpetuity or for a
310
+ fixed term (regardless of how the transaction is characterized), the
311
+ Corresponding Source conveyed under this section must be accompanied
312
+ by the Installation Information. But this requirement does not apply
313
+ if neither you nor any third party retains the ability to install
314
+ modified object code on the User Product (for example, the work has
315
+ been installed in ROM).
316
+
317
+ The requirement to provide Installation Information does not include a
318
+ requirement to continue to provide support service, warranty, or updates
319
+ for a work that has been modified or installed by the recipient, or for
320
+ the User Product in which it has been modified or installed. Access to a
321
+ network may be denied when the modification itself materially and
322
+ adversely affects the operation of the network or violates the rules and
323
+ protocols for communication across the network.
324
+
325
+ Corresponding Source conveyed, and Installation Information provided,
326
+ in accord with this section must be in a format that is publicly
327
+ documented (and with an implementation available to the public in
328
+ source code form), and must require no special password or key for
329
+ unpacking, reading or copying.
330
+
331
+ 7. Additional Terms.
332
+
333
+ "Additional permissions" are terms that supplement the terms of this
334
+ License by making exceptions from one or more of its conditions.
335
+ Additional permissions that are applicable to the entire Program shall
336
+ be treated as though they were included in this License, to the extent
337
+ that they are valid under applicable law. If additional permissions
338
+ apply only to part of the Program, that part may be used separately
339
+ under those permissions, but the entire Program remains governed by
340
+ this License without regard to the additional permissions.
341
+
342
+ When you convey a copy of a covered work, you may at your option
343
+ remove any additional permissions from that copy, or from any part of
344
+ it. (Additional permissions may be written to require their own
345
+ removal in certain cases when you modify the work.) You may place
346
+ additional permissions on material, added by you to a covered work,
347
+ for which you have or can give appropriate copyright permission.
348
+
349
+ Notwithstanding any other provision of this License, for material you
350
+ add to a covered work, you may (if authorized by the copyright holders of
351
+ that material) supplement the terms of this License with terms:
352
+
353
+ a) Disclaiming warranty or limiting liability differently from the
354
+ terms of sections 15 and 16 of this License; or
355
+
356
+ b) Requiring preservation of specified reasonable legal notices or
357
+ author attributions in that material or in the Appropriate Legal
358
+ Notices displayed by works containing it; or
359
+
360
+ c) Prohibiting misrepresentation of the origin of that material, or
361
+ requiring that modified versions of such material be marked in
362
+ reasonable ways as different from the original version; or
363
+
364
+ d) Limiting the use for publicity purposes of names of licensors or
365
+ authors of the material; or
366
+
367
+ e) Declining to grant rights under trademark law for use of some
368
+ trade names, trademarks, or service marks; or
369
+
370
+ f) Requiring indemnification of licensors and authors of that
371
+ material by anyone who conveys the material (or modified versions of
372
+ it) with contractual assumptions of liability to the recipient, for
373
+ any liability that these contractual assumptions directly impose on
374
+ those licensors and authors.
375
+
376
+ All other non-permissive additional terms are considered "further
377
+ restrictions" within the meaning of section 10. If the Program as you
378
+ received it, or any part of it, contains a notice stating that it is
379
+ governed by this License along with a term that is a further
380
+ restriction, you may remove that term. If a license document contains
381
+ a further restriction but permits relicensing or conveying under this
382
+ License, you may add to a covered work material governed by the terms
383
+ of that license document, provided that the further restriction does
384
+ not survive such relicensing or conveying.
385
+
386
+ If you add terms to a covered work in accord with this section, you
387
+ must place, in the relevant source files, a statement of the
388
+ additional terms that apply to those files, or a notice indicating
389
+ where to find the applicable terms.
390
+
391
+ Additional terms, permissive or non-permissive, may be stated in the
392
+ form of a separately written license, or stated as exceptions;
393
+ the above requirements apply either way.
394
+
395
+ 8. Termination.
396
+
397
+ You may not propagate or modify a covered work except as expressly
398
+ provided under this License. Any attempt otherwise to propagate or
399
+ modify it is void, and will automatically terminate your rights under
400
+ this License (including any patent licenses granted under the third
401
+ paragraph of section 11).
402
+
403
+ However, if you cease all violation of this License, then your
404
+ license from a particular copyright holder is reinstated (a)
405
+ provisionally, unless and until the copyright holder explicitly and
406
+ finally terminates your license, and (b) permanently, if the copyright
407
+ holder fails to notify you of the violation by some reasonable means
408
+ prior to 60 days after the cessation.
409
+
410
+ Moreover, your license from a particular copyright holder is
411
+ reinstated permanently if the copyright holder notifies you of the
412
+ violation by some reasonable means, this is the first time you have
413
+ received notice of violation of this License (for any work) from that
414
+ copyright holder, and you cure the violation prior to 30 days after
415
+ your receipt of the notice.
416
+
417
+ Termination of your rights under this section does not terminate the
418
+ licenses of parties who have received copies or rights from you under
419
+ this License. If your rights have been terminated and not permanently
420
+ reinstated, you do not qualify to receive new licenses for the same
421
+ material under section 10.
422
+
423
+ 9. Acceptance Not Required for Having Copies.
424
+
425
+ You are not required to accept this License in order to receive or
426
+ run a copy of the Program. Ancillary propagation of a covered work
427
+ occurring solely as a consequence of using peer-to-peer transmission
428
+ to receive a copy likewise does not require acceptance. However,
429
+ nothing other than this License grants you permission to propagate or
430
+ modify any covered work. These actions infringe copyright if you do
431
+ not accept this License. Therefore, by modifying or propagating a
432
+ covered work, you indicate your acceptance of this License to do so.
433
+
434
+ 10. Automatic Licensing of Downstream Recipients.
435
+
436
+ Each time you convey a covered work, the recipient automatically
437
+ receives a license from the original licensors, to run, modify and
438
+ propagate that work, subject to this License. You are not responsible
439
+ for enforcing compliance by third parties with this License.
440
+
441
+ An "entity transaction" is a transaction transferring control of an
442
+ organization, or substantially all assets of one, or subdividing an
443
+ organization, or merging organizations. If propagation of a covered
444
+ work results from an entity transaction, each party to that
445
+ transaction who receives a copy of the work also receives whatever
446
+ licenses to the work the party's predecessor in interest had or could
447
+ give under the previous paragraph, plus a right to possession of the
448
+ Corresponding Source of the work from the predecessor in interest, if
449
+ the predecessor has it or can get it with reasonable efforts.
450
+
451
+ You may not impose any further restrictions on the exercise of the
452
+ rights granted or affirmed under this License. For example, you may
453
+ not impose a license fee, royalty, or other charge for exercise of
454
+ rights granted under this License, and you may not initiate litigation
455
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
456
+ any patent claim is infringed by making, using, selling, offering for
457
+ sale, or importing the Program or any portion of it.
458
+
459
+ 11. Patents.
460
+
461
+ A "contributor" is a copyright holder who authorizes use under this
462
+ License of the Program or a work on which the Program is based. The
463
+ work thus licensed is called the contributor's "contributor version".
464
+
465
+ A contributor's "essential patent claims" are all patent claims
466
+ owned or controlled by the contributor, whether already acquired or
467
+ hereafter acquired, that would be infringed by some manner, permitted
468
+ by this License, of making, using, or selling its contributor version,
469
+ but do not include claims that would be infringed only as a
470
+ consequence of further modification of the contributor version. For
471
+ purposes of this definition, "control" includes the right to grant
472
+ patent sublicenses in a manner consistent with the requirements of
473
+ this License.
474
+
475
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
476
+ patent license under the contributor's essential patent claims, to
477
+ make, use, sell, offer for sale, import and otherwise run, modify and
478
+ propagate the contents of its contributor version.
479
+
480
+ In the following three paragraphs, a "patent license" is any express
481
+ agreement or commitment, however denominated, not to enforce a patent
482
+ (such as an express permission to practice a patent or covenant not to
483
+ sue for patent infringement). To "grant" such a patent license to a
484
+ party means to make such an agreement or commitment not to enforce a
485
+ patent against the party.
486
+
487
+ If you convey a covered work, knowingly relying on a patent license,
488
+ and the Corresponding Source of the work is not available for anyone
489
+ to copy, free of charge and under the terms of this License, through a
490
+ publicly available network server or other readily accessible means,
491
+ then you must either (1) cause the Corresponding Source to be so
492
+ available, or (2) arrange to deprive yourself of the benefit of the
493
+ patent license for this particular work, or (3) arrange, in a manner
494
+ consistent with the requirements of this License, to extend the patent
495
+ license to downstream recipients. "Knowingly relying" means you have
496
+ actual knowledge that, but for the patent license, your conveying the
497
+ covered work in a country, or your recipient's use of the covered work
498
+ in a country, would infringe one or more identifiable patents in that
499
+ country that you have reason to believe are valid.
500
+
501
+ If, pursuant to or in connection with a single transaction or
502
+ arrangement, you convey, or propagate by procuring conveyance of, a
503
+ covered work, and grant a patent license to some of the parties
504
+ receiving the covered work authorizing them to use, propagate, modify
505
+ or convey a specific copy of the covered work, then the patent license
506
+ you grant is automatically extended to all recipients of the covered
507
+ work and works based on it.
508
+
509
+ A patent license is "discriminatory" if it does not include within
510
+ the scope of its coverage, prohibits the exercise of, or is
511
+ conditioned on the non-exercise of one or more of the rights that are
512
+ specifically granted under this License. You may not convey a covered
513
+ work if you are a party to an arrangement with a third party that is
514
+ in the business of distributing software, under which you make payment
515
+ to the third party based on the extent of your activity of conveying
516
+ the work, and under which the third party grants, to any of the
517
+ parties who would receive the covered work from you, a discriminatory
518
+ patent license (a) in connection with copies of the covered work
519
+ conveyed by you (or copies made from those copies), or (b) primarily
520
+ for and in connection with specific products or compilations that
521
+ contain the covered work, unless you entered into that arrangement,
522
+ or that patent license was granted, prior to 28 March 2007.
523
+
524
+ Nothing in this License shall be construed as excluding or limiting
525
+ any implied license or other defenses to infringement that may
526
+ otherwise be available to you under applicable patent law.
527
+
528
+ 12. No Surrender of Others' Freedom.
529
+
530
+ If conditions are imposed on you (whether by court order, agreement or
531
+ otherwise) that contradict the conditions of this License, they do not
532
+ excuse you from the conditions of this License. If you cannot convey a
533
+ covered work so as to satisfy simultaneously your obligations under this
534
+ License and any other pertinent obligations, then as a consequence you may
535
+ not convey it at all. For example, if you agree to terms that obligate you
536
+ to collect a royalty for further conveying from those to whom you convey
537
+ the Program, the only way you could satisfy both those terms and this
538
+ License would be to refrain entirely from conveying the Program.
539
+
540
+ 13. Remote Network Interaction; Use with the GNU General Public License.
541
+
542
+ Notwithstanding any other provision of this License, if you modify the
543
+ Program, your modified version must prominently offer all users
544
+ interacting with it remotely through a computer network (if your version
545
+ supports such interaction) an opportunity to receive the Corresponding
546
+ Source of your version by providing access to the Corresponding Source
547
+ from a network server at no charge, through some standard or customary
548
+ means of facilitating copying of software. This Corresponding Source
549
+ shall include the Corresponding Source for any work covered by version 3
550
+ of the GNU General Public License that is incorporated pursuant to the
551
+ following paragraph.
552
+
553
+ Notwithstanding any other provision of this License, you have
554
+ permission to link or combine any covered work with a work licensed
555
+ under version 3 of the GNU General Public License into a single
556
+ combined work, and to convey the resulting work. The terms of this
557
+ License will continue to apply to the part which is the covered work,
558
+ but the work with which it is combined will remain governed by version
559
+ 3 of the GNU General Public License.
560
+
561
+ 14. Revised Versions of this License.
562
+
563
+ The Free Software Foundation may publish revised and/or new versions of
564
+ the GNU Affero General Public License from time to time. Such new versions
565
+ will be similar in spirit to the present version, but may differ in detail to
566
+ address new problems or concerns.
567
+
568
+ Each version is given a distinguishing version number. If the
569
+ Program specifies that a certain numbered version of the GNU Affero General
570
+ Public License "or any later version" applies to it, you have the
571
+ option of following the terms and conditions either of that numbered
572
+ version or of any later version published by the Free Software
573
+ Foundation. If the Program does not specify a version number of the
574
+ GNU Affero General Public License, you may choose any version ever published
575
+ by the Free Software Foundation.
576
+
577
+ If the Program specifies that a proxy can decide which future
578
+ versions of the GNU Affero General Public License can be used, that proxy's
579
+ public statement of acceptance of a version permanently authorizes you
580
+ to choose that version for the Program.
581
+
582
+ Later license versions may give you additional or different
583
+ permissions. However, no additional obligations are imposed on any
584
+ author or copyright holder as a result of your choosing to follow a
585
+ later version.
586
+
587
+ 15. Disclaimer of Warranty.
588
+
589
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597
+
598
+ 16. Limitation of Liability.
599
+
600
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608
+ SUCH DAMAGES.
609
+
610
+ 17. Interpretation of Sections 15 and 16.
611
+
612
+ If the disclaimer of warranty and limitation of liability provided
613
+ above cannot be given local legal effect according to their terms,
614
+ reviewing courts shall apply local law that most closely approximates
615
+ an absolute waiver of all civil liability in connection with the
616
+ Program, unless a warranty or assumption of liability accompanies a
617
+ copy of the Program in return for a fee.
618
+
619
+ END OF TERMS AND CONDITIONS
620
+
621
+ How to Apply These Terms to Your New Programs
622
+
623
+ If you develop a new program, and you want it to be of the greatest
624
+ possible use to the public, the best way to achieve this is to make it
625
+ free software which everyone can redistribute and change under these terms.
626
+
627
+ To do so, attach the following notices to the program. It is safest
628
+ to attach them to the start of each source file to most effectively
629
+ state the exclusion of warranty; and each file should have at least
630
+ the "copyright" line and a pointer to where the full notice is found.
631
+
632
+ <one line to give the program's name and a brief idea of what it does.>
633
+ Copyright (C) <year> <name of author>
634
+
635
+ This program is free software: you can redistribute it and/or modify
636
+ it under the terms of the GNU Affero General Public License as published
637
+ by the Free Software Foundation, either version 3 of the License, or
638
+ (at your option) any later version.
639
+
640
+ This program is distributed in the hope that it will be useful,
641
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
642
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643
+ GNU Affero General Public License for more details.
644
+
645
+ You should have received a copy of the GNU Affero General Public License
646
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
647
+
648
+ Also add information on how to contact you by electronic and paper mail.
649
+
650
+ If your software can interact with users remotely through a computer
651
+ network, you should also make sure that it provides a way for users to
652
+ get its source. For example, if your program is a web application, its
653
+ interface could display a "Source" link that leads users to an archive
654
+ of the code. There are many ways you could offer source, and different
655
+ solutions will be better for different programs; see section 13 for the
656
+ specific requirements.
657
+
658
+ You should also get your employer (if you work as a programmer) or school,
659
+ if any, to sign a "copyright disclaimer" for the program,
660
+ if necessary. For more information on this, and how to apply and follow the GNU AGPL, see <https://www.gnu.org/licenses/>.
661
+
662
+ If you would like to obtain a copy of the software under a different license (e.g. Apache),
663
+ please send an email to Camb.AI at help@camb.ai indicating that
664
+ you would like a copy of the software under a different license.
README.md ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![MARS5 Banner](assets/github-banner.png)
2
+
3
+ # MARS5: A novel speech model for insane prosody.
4
+
5
+ This is the repo for the MARS5 English speech model (TTS) from CAMB.AI.
6
+
7
+ The model follows a two-stage AR-NAR pipeline with a distinctively novel NAR component (see more info in the [docs](docs/architecture.md)).
8
+
9
+ With just 5 seconds of audio and a snippet of text, MARS5 can generate speech even for prosodically hard and diverse scenarios like sports commentary, anime and more. Check out our demo:
10
+
11
+
12
+
13
+
14
+ https://github.com/Camb-ai/MARS5-TTS/assets/23717819/3e191508-e03c-4ff9-9b02-d73ae0ebefdd
15
+
16
+
17
+
18
+
19
+ **Quick links**:
20
+ - [CAMB.AI website](https://camb.ai/) (access MARS5 in 140+ languages for TTS and dubbing)
21
+ - Technical docs: [in the docs folder](docs/architecture.md)
22
+ - Colab quickstart: <a target="_blank" href="https://colab.research.google.com/github/Camb-ai/mars5-tts/blob/master/mars5_demo.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
23
+ - Demo page with samples: [here](https://179c54d254f7.ngrok.app/)
24
+
25
+ ![Mars 5 simplified diagram](docs/assets/simplified_diagram.png)
26
+
27
+ **Figure**: the high-level architecture flow of Mars 5. Given text and a reference audio, coarse (L0) encodec speech features are obtained through an autoregressive transformer model. Then, the text, reference, and coarse features are refined in a multinomial DDPM model to produce the remaining encodec codebook values. The output of the DDPM is then vocoded to produce the final audio.
28
+
29
+ Because the model is trained on raw audio together with byte-pair-encoded text, it can be steered with things like punctuation and capitalization.
30
+ E.g. to add a pause, add a comma to that part in the transcript. Or, to emphasize a word, put it in capital letters in the transcript.
31
+ This enables a fairly natural way for guiding the prosody of the generated output.
32
+
33
+ Speaker identity is specified using an audio reference file between 2-12 seconds, with lengths around 6s giving optimal results.
34
+ Further, by providing the transcript of the reference, MARS5 enables one to do a '_deep clone_' which improves the quality of the cloning and output, at the cost of taking a bit longer to produce the audio.
35
+ For more details on this and other performance and model details, please see inside the [docs folder](docs/architecture.md).
36
+
37
+
38
+ ## Quickstart
39
+
40
+
41
+ We use `torch.hub` to make loading the model easy -- no cloning of the repo needed. The steps to perform inference are simple:
42
+
43
+ 1. **Install pip dependencies**: we have 3 inference dependencies only `torch`, `torchaudio`, `librosa`, `vocos`, and `encodec`. Python must be at version 3.10 or greater, and torch must be v2.0 or greater.
44
+
45
+ ```bash
46
+ pip install --upgrade torch torchaudio librosa vocos encodec
47
+ ```
48
+
49
+ 2. **Load models**: load the Mars 5 AR and NAR model from torch hub:
50
+
51
+ ```python
52
+ import torch, librosa
53
+
54
+ mars5, config_class = torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', trust_repo=True)
55
+ # The `mars5` contains the AR and NAR model, as well as inference code.
56
+ # The `config_class` contains tunable inference config settings like temperature.
57
+ ```
58
+ 3. **Pick a reference** and optionally its transcript:
59
+
60
+ ```python
61
+ # load reference audio between 1-12 seconds.
62
+ wav, sr = librosa.load('<path to arbitrary 24kHz waveform>.wav',
63
+ sr=mars5.sr, mono=True)
64
+ wav = torch.from_numpy(wav)
65
+ ref_transcript = "<transcript of the reference audio>"
66
+ ```
67
+
68
+ The reference transcript is an optional piece of info you need if you wish to do a deep clone.
69
+ Mars5 supports 2 kinds of inference: a shallow, fast inference whereby you do not need the transcript of the reference (we call this a _shallow clone_), and a second slower, but typically higher quality way, which we call a _deep clone_.
70
+ To use the deep clone, you need the prompt transcript. See the [model docs](docs/architecture.md) for more info on this.
71
+
72
+ 4. **Perform the synthesis**:
73
+
74
+ ```python
75
+ # Pick whether you want a deep or shallow clone. Set to False if you don't know prompt transcript or want fast inference. Set to True if you know transcript and want highest quality.
76
+ deep_clone = True
77
+ # Below you can tune other inference settings, like top_k, temperature, top_p, etc...
78
+ cfg = config_class(deep_clone=deep_clone, rep_penalty_window=100,
79
+ top_k=100, temperature=0.7, freq_penalty=3)
80
+
81
+ ar_codes, output_audio = mars5.tts("The quick brown rat.", wav,
82
+ ref_transcript,
83
+ cfg=cfg)
84
+ # output_audio is (T,) shape float tensor corresponding to the 24kHz output audio.
85
+ ```
86
+
87
+ That's it! These default settings provide pretty good results, but feel free to tune the inference settings to optimize the output for your particular example. See the [`InferenceConfig`](inference.py) code or the demo notebook for info and docs on all the different inference settings.
88
+
89
+ _Some tips for best quality:_
90
+ - Make sure reference audio is clean and between 1 second and 12 seconds.
91
+ - Use deep clone and provide an accurate transcript for the reference.
92
+ - Use proper punctuation -- the model can be guided and made better or worse with proper use of punctuation and capitalization.
93
+
94
+
95
+ ## Model details
96
+
97
+ **Checkpoints**
98
+
99
+ The checkpoints for MARS5 are provided under the releases tab of this github repo. We provide two checkpoints:
100
+
101
+ - AR fp16 checkpoint [~750M parameters], along with config embedded in the checkpoint.
102
+ - NAR fp16 checkpoint [~450M parameters], along with config embedded in the checkpoint.
103
+ - The byte-pair encoding tokenizer used for the L0 encodec codes and the English text is embedded in each checkpoint under the `'vocab'` key, and follows roughly the same format of a saved minbpe tokenizer.
104
+
105
+ **Hardware requirements**:
106
+
107
+ You must be able to store at least 750M+450M params on GPU, and do inference with 750M of active parameters. In general, at least **20GB of GPU VRAM** is needed to run the model on GPU (we plan to further optimize this in the future).
108
+
109
+ If you do not have the necessary hardware requirements and just want to use MARS5 in your applications, you can use it via our API: see [docs.camb.ai](https://docs.camb.ai/). If you need some more credits to test it for your use case, feel free to reach out to `help@camb.ai` for help.
110
+
111
+ ## Roadmap
112
+
113
+ Mars 5 is not perfect at the moment, and we are working on a few efforts to improve its quality, stability, and performance.
114
+ Rough areas we are looking to improve, and welcome any contributions:
115
+
116
+ - Improving inference stability and consistency
117
+ - Speed/performance optimizations
118
+ - Improving reference audio selection when given long references.
119
+ - Benchmark performance numbers for Mars 5 on standard speech datasets.
120
+
121
+ If you would like to contribute any improvement to MARS, please feel free to contribute (guidelines below).
122
+
123
+ ## Contributions
124
+
125
+ We welcome any contributions to improving the model. As you may find when experimenting, it can produce really great results, it can still be further improved to create excellent outputs _consistently_.
126
+
127
+ **Contribution format**:
128
+
129
+ The preferred way to contribute to our repo is to fork the [master repository](https://github.com/Camb-ai/mars5-tts) on GitHub:
130
+
131
+ 1. Fork the repo on github
132
+ 2. Clone the repo, set upstream as this repo: `git remote add upstream git@github.com:Camb-ai/mars5-tts.git`
133
+ 3. Make to a new local branch and make your changes, commit changes.
134
+ 4. Push changes to new upstream branch: `git push --set-upstream origin <NAME-NEW-BRANCH>`
135
+ 5. On github, go to your fork and click 'Pull request' to begin the PR process. Please make sure to include a description of what you did/fixed.
136
+
137
+ ## License
138
+
139
+ We are open-sourcing MARS in English under GNU AGPL 3.0, but you can request to use it under a different license by emailing help@camb.ai
140
+
141
+ ## Join our team
142
+
143
+ We're an ambitious team, globally distributed, with a singular aim of making everyone's voice count. At CAMB.AI, we're a research team of Interspeech-published, Carnegie Mellon, ex-Siri engineers and we're looking for you to join our team.
144
+
145
+ We're actively hiring; please drop us an email at ack@camb.ai if you're interested. Visit our [careers page](https://www.camb.ai/careers) for more info.
146
+
147
+ ## Acknowledgements
148
+
149
+ Parts of code for this project are adapted from the following repositories -- please make sure to check them out! Thank you to the authors of:
150
+
151
+ - AWS: For providing much needed compute resources (NVIDIA H100s) to enable training of the model.
152
+ - TransFusion: [https://github.com/RF5/transfusion-asr](https://github.com/RF5/transfusion-asr)
153
+ - Multinomial diffusion: [https://github.com/ehoogeboom/multinomial_diffusion](https://github.com/ehoogeboom/multinomial_diffusion)
154
+ - Mistral-src: [https://github.com/mistralai/mistral-src](https://github.com/mistralai/mistral-src)
155
+ - minbpe: [https://github.com/karpathy/minbpe](https://github.com/karpathy/minbpe)
156
+ - gemelo-ai's encodec Vocos: [https://github.com/gemelo-ai/vocos](https://github.com/gemelo-ai/vocos)
157
+ - librosa for their `.trim()` code: [https://librosa.org/doc/main/generated/librosa.effects.trim.html](https://librosa.org/doc/main/generated/librosa.effects.trim.html)
assets/demo-preview.png ADDED
assets/github-banner.png ADDED
docs/architecture.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mars 5 technical details
2
+
3
+ While we do not have the time for a proper full writeup of the details of Mars5, its design, training, and implementation, we at least try give a more detailed overview here of how Mars5 works.
4
+
5
+
6
+ ## hubconf object/api
7
+
8
+
9
+ After loading the model with `torch.hub.load`, two objects are returned, a Mars5TTS, and the dataclass of the inference config to use when calling the `mars5.tts()` method.
10
+ Concretely, the main methods of the mars5 object are:
11
+
12
+ ```python
13
+
14
+ # The init function, called automatically when you initialize the
15
+ # model from torch.hub.load(). If you want, you can pass in your
16
+ # own custom checkpoints here to initalize the model with your
17
+ # own model, tokenizer, etc...
18
+ def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None:
19
+ # ... initialization code ...
20
+
21
+ # Main text-to-speech function, converting text and a reference
22
+ # audio to speech.
23
+ def tts(self, text: str, ref_audio: Tensor, ref_transcript: str | None,
24
+ cfg: InferenceConfig) -> Tensor:
25
+ """ Perform TTS for `text`, given a reference audio `ref_audio` (of shape [sequence_length,], sampled at 24kHz)
26
+ which has an associated `ref_transcript`. Perform inference using the inference
27
+ config given by `cfg`, which controls the temperature, top_p, etc...
28
+ Returns:
29
+ - `ar_codes`: (seq_len,) long tensor of discrete coarse code outputs from the AR model.
30
+ - `out_wav`: (T,) float output audio tensor sampled at 24kHz.
31
+ """
32
+
33
+ # Utility function to vocode encodec tokens, if one wishes
34
+ # to hear the raw AR model ouput by vocoding the `ar_codes`
35
+ # returned above.
36
+ def vocode(self, tokens: Tensor) -> Tensor:
37
+ """ Vocodes tokens of shape (seq_len, n_q) """
38
+ ```
39
+
40
+
41
+ ## Model design
42
+
43
+ Mars 5 follows a two-stage AR-NAR design according to the diagram on the main page.
44
+
45
+ #### AR component
46
+
47
+ The AR model follows a Mistral-style encoder-only transformer model to predict Encodec L0 codes (the lowest/most coarse level quantization codes).
48
+ Overall, the AR and NAR model is going to predict all 8 codebook entries of the Encodec 6kbps codec.
49
+ The AR model design is given below:
50
+
51
+ ![Mars 5 AR architecture](/docs/assets/mars5_AR_arch.png)
52
+
53
+ **Figure**: autoregressive component of Mars 5. During training, the initial 6kbps encodec tokens of the speech are fed through a small encoder-only transformer, producing a single output vector corresponding to an implicit speaker embedding.
54
+ This vector is concatenated with learnt embeddings corresponding to the text tokens, and L0 speech tokens, after byte-pair encoding tokenization.
55
+
56
+
57
+ The AR model is trained using the standard next-token prediction task of language models with a cross-entropy loss with the next token, given a smaller weight to text tokens.
58
+ During inference, we iteratively sample from the transformer to produce the desiged L0 codes.
59
+ When we use a _shallow clone_, then the reference audio is fed into the transcript to make the implicit speaker embedding used in the input sequence.
60
+ When we use a _deep clone_, the above is done, but we also concatenate the reference transcript with the desired text, and the reference audio tokens with the input sequence before we start sampling the output.
61
+ In pseudocode:
62
+
63
+ ```
64
+ speaker_embedding <- speaker_conditioning_transformer(ref audio)
65
+ if deep_clone:
66
+ prompt = concatenate( speaker embedding, reference text, target text, reference L0 speech codes )
67
+ else:
68
+ prompt = concatenate( speaker embedding, target text )
69
+
70
+ ar output <- autoregressively sample from prompt
71
+ ```
72
+
73
+ While a deep clone provides a more accurate cloning of the reference speaker identity and prosody, it requires knowledge of the reference transcript and takes longer to do inference.
74
+
75
+ #### NAR component
76
+
77
+ After the AR model has predicted the L0 encodec codes, we need a way to predict the remaining 7 codebooks of the 6kbps Encodec codec.
78
+ This is what the NAR model is trained to do, using a multinomial diffusion framework.
79
+ Concretely, the diffusion process is a discrete DDPM, whereby at each timestep in the diffusion process, it takes in a sequence of `(batch size, sequence length, n_codebooks)` and produces an output categorical distribution over each codebook, i.e. an output of shape `(batch size, sequence length, n_codebooks, 1024)`, since each encodec codebook has 1024 possible values.
80
+ The architecture of the model looks as follows:
81
+
82
+
83
+ ![Mars 5 NAR architecture](/docs/assets/mars5_NAR_arch.png)
84
+
85
+ **Figure**: Mars 5 non-autoregressive component. It follows an encoder-decoder transformer architecture, whereby the encoder computes an implicit speaker embedding like the AR model, and concatenates that along with the target to form an input sequence to a transformer encoder. The transformer decoder predicts the distribution of all 8 encodec codebook tokens given a partly noised input at some diffusion timestep `t`.
86
+
87
+
88
+ The encoder and decoder transformers are simple `nn.Transformer` variants with sinusoidal positional embeddings and SwiGLU activations.
89
+ A multinomial diffusion manager controls the forward and reference diffusion processes during inference and training according to a cosine diffusion schedule.
90
+ Diffusion is performed independently of the sequence length or codebook index.
91
+
92
+ During training and inference, the L0 codebooks of the input at timestep $t$ are overridden (i.e. not noised in the forward diffusion process) with either the ground truth L0 codes (during training) or the AR model's predictions (during inference).
93
+ Like the AR model, the NAR model can perform inference in either a _shallow clone_ way or a _deep clone_ way.
94
+ And, like the AR model, the difference between the two is, with a _deep clone_, we concatenate the reference text to the input text sequence, and the reference speech codes (the full values for all 8 codebooks) to the decoder input sequence $x$.
95
+ During inference, we then treat the portion of $x$ corresponding to the reference codec codes, and all the AR L0 codes, as 'fixed' and effectively perform diffusion inpainting for the remaining missing codec codes.
96
+ The figure below explains what the input to the decoder looks like for a deep clone:
97
+
98
+ ![NAR decoder input for deep clone](/docs/assets/NAR_inpainting_diagram.png)
99
+
100
+ This allows us to use diffusion inpainting techniques like [RePaint](https://arxiv.org/abs/2201.09865) to improve the quality of the output at the cost of more inference time.
101
+ We've implemented this in the the diffusion config used in the NAR inference code (see it [here](/mars5/diffuser.py)), and you can simply increase the `jump_len` and `jump_n_sample` to greater than 1 to use RePaint inpainting to improve NAR performance.
102
+
103
+
docs/assets/NAR_inpainting_diagram.png ADDED
docs/assets/example_ref.wav ADDED
Binary file (137 kB). View file
 
docs/assets/intro_vid.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cabbb40186fd5599282b4ada76643b1d1b34c513af1977861513f9d2f1220ad6
3
+ size 2105962
docs/assets/mars5_AR_arch.png ADDED
docs/assets/mars5_NAR_arch.png ADDED
docs/assets/simplified_diagram.png ADDED
hubconf.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dependencies = ['torch', 'torchaudio', 'numpy', 'vocos']
2
+
3
+ import logging
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ from inference import Mars5TTS, InferenceConfig
8
+
9
+ ar_url = "https://github.com/Camb-ai/mars5-tts/releases/download/v0.1-checkpoints/mars5_en_checkpoints_ar-1680000.pt"
10
+ nar_url = "https://github.com/Camb-ai/mars5-tts/releases/download/v0.1-checkpoints/mars5_en_checkpoints_nar-1260000.pt"
11
+
12
+ def mars5_english(pretrained=True, progress=True, device=None, ar_path=None, nar_path=None) -> Mars5TTS:
13
+ """ Load mars5 english model on `device`, optionally show `progress`. """
14
+ if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ logging.info(f"Using device: {device}")
16
+ if pretrained == False: raise AssertionError('Only pretrained model currently supported.')
17
+ logging.info("Loading AR checkpoint...")
18
+ if ar_path is None:
19
+ ar_ckpt = torch.hub.load_state_dict_from_url(
20
+ ar_url, progress=progress, check_hash=False, map_location='cpu'
21
+ )
22
+ else: ar_ckpt = torch.load(str(ar_path), map_location='cpu')
23
+
24
+ logging.info("Loading NAR checkpoint...")
25
+ if nar_path is None:
26
+ nar_ckpt = torch.hub.load_state_dict_from_url(
27
+ nar_url, progress=progress, check_hash=False, map_location='cpu'
28
+ )
29
+ else: nar_ckpt = torch.load(str(nar_path), map_location='cpu')
30
+ logging.info("Initializing modules...")
31
+ mars5 = Mars5TTS(ar_ckpt, nar_ckpt, device=device)
32
+ return mars5, InferenceConfig
33
+
inference.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import logging
6
+ import json
7
+ from typing import Optional
8
+ from pathlib import Path
9
+ from dataclasses import dataclass
10
+ import os
11
+
12
+ from mars5.model import CodecLM, ResidualTransformer
13
+ from vocos import Vocos
14
+ from encodec import EncodecModel
15
+ from mars5.diffuser import MultinomialDiffusion, DSH, perform_simple_inference
16
+ from mars5.minbpe.regex import RegexTokenizer, GPT4_SPLIT_PATTERN
17
+ from mars5.minbpe.codebook import CodebookTokenizer
18
+ from mars5.ar_generate import ar_generate
19
+ from mars5.utils import nuke_weight_norm
20
+ from mars5.trim import trim
21
+ import tempfile
22
+ import logging
23
+
24
+
25
+ @dataclass
26
+ class InferenceConfig():
27
+ """ The defaults configuration variables for TTS inference. """
28
+
29
+ ## >>>> AR CONFIG
30
+ temperature: float = 0.7
31
+ top_k: int = 200 # 0 disables it
32
+ top_p: float = 0.2
33
+ typical_p: float = 1.0
34
+ freq_penalty: float = 3
35
+ presence_penalty: float = 0.4
36
+ rep_penalty_window: int = 80 # how far in the past to consider when penalizing repetitions. Equates to 5s
37
+
38
+ eos_penalty_decay: float = 0.5 # how much to penalize <eos>
39
+ eos_penalty_factor: float = 1 # overal penalty weighting
40
+ eos_estimated_gen_length_factor: float = 1.0 # multiple of len(text_phones) to assume an approximate output length is
41
+
42
+ ## >>>> NAR CONFIG
43
+ # defaults, that can be overridden with user specified inputs
44
+ timesteps: int = 200
45
+ x_0_temp: float = 0.7
46
+ q0_override_steps: int = 20 # number of diffusion steps where NAR L0 predictions overrides AR L0 predictions.
47
+ nar_guidance_w: float = 3
48
+
49
+ max_prompt_dur: float = 12 # maximum length prompt is allowed, in seconds.
50
+
51
+ # Maximum AR codes to generate in 1 inference.
52
+ # Default of -1 leaves it same as training time max AR tokens.
53
+ # Typical values up to ~2x training time can be tolerated,
54
+ # with ~1.5x trianing time tokens having still mostly ok performance.
55
+ generate_max_len_override: int = -1
56
+
57
+ # Whether to deep clone from the reference.
58
+ # Pros: improves intelligibility and speaker cloning performance.
59
+ # Cons: requires reference transcript, and inference takes a bit longer.
60
+ deep_clone: bool = True
61
+
62
+ use_kv_cache: bool = True
63
+ trim_db: float = 27
64
+ beam_width: int = 1 # only beam width of 1 is currently supported
65
+ ref_audio_pad: float = 0
66
+
67
+
68
+ class Mars5TTS(nn.Module):
69
+
70
+ def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None:
71
+ super().__init__()
72
+
73
+ if device is None:
74
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
75
+ self.device = torch.device(device)
76
+
77
+ self.codec = EncodecModel.encodec_model_24khz().to(device).eval()
78
+ self.codec.set_target_bandwidth(6.0)
79
+
80
+ # save and load text tokenize
81
+ self.texttok = RegexTokenizer(GPT4_SPLIT_PATTERN)
82
+ tfn = tempfile.mkstemp(suffix='texttok.model')[1]
83
+ Path(tfn).write_text(ar_ckpt['vocab']['texttok.model'])
84
+ self.texttok.load(tfn)
85
+ os.remove(tfn)
86
+ # save and load speech tokenizer
87
+ sfn = tempfile.mkstemp(suffix='speechtok.model')[1]
88
+ self.speechtok = CodebookTokenizer(GPT4_SPLIT_PATTERN)
89
+ Path(sfn).write_text(ar_ckpt['vocab']['speechtok.model'])
90
+ self.speechtok.load(sfn)
91
+ os.remove(sfn)
92
+ # keep track of tokenization things.
93
+ self.n_vocab = len(self.texttok.vocab) + len(self.speechtok.vocab)
94
+ self.n_text_vocab = len(self.texttok.vocab) + 1
95
+ self.diffusion_n_classes: int = 1025 # 1 for padding idx
96
+ # load AR model
97
+ self.codeclm = CodecLM(n_vocab=self.n_vocab, dim=1536, dim_ff_scale=7/3)
98
+ self.codeclm.load_state_dict(ar_ckpt['model'])
99
+ self.codeclm = self.codeclm.to(self.device).eval()
100
+ # load NAR model
101
+ self.codecnar = ResidualTransformer(n_text_vocab=self.n_text_vocab, n_quant=self.diffusion_n_classes,
102
+ p_cond_drop=0, dropout=0)
103
+ self.codecnar.load_state_dict(nar_ckpt['model'])
104
+ self.codecnar = self.codecnar.to(self.device).eval()
105
+ self.default_T = 200
106
+
107
+ self.sr = 24000
108
+ self.latent_sr = 75
109
+
110
+ # load vocoder
111
+ self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device).eval()
112
+ nuke_weight_norm(self.codec)
113
+ nuke_weight_norm(self.vocos)
114
+
115
+ @torch.inference_mode
116
+ def vocode(self, tokens: Tensor) -> Tensor:
117
+ """ Vocodes tokens of shape (seq_len, n_q) """
118
+ tokens = tokens.T.to(self.device)
119
+ features = self.vocos.codes_to_features(tokens)
120
+ # A cool hidden feature of vocos vocoding:
121
+ # setting the bandwidth below to 1 (corresponding to 3 kbps)
122
+ # actually still works on 6kbps input tokens, but *smooths* the output
123
+ # audio a bit, which can help improve quality if its a bit noisy.
124
+ # Hence we use [1] and not [2] below.
125
+ bandwidth_id = torch.tensor([1], device=self.device) # 6 kbps
126
+ wav_diffusion = self.vocos.decode(features, bandwidth_id=bandwidth_id)
127
+ return wav_diffusion.cpu().squeeze()[None]
128
+
129
+ @torch.inference_mode
130
+ def tts(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None,
131
+ cfg: Optional[InferenceConfig] = InferenceConfig()) -> Tensor:
132
+ """ Perform TTS for `text`, given a reference audio `ref_audio` (of shape [sequence_length,], sampled at 24kHz)
133
+ which has an associated `ref_transcript`. Perform inference using the inference
134
+ config given by `cfg`, which controls the temperature, top_p, etc...
135
+ Returns:
136
+ - `ar_codes`: (seq_len,) long tensor of discrete coarse code outputs from the AR model.
137
+ - `out_wav`: (T,) float output audio tensor sampled at 24kHz.
138
+ """
139
+
140
+ if cfg.deep_clone and ref_transcript is None:
141
+ raise AssertionError(
142
+ ("Inference config deep clone is set to true, but reference transcript not specified! "
143
+ "Please specify the transcript of the prompt, or set deep_clone=False in the inference `cfg` argument."
144
+ ))
145
+ ref_dur = ref_audio.shape[-1]/self.sr
146
+ if ref_dur > cfg.max_prompt_dur:
147
+ logging.warning((f"Reference audio duration is {ref_dur:.2f} > max suggested ref audio. "
148
+ f"Expect quality degradations. We recommend you trim prompt to be shorter than max prompt length."))
149
+
150
+ # get text codes.
151
+ text_tokens = self.texttok.encode("<|startoftext|>"+text.strip()+"<|endoftext|>",
152
+ allowed_special='all')
153
+
154
+ text_tokens_full = self.texttok.encode("<|startoftext|>"+ ref_transcript + ' ' + str(text).strip()+"<|endoftext|>",
155
+ allowed_special='all')
156
+
157
+ if ref_audio.dim() == 1: ref_audio = ref_audio[None]
158
+ if ref_audio.shape[0] != 1: ref_audio = ref_audio.mean(dim=0, keepdim=True)
159
+ ref_audio = F.pad(ref_audio, (int(self.sr*cfg.ref_audio_pad), 0))
160
+ # get reference audio codec tokens
161
+ prompt_codec = self.codec.encode(ref_audio[None].to(self.device))[0][0] # (bs, n_q, seq_len)
162
+
163
+ n_speech_inp = 0
164
+ n_start_skip = 0
165
+ q0_str = ' '.join([str(t) for t in prompt_codec[0, 0].tolist()])
166
+ # Note, in the below, we do NOT want to encode the <eos> token as a part of it, since we will be continuing it!!!
167
+ speech_tokens = self.speechtok.encode(q0_str.strip()) # + "<|endofspeech|>", allowed_special='all')
168
+ spk_ref_codec = prompt_codec[0, :, :].T # (seq_len, n_q)
169
+
170
+ raw_prompt_acoustic_len = len(prompt_codec[0,0].squeeze())
171
+ offset_speech_codes = [p+len(self.texttok.vocab) for p in speech_tokens]
172
+ if not cfg.deep_clone:
173
+ # shallow clone, so
174
+ # 1. clip existing speech codes to be empty (n_speech_inp = 0)
175
+ offset_speech_codes = offset_speech_codes[:n_speech_inp]
176
+ else:
177
+ # Deep clone, so
178
+ # 1. set text to be text of prompt + target text
179
+ text_tokens = text_tokens_full
180
+ # 2. update n_speech_inp to be length of prompt, so we only display from ths `n_speech_inp` onwards in the final output.
181
+ n_speech_inp = len(offset_speech_codes)
182
+ prompt = torch.tensor(text_tokens + offset_speech_codes, dtype=torch.long, device=self.device)
183
+ first_codec_idx = prompt.shape[-1] - n_speech_inp + 1
184
+
185
+ # ---> perform AR code generation
186
+
187
+ logging.debug(f"Raw acoustic prompt length: {raw_prompt_acoustic_len}")
188
+
189
+ ar_codes = ar_generate(self.texttok, self.speechtok, self.codeclm,
190
+ prompt, spk_ref_codec, first_codec_idx,
191
+ max_len=cfg.generate_max_len_override if cfg.generate_max_len_override > 1 else 2000,
192
+ temperature=cfg.temperature, topk=cfg.top_k, top_p=cfg.top_p, typical_p=cfg.typical_p,
193
+ alpha_frequency=cfg.freq_penalty, alpha_presence=cfg.presence_penalty, penalty_window=cfg.rep_penalty_window,
194
+ eos_penalty_decay=cfg.eos_penalty_decay, eos_penalty_factor=cfg.eos_penalty_factor,
195
+ beam_width=cfg.beam_width, beam_length_penalty=1,
196
+ n_phones_gen=round(cfg.eos_estimated_gen_length_factor*len(text)),
197
+ vocode=False, use_kv_cache=cfg.use_kv_cache)
198
+
199
+ # Parse AR output
200
+ output_tokens = ar_codes - len(self.texttok.vocab)
201
+ output_tokens = output_tokens.clamp(min=0).squeeze()[first_codec_idx:].cpu().tolist()
202
+ gen_codes_decoded = self.speechtok.decode_int(output_tokens)
203
+ gen_codes_decoded = torch.tensor([s for s in gen_codes_decoded if type(s) == int], dtype=torch.long, device=self.device)
204
+
205
+ c_text = torch.tensor(text_tokens, dtype=torch.long, device=self.device)[None]
206
+ c_codes = prompt_codec.permute(0, 2, 1)
207
+ c_texts_lengths = torch.tensor([len(text_tokens)], dtype=torch.long, device=self.device)
208
+ c_codes_lengths = torch.tensor([c_codes.shape[1],], dtype=torch.long, device=self.device)
209
+
210
+ _x = gen_codes_decoded[None, n_start_skip:, None].repeat(1, 1, 8) # (seq_len) -> (1, seq_len, 8)
211
+ x_padding_mask = torch.zeros((1, _x.shape[1]), dtype=torch.bool, device=_x.device)
212
+
213
+ # ---> perform DDPM NAR inference
214
+
215
+ T = self.default_T
216
+ diff = MultinomialDiffusion(self.diffusion_n_classes, timesteps=T, device=self.device)
217
+
218
+ dsh_cfg = DSH(last_greedy=True, x_0_temp=cfg.x_0_temp,
219
+ guidance_w=cfg.nar_guidance_w,
220
+ deep_clone=cfg.deep_clone, jump_len=1, jump_n_sample=1,
221
+ q0_override_steps=cfg.q0_override_steps,
222
+ enable_kevin_scaled_inference=True, # see TransFusion ASR for explanation of this
223
+ progress=False)
224
+
225
+ final_output = perform_simple_inference(self.codecnar,(
226
+ c_text, c_codes, c_texts_lengths, c_codes_lengths, _x, x_padding_mask
227
+ ), diff, diff.num_timesteps, torch.float16, dsh=dsh_cfg, retain_quant0=True) # (bs, seq_len, n_quant)
228
+
229
+ skip_front = raw_prompt_acoustic_len if cfg.deep_clone else 0
230
+ final_output = final_output[0, skip_front:].to(self.device) # (seq_len, n_quant)
231
+
232
+ # vocode final output and trim silences
233
+ final_audio = self.vocode(final_output).squeeze()
234
+ final_audio, _ = trim(final_audio.cpu(), top_db=cfg.trim_db)
235
+
236
+ return gen_codes_decoded, final_audio
mars5/ar_generate.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torchaudio
4
+ import copy
5
+ from torch import Tensor, nn
6
+ import logging
7
+ from .model import length_to_mask
8
+ from .samplers import (apply_typical_p, early_eos_penalty,
9
+ top_k_top_p_filtering, freq_rep_penalty)
10
+ from .nn_future import RotatingBufferCache
11
+ from .minbpe.codebook import CodebookTokenizer
12
+ from .minbpe.regex import RegexTokenizer
13
+
14
+
15
+ @torch.inference_mode()
16
+ def ar_generate(texttok: RegexTokenizer, speechtok: CodebookTokenizer,
17
+ codeclm: nn.Module, xx: Tensor, ss_gen: Tensor, first_codex_idx: int,
18
+ max_len: int = 1500, fp16: bool = True, temperature: float = 1.0, topk: int = None,
19
+ top_p=1.0, alpha_frequency=0, alpha_presence=0, penalty_window=100,
20
+ typical_p=1.0, eos_penalty_factor=1.0, eos_penalty_decay=0, n_phones_gen=None, vocode=True,
21
+ beam_width: int = 1, beam_length_penalty=2, use_kv_cache: bool = True) -> tuple[Tensor, Tensor]:
22
+ """ Use the `codeclm` language model to autoregressively generate a completion of `xx` (seq_len), where the first `first_codex_idx`-1
23
+ indices correspond to the input phones. The output generation is limited to at most `max_len` (measured as num latent codes).
24
+ Returns both output first quantizer codes and synthesized audio using `codec`. Use decoding with `beam_width` to keep
25
+ track of top `beam_width` outcomes, selecting the top one among them.
26
+
27
+ - Optionally vocode if `vocode` (default True).
28
+ - See `InferenceConfig` for other inference docs.
29
+ """
30
+ assert xx.dim() == 1, "Only batch size of 1 is currently supported."
31
+ assert beam_width == 1, "Only beam size of 1 is currently supported."
32
+ # internally our batch size will be the beam width
33
+ bs = beam_width
34
+ x_inp = xx[None].repeat(bs, 1) # (bs, seq_len)
35
+ ss_gen = ss_gen[None].repeat(bs, 1, 1)
36
+ # We must subtract 1 in the line below so that we match the train-time conditions of having a
37
+ # False padding value for the <bos> token position. This is needed so that we correctly use the
38
+ # _acoustic_ and not the linguistic language embedding for the <bos> token.
39
+ offsets = torch.tensor([first_codex_idx - 1 for _ in range(bs)], dtype=torch.long, device=xx.device)
40
+ valid_logit_idx_start = len(texttok.vocab) # vocab['s2i']['quant0-0000']
41
+ valid_logit_idx_end = len(texttok.vocab) + len(speechtok.vocab) + 1 # vocab['s2i']['quant1-0000']
42
+ # Make mask that is True where we have valid outputs, False otherwise (where we have text outputs).
43
+ # logit_mask = torch.zeros(n_vocab, dtype=bool, device=x_inp.device)
44
+ # logit_mask[valid_logit_idx_start:valid_logit_idx_end] = True
45
+ # logit_mask[vocab['s2i']['<eos>']] = True
46
+ cum_logprobs = torch.zeros(bs, dtype=torch.float, device=x_inp.device)
47
+ eos_idx = len(texttok.vocab) + speechtok.special_tokens['<|endofspeech|>']
48
+ n_vocab = len(texttok.vocab) + len(speechtok.vocab)
49
+
50
+ logging.info(f"Starting beam decoding with beam_width={beam_width}")
51
+
52
+ prev_ids = [[] for _ in range(bs)]
53
+
54
+ cache = None
55
+ if use_kv_cache:
56
+ # Initialise kv cache
57
+ cache_window = min(codeclm.ar.args.sliding_window, x_inp.shape[-1] + max_len)
58
+ cache = RotatingBufferCache(codeclm.ar.args.n_layers, bs, cache_window, codeclm.ar.args.n_kv_heads, codeclm.ar.args.head_dim)
59
+ cache.to(device=x_inp.device, dtype=torch.float16)
60
+
61
+ counter = 0
62
+ while x_inp.shape[-1] < max_len:
63
+ counter += 1
64
+ gen_length = torch.tensor([x_inp.shape[-1] for _ in range(bs)], dtype=torch.long, device=xx.device)
65
+ padding_mask = length_to_mask(gen_length, offsets)
66
+
67
+ with torch.autocast('cuda', enabled=fp16):
68
+ logits: Tensor = codeclm(x_inp, padding_mask, spk_reference=ss_gen, cache=cache, counter=counter)
69
+ logits = logits.float()
70
+
71
+ logits = logits[:, -1] # select last index, now (bs, logit_dim)
72
+
73
+ # <---------------------- logit filtering ---------------------->
74
+ filtered_logits = logits.clone()
75
+
76
+ # apply repetition penalty before logit mask if any item in the beam has more than 1 prior token.
77
+ if len(prev_ids[0]) > 1:
78
+ filtered_logits = freq_rep_penalty(filtered_logits, previous=torch.tensor(prev_ids, dtype=torch.long),
79
+ alpha_frequency=alpha_frequency, alpha_presence=alpha_presence,
80
+ penalty_window=penalty_window)
81
+
82
+ filtered_logits[..., :valid_logit_idx_start-1] = float('-inf')
83
+ filtered_logits[..., valid_logit_idx_end:] = float('-inf')
84
+
85
+ if n_phones_gen is not None:
86
+ # apply eos penalty
87
+ filtered_logits = early_eos_penalty(filtered_logits, len(prev_ids[0]), n_phones_gen,
88
+ eos_penalty_decay, eos_penalty_factor,
89
+ eos_index=eos_idx)
90
+
91
+ filtered_logits = filtered_logits / temperature
92
+ filtered_logits = top_k_top_p_filtering(filtered_logits, top_k=topk, top_p=top_p)
93
+ filtered_logits = apply_typical_p(filtered_logits, mass=typical_p)
94
+
95
+ # mask out anything that isn't first quantizer output codes
96
+ filtered_logits[..., :valid_logit_idx_start-1] = float('-inf')
97
+ filtered_logits[..., valid_logit_idx_end:] = float('-inf')
98
+ logits = filtered_logits
99
+
100
+ # <---------------------- next frame prediction --------------------->
101
+
102
+ logprobs = logits.log_softmax(dim=-1)
103
+
104
+ # update assignments: if any beam ended in <eos> last step, it MUST also end in <eos> this step.
105
+ # so, below we multiply the logits with a True/False mask, setting to
106
+ for j in range(bs):
107
+ if x_inp[j, -1] == eos_idx:
108
+ # do not add any additional probability to it, keeping it the same for all vocab idxs
109
+ logprobs[j] = float('-inf') # zero probability of anything non-eos after 1 eos
110
+ logprobs[j, eos_idx] = 0 # probability=1 of <eos> after <eos>
111
+
112
+ candidate_cum_logprobs = cum_logprobs[:, None] + logprobs # (bs, 1) + (bs, vocab) -> (bs, vocab)
113
+
114
+ logp_flat = logprobs.flatten()
115
+ candidates = torch.multinomial(logp_flat.exp(), num_samples=beam_width, replacement=False) # (bs,)
116
+ # Ravel it up:
117
+ beam_idxs = candidates // n_vocab # (bs,)
118
+ tok_inds_in_each_beam = candidates % n_vocab # (bs,)
119
+
120
+ # check for breaks
121
+ if torch.all(tok_inds_in_each_beam == eos_idx):
122
+ # apply length penalty:
123
+ non_eos_toks = (x_inp != eos_idx).sum(dim=-1) # (bs,) number of non eos toks
124
+ gen_length = non_eos_toks - first_codex_idx
125
+ penalties = (gen_length**beam_length_penalty)
126
+ penalized_cum_tok_logp = candidate_cum_logprobs / penalties[:, None]
127
+
128
+ eos_avg_logps = penalized_cum_tok_logp[:, eos_idx]
129
+ best_beam_idx = eos_avg_logps.argmax()
130
+ best_avg_logp = eos_avg_logps[best_beam_idx]
131
+ best_beam = x_inp[best_beam_idx]
132
+ logging.info((f"best beam = {best_beam_idx} @ penalized_cum_tok_logp = {best_avg_logp.item():.3f} |\n num toks: {non_eos_toks.cpu().tolist()}. "
133
+ f"Candidates: {eos_avg_logps.cpu()} |\n non-eos toks: {non_eos_toks.cpu().tolist()} |\n penalties: {penalties.cpu().tolist()} | "
134
+ f"raw cumulative probs: {candidate_cum_logprobs[:, eos_idx].cpu().tolist()}"))
135
+ break
136
+
137
+ # update beam histories:
138
+ x_inp = x_inp[beam_idxs]
139
+ # update next token
140
+ next_sample = tok_inds_in_each_beam
141
+ # update cum logprob
142
+ cum_logprobs = cum_logprobs[beam_idxs] + logprobs[beam_idxs, tok_inds_in_each_beam]
143
+ # update prior inds to point to correct beam
144
+ prev_ids = [copy.deepcopy(prev_ids[beam_idx.item()]) for beam_idx in beam_idxs]
145
+ # add new tokens to previous ids
146
+ for j in range(bs):
147
+ prev_ids[j].append(tok_inds_in_each_beam[j].item())
148
+
149
+ logging.debug("L%d | next sample: %s | beam: %s | cum_logp: %s", len(x_inp[0]), next_sample.cpu().tolist(), beam_idxs.cpu().tolist(), cum_logprobs.cpu())
150
+
151
+ # update cache with beam indexes
152
+ if cache is not None:
153
+ cache.cache_k = cache.cache_k[:, beam_idxs]
154
+ cache.cache_v = cache.cache_v[:, beam_idxs]
155
+
156
+ # add 1 None below to make (bs,) -> (bs, 1) so we can concat along seq len dim.
157
+ x_inp = torch.cat([x_inp, next_sample[:, None]], dim=-1)
158
+
159
+
160
+ if x_inp.shape[-1] >= max_len - 1:
161
+ logging.warning(f"[autoregressive generation] output length = {x_inp.shape[-1]} -- inference likely failed or input too long!")
162
+ best_beam = x_inp[0]
163
+
164
+ if not vocode: return best_beam # (seq_len,)
165
+ else: raise AssertionError()
mars5/diffuser.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Discrete multinomial diffusion code adapted from https://github.com/RF5/transfusion-asr,
3
+ which in turn is adapted from https://github.com/ehoogeboom/multinomial_diffusion.
4
+
5
+ Please see the original repo (https://github.com/ehoogeboom/multinomial_diffusion) and paper for full
6
+ details on how multinomial diffusion works -- thanks to the original authors!
7
+ """
8
+
9
+ import torch
10
+ from torch import Tensor
11
+ from torch.functional import F
12
+ import numpy as np
13
+ from dataclasses import dataclass
14
+ from typing import Union
15
+
16
+ # -------------- Multinomial utility functions -----------
17
+
18
+ MIN_LOG_ARG = 1e-7 # originally was 1e-40
19
+
20
+ def log_1_min_a(a): return torch.log((1 - a.exp()).clamp_(min=1e-30))
21
+
22
+ def log_add_exp(a, b):
23
+ maximum = torch.max(a, b)
24
+ return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))
25
+
26
+ def extract(a: Tensor, t, x_shape):
27
+ """ Given 1D vector of alpha/alpha_cum/betas, get index at `t` of shape (bs,), and then
28
+ broadcast it to number of dims in `x_shape`.
29
+ """
30
+ b, *_ = t.shape
31
+ out = a.gather(-1, t)
32
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
33
+
34
+ def index_to_log_onehot(x, num_classes, dim=-1, dtype=torch.float32):
35
+ """ Convert indices `x` (bs, ...) to approx one-hot log-probs of shape (bs, ..., num_classes) """
36
+ assert x.max().item() < num_classes, \
37
+ f'Error: {x.max().item()} >= {num_classes}'
38
+ x_onehot = F.one_hot(x, num_classes)
39
+ if dim == 1:
40
+ permute_order = (0, -1) + tuple(range(1, len(x.size())))
41
+ x_onehot = x_onehot.permute(permute_order)
42
+ else:
43
+ pass
44
+
45
+ log_x = torch.log(x_onehot.to(dtype).clamp(min=MIN_LOG_ARG)) # so min(log_x) will be -30
46
+
47
+ return log_x
48
+
49
+ def sum_except_batch(x: Tensor, num_dims=1) -> Tensor:
50
+ '''
51
+ Sums all dimensions except the first.
52
+ Args:
53
+ x: Tensor, shape (batch_size, ...)
54
+ num_dims: int, number of batch dims (default=1)
55
+ Returns:
56
+ x_sum: Tensor, shape (batch_size,)
57
+ '''
58
+ return x.reshape(*x.shape[:num_dims], -1).sum(-1)
59
+
60
+ # -------------- Multinomial diffusion class -------------
61
+
62
+ class MultinomialDiffusion():
63
+ def __init__(self, num_classes, timesteps=100, diffusion_s=0.008,
64
+ loss_type='vb_stochastic', parametrization='x0',
65
+ dtype=torch.float32,
66
+ device='cpu'):
67
+ super(MultinomialDiffusion, self).__init__()
68
+ assert loss_type in ('vb_stochastic',)
69
+ assert parametrization in ('x0', 'direct')
70
+
71
+ self.num_classes = num_classes
72
+ self.loss_type = loss_type
73
+ self.num_timesteps = timesteps
74
+ self.parametrization = parametrization
75
+
76
+ alphas = self.cosine_beta_schedule(timesteps, diffusion_s)
77
+
78
+ alphas = alphas.to(torch.float64)
79
+ log_alpha = alphas.log()
80
+ log_cumprod_alpha = torch.cumsum(log_alpha, dim=-1)
81
+
82
+ log_1_min_alpha = log_1_min_a(log_alpha) # = log(betas)
83
+
84
+ log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha) # = log(1- \bar{a})
85
+ a = log_add_exp(log_alpha, log_1_min_alpha) # log(1-beta + beta) = log(1) = 0
86
+
87
+ assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5
88
+ assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5
89
+ assert (torch.cumsum(log_alpha, dim=-1) - log_cumprod_alpha).abs().sum().item() < 1.e-5
90
+
91
+ # Convert to float32 and register buffers.
92
+ self.log_alpha = log_alpha.to(dtype).to(device)
93
+ self.log_1_min_alpha = log_1_min_alpha.to(dtype).to(device)
94
+ self.log_cumprod_alpha = log_cumprod_alpha.to(dtype).to(device)
95
+ self.log_1_min_cumprod_alpha = log_1_min_cumprod_alpha.to(dtype).to(device)
96
+
97
+ @staticmethod
98
+ def cosine_beta_schedule(timesteps, s=0.008) -> Tensor:
99
+ """
100
+ cosine schedule as proposed in https://arxiv.org/abs/2102.09672 .
101
+ Returns alpha parameters, NOT Beta
102
+ """
103
+ steps = timesteps + 1
104
+ x = torch.linspace(0, timesteps, steps)
105
+ alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
106
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
107
+ alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])
108
+ alphas = torch.clamp(alphas, 0.001, 1.0)
109
+ return torch.sqrt(alphas)
110
+
111
+ def multinomial_kl(self, log_prob1: Tensor, log_prob2: Tensor, dim=-1) -> Tensor:
112
+ """ Get KL divergence between two categorical distributions specified with `log_prob1` and `log_prob2`.
113
+ Assumed probability dim is `dim` (i.e. log_prob1.exp().sum(dim=`dim`) should be tensor of ones)
114
+ """
115
+ kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=dim)
116
+ return kl
117
+
118
+ def q_pred_one_timestep(self, log_x_t: Tensor, t: Tensor) -> Tensor:
119
+ """ Compute q(x_t | x_{t-1}) = C(x_t | alpha_t * x_{t-1} + (1-alpha_t)/K in the log-domain
120
+ given `log_x_t` as log one-hot encoding of x_t.
121
+
122
+ Recall due to symmetry property we can compute
123
+ this value using x_t instead of x_{t-1} (se appendix A of https://arxiv.org/pdf/2102.05379.pdf)
124
+ """
125
+ dt = log_x_t.dtype
126
+ log_alpha_t = extract(self.log_alpha, t, log_x_t.shape).to(dt)
127
+ log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape).to(dt)
128
+
129
+ # alpha_t * E[xt] + (1 - alpha_t) 1 / K
130
+ log_probs = log_add_exp(
131
+ log_x_t + log_alpha_t,
132
+ log_1_min_alpha_t - np.log(self.num_classes)
133
+ )
134
+ return log_probs
135
+
136
+ def q_pred_one_timestep_scaled(self, log_x_t: Tensor, t: Tensor, c: int, jump_len: int) -> Tensor:
137
+ """ Compute q(x_t | x_{t-1}) = C(x_t | alpha_t * x_{t-1} + (1-alpha_t)/K in the log-domain
138
+ given `log_x_t` as log one-hot encoding of x_t.
139
+
140
+ Recall due to symmetry property we can compute
141
+ this value using x_t instead of x_{t-1} (se appendix A of https://arxiv.org/pdf/2102.05379.pdf)
142
+ """
143
+ dt = log_x_t.dtype
144
+ log_alpha_t = extract(self.log_alpha, t, log_x_t.shape).to(dt)
145
+ log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape).to(dt)
146
+
147
+ # Magic
148
+ xax = torch.arange(0,log_x_t.shape[1],1).to(log_x_t.device)
149
+ aa=log_x_t.shape[1]*(c/jump_len)
150
+ sig = 1/(1+torch.exp(-(xax-aa+20)/8))
151
+ log_alpha_t = (torch.log(1/sig)[None,:,None] + log_alpha_t).clamp(-torch.inf, 0)
152
+ log_1_min_alpha_t = torch.log(sig)[None,:,None] + log_1_min_alpha_t
153
+
154
+ # alpha_t * E[xt] + (1 - alpha_t) 1 / K
155
+ log_probs = log_add_exp(
156
+ log_x_t + log_alpha_t,
157
+ log_1_min_alpha_t - np.log(self.num_classes)
158
+ )
159
+ return log_probs
160
+
161
+ def q_pred(self, log_x_start: Tensor, t) -> Tensor:
162
+ """ Compute q(x_t | x_0) = C(x_t | bar{alpha}_t * x_0 + (1 - bar{alpha}_t)/K ) in log domain,
163
+ given `log_x_start` of log probs of x_0.
164
+ """
165
+ dt = log_x_start.dtype
166
+ log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape).to(dt)
167
+ log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape).to(dt)
168
+
169
+ log_probs = log_add_exp(
170
+ log_x_start + log_cumprod_alpha_t,
171
+ log_1_min_cumprod_alpha - np.log(self.num_classes)
172
+ )
173
+
174
+ return log_probs
175
+
176
+ def q_posterior(self, log_x_start, log_x_t, t):
177
+ """ Compute `q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)`
178
+ where q(xt | xt-1, x0) = q(xt | xt-1).
179
+ """
180
+ # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)
181
+ # where q(xt | xt-1, x0) = q(xt | xt-1).
182
+
183
+ t_minus_1 = t - 1
184
+ # Remove negative values, will not be used anyway for final decoder
185
+ t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
186
+ log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1) # log( q(x_{t-1} | x_0) )
187
+ # if t == 0, then log( q(x_0 | x_0) ) = log( one_hot(x_0) ), not even random at that point.
188
+ # so, where t == 0
189
+ num_axes = (1,) * (len(log_x_start.size()) - 1)
190
+ t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start) # broadcast to non-batch axes
191
+ log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)
192
+ # where it is zero, replace
193
+ # with log one-hot encoding of x0.
194
+
195
+ # Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
196
+ # Not very easy to see why this is true. But it is :)
197
+ # log_EV_qxtmin_x0 ~ q(x_{t-1} | x_0)
198
+ # q_pred_one_timestep(log_x_t, t) ~ q(x_t | x_{t-1}) (which due to symmetry can be computed using x_t)
199
+ unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t) # numerator of bayes
200
+
201
+ # approximate denominator with just a normalizing sum.
202
+ log_EV_xtmin_given_xt_given_xstart = \
203
+ unnormed_logprobs \
204
+ - torch.logsumexp(unnormed_logprobs, dim=-1, keepdim=True)
205
+
206
+ return log_EV_xtmin_given_xt_given_xstart
207
+
208
+ def p_pred(self, log_x_t, t, log_x0_pred):
209
+ """ Predict `p(x_{t-1} | x_t)` using `q(xt-1 | xt, hat{x0})`, where `hat{x0}` is given by
210
+ log probabilities from model as `log_x0_pred` (bs, ...., K) and x_t is given by
211
+ `log_x_t` of shape `(bs, ..., K)`
212
+ """
213
+ # log_x_recon = self.predict_start(log_x, t=t) # model itself predicts x_0
214
+ # log_x0_pred
215
+ log_model_pred = self.q_posterior(
216
+ log_x_start=log_x0_pred, log_x_t=log_x_t, t=t)
217
+ return log_model_pred
218
+
219
+ def log_sample_categorical(self, logprobs: Tensor, dim=-1) -> Tensor:
220
+ """ Sample from categorical `logprobs` (bs, ..., probs), where position of probs is specified
221
+ by `dim`.
222
+
223
+ Returns sampled long indices of shape `(bs, ...)`
224
+ """
225
+ uniform = torch.rand_like(logprobs)
226
+ gumbel_noise = -torch.log( (-torch.log(uniform.clamp_(min=MIN_LOG_ARG)) ).clamp_(min=MIN_LOG_ARG))
227
+ sample = (gumbel_noise + logprobs).argmax(dim=dim)
228
+ return sample
229
+
230
+ def q_sample(self, log_x_start, t):
231
+ """ Draw `x_t` ~ q(x_t | x_0) . `log_x_start` is of shape `(bs, ..., K)`, returns result of same shape """
232
+ log_EV_qxt_x0 = self.q_pred(log_x_start, t)
233
+ sample = self.log_sample_categorical(log_EV_qxt_x0)
234
+ # log_sample = index_to_log_onehot(sample, self.num_classes)
235
+
236
+ return sample #log_sample
237
+
238
+ def compute_Lt(self, log_x_start: Tensor, log_x_t: Tensor, log_x0_pred: Tensor, t,
239
+ detach_mean=False, include_kl_prior=True):
240
+ """ Get loss given one-hot log x_0, one-hot log x_t, t, and model prediction `log_x0_pred`.
241
+ Parameters:
242
+ - `log_x_start`: ground-truth input x0, converted to log one-hot (bs, ..., K)
243
+ - `log_x_t`: sampled noisy input at `x_t`, converted to log one-hot (bs, ..., K)
244
+ - `t`: diffusion timestep (bs,)
245
+ - `log_x0_pred`: model prediction of log probabilities of x0, i.e. hat{x0}.
246
+ - `include_kl_prior`: add last two terms to model loss (does not change optimization problem).
247
+ """
248
+ dtype = log_x_start.dtype
249
+ log_true_prob = self.q_posterior(
250
+ log_x_start=log_x_start, log_x_t=log_x_t, t=t)
251
+
252
+ log_model_prob = self.p_pred(log_x_t=log_x_t, t=t, log_x0_pred=log_x0_pred)
253
+
254
+ if detach_mean:
255
+ log_model_prob = log_model_prob.detach()
256
+
257
+ kl = self.multinomial_kl(log_true_prob, log_model_prob)
258
+ kl = sum_except_batch(kl)
259
+
260
+ # Add L_0, -log(p(x_0 | x_1))
261
+ decoder_nll = - (log_x_start.exp() * log_model_prob).sum(dim=-1)
262
+ decoder_nll = sum_except_batch(decoder_nll)
263
+
264
+ mask = (t == torch.zeros_like(t)).to(dtype)
265
+ loss = mask * decoder_nll + (1. - mask) * kl # only add L0 if t == 0.
266
+
267
+ if include_kl_prior:
268
+ pt = torch.ones_like(t, dtype=dtype)
269
+ kl_prior = self.kl_prior(log_x_start)
270
+ loss = (kl) + kl_prior
271
+
272
+ return loss
273
+
274
+ def kl_prior(self, log_x_start: Tensor) -> Tensor:
275
+ """ This function computes -H_{q}(x_T | x_0)+H_{p}(x_T), which
276
+ by some math (see wiki for KL div relation to conditional entropy).
277
+ So KL(q(x_T | x_0) || 1/K) = -H_{q}(x_T | x_0)+H_{p}(x_T) for categorical distribution.
278
+
279
+ Given `log_x_start` (bs, ..., probs), return KL prior of shape (bs,)
280
+ """
281
+ b = log_x_start.size(0)
282
+ device = log_x_start.device
283
+ ones = torch.ones(b, device=device, dtype=torch.long)
284
+
285
+ log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones) # q(x_T | x_0)
286
+ log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob)) # log(1/K), broadcast to q(x_T|x_0) shape
287
+
288
+ kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)
289
+ return sum_except_batch(kl_prior)
290
+
291
+
292
+ def index2logit(x: Tensor, vocab_size: int, dtype=torch.float32):
293
+ x = F.one_hot(x, num_classes=vocab_size).to(dtype)
294
+ x = x * (vocab_size/(vocab_size - 1)) - 1/(vocab_size - 1)
295
+ return x
296
+
297
+
298
+ # ------------------------------
299
+ # Functions adapted from the full
300
+
301
+
302
+ @dataclass
303
+ class DSH():
304
+ # Diffusion Sampling Hyperparameters [DSH] (Section 4)
305
+ jump_len: int = 1 # j in RePaint paper [default 10] (Section 4.1)
306
+ jump_n_sample: int = 1 # r in RePaint paper [default 10] (Section 4.1)
307
+ last_greedy: bool = False # whether to not sample at t=0, but take argmax prediction. [default False]
308
+ x_0_temp: float = 1.0 # reweight temp for model prediction of x0
309
+ guidance_w: float = 1.0 # classifier free guidance weight [default 1.5] (Section 4.3)
310
+ enable_kevin_scaled_inference: bool = True # sequentially progressive diffusion [default True] (Section 4.2)
311
+ T_override: Union[None, int] = None # allow variable transcription sizes during inference (Section 4.4)
312
+
313
+ deep_clone: bool = False # whether to do deep clone.
314
+ q0_override_steps: int = 0 # number of steps that we allow overriding the input quant level 0 inputs.
315
+ progress: bool = False # whether to show progress bar
316
+
317
+
318
+ def get_schedule(t_T, jump_len=10, jump_n_sample=10):
319
+ jumps = {}
320
+ for j in range(0, t_T - jump_len, jump_len):
321
+ jumps[j] = jump_n_sample - 1
322
+ t = t_T
323
+ ts = []
324
+ while t >= 1:
325
+ t = t-1
326
+ ts.append(t)
327
+ if jumps.get(t, 0) > 0:
328
+ jumps[t] = jumps[t] - 1
329
+ for _ in range(jump_len):
330
+ t = t + 1
331
+ ts.append(t)
332
+ ts.append(-1)
333
+ return ts
334
+
335
+
336
+ def forward_diffusion(diff: MultinomialDiffusion, dtype, x, t, c=None, dsh=DSH):
337
+ """Simple forward diffusion process p"""
338
+ log_x_t = index_to_log_onehot(x, diff.num_classes, dtype=dtype)
339
+ if c is not None: x = diff.q_pred_one_timestep_scaled(log_x_t, t, c, dsh.jump_len)
340
+ else: x = diff.q_pred_one_timestep(log_x_t, t)
341
+ x = diff.log_sample_categorical(x)
342
+ return x
343
+
344
+
345
+ def reverse_diffusion(diff: MultinomialDiffusion, model, batch, x_known=None, m=None,
346
+ last_greedy=False, temperature=1.0, alphas=None, ensemble_size=1, dsh=DSH):
347
+ """Reverse diffusion process q: predict x_{t-1} given x, t, x_known, m. Optionally do not sample model output
348
+ for t=0, but rather use the greedy argmax with `last_greedy`.
349
+ """
350
+ x = batch[4]
351
+ t = batch[-1]
352
+ if x_known is None: x_known = torch.zeros_like(x)
353
+ if m is None: m = torch.zeros_like(x)
354
+
355
+ # Equation 8b
356
+ # for b in batch:
357
+ # print(f"{b.shape}: {b}")
358
+ x_0_pred = model(*batch) # (bs, seq_len, logit_dim, n_quant)
359
+ x_0_pred = x_0_pred.permute(0, 1, 3, 2) # (bs, seq_len, n_quant, dim)
360
+
361
+ if dsh.guidance_w != 1:
362
+ uncond_x_0_pred = model(*(c.clone() if c is not None else None for c in batch), drop_cond=True)
363
+ uncond_x_0_pred = uncond_x_0_pred.permute(0, 1, 3, 2)
364
+ x_0_pred = dsh.guidance_w*x_0_pred + (1-dsh.guidance_w)*uncond_x_0_pred
365
+
366
+ x_0_pred = x_0_pred / temperature
367
+ log_x_0_pred = F.log_softmax(x_0_pred, dim=-1)
368
+ log_x_t = index_to_log_onehot(x, diff.num_classes, dtype=x_0_pred.dtype)
369
+
370
+ # print("PRE: ", log_x_t.shape, t.shape, log_x_0_pred.shape)
371
+ log_model_pred = diff.p_pred(log_x_t, t, log_x_0_pred) # p(x_{t-1} | x_{t})
372
+
373
+ a_t = alphas[t[0]] if alphas is not None else 0
374
+ mat = torch.eye(ensemble_size, device=x.device)*(1-a_t)
375
+ mat += 1/ensemble_size * a_t
376
+ mat = torch.block_diag(*([mat]*(x.shape[0]//ensemble_size)))
377
+ log_model_pred = ( (mat[..., None, None] ).log().to(x.dtype) + log_model_pred[None])
378
+ log_model_pred = torch.logsumexp(log_model_pred, dim=1)
379
+
380
+ if (t==0).all() and last_greedy: # Do not sample at t=0
381
+ x_tm1_unknown = log_model_pred.argmax(dim=-1)
382
+ else:
383
+ x_tm1_unknown = diff.log_sample_categorical(log_model_pred)
384
+
385
+ # Equation 8a
386
+ x_known_log = index_to_log_onehot(x_known, diff.num_classes, dtype=x_0_pred.dtype)
387
+ if (t==0).all(): # Do not sample at t=0
388
+ x_tm1_known = x_known
389
+ else:
390
+ x_tm1_known = diff.q_sample(x_known_log, t)
391
+
392
+ # Equation 8c
393
+ x_tm1 = x_tm1_known * m.long() + x_tm1_unknown * (1 - m.long())
394
+ return x_tm1, x_0_pred
395
+
396
+
397
+
398
+ @torch.inference_mode()
399
+ def perform_simple_inference(model: torch.nn.Module, batch: tuple, diff: MultinomialDiffusion, T, dtype=torch.float16,
400
+ retain_quant0: bool = True, dsh=DSH):
401
+ """ If `retain_quant0`, then do not sample quant0 in each forward or reverse diffusion step. """
402
+
403
+ # (bs=1, N), (bs, seq_len2, 8), (bs,)
404
+ c_text, c_codes, c_text_lengths, c_codes_lengths, x, x_padding_mask = batch
405
+
406
+ device = c_text.device
407
+ bs = c_text.shape[0]
408
+ x_quant0 = x[..., 0].clone() # (bs, seq_len) 0th quant level
409
+ x = torch.randint(0, diff.num_classes, x.shape, dtype=x.dtype, device=device)
410
+ # CRITICAL LINE: override quantization level 0 with provided quant0 level.
411
+ x[..., 0] = x_quant0
412
+
413
+ # RePaint paper resample scheduling
414
+ times = get_schedule(T, jump_n_sample=dsh.jump_n_sample, jump_len=dsh.jump_len)
415
+
416
+ x_known = torch.zeros_like(x)
417
+ x_known[..., 0] = x[..., 0] # override L0 codes
418
+ m = torch.zeros_like(x).bool()
419
+ # (bs, seq_len, 8)
420
+ m[..., 0] = True
421
+
422
+ offset = 0
423
+ if dsh.deep_clone:
424
+ print(f"Note: using deep clone. Assuming input `c_phones` is concatenated prompt and output phones.",
425
+ "Also assuming no padded indices in `c_codes`.")
426
+ prompt = c_codes
427
+ x = torch.cat((prompt, x), dim=1) # (bs=1, sl1 + sl2, 8)
428
+ x_known = torch.cat((prompt, x_known), dim=1)
429
+ x_padding_mask = torch.cat((
430
+ torch.zeros(x_padding_mask.shape[0], c_codes_lengths[0], dtype=torch.bool, device=x_padding_mask.device),
431
+ x_padding_mask), dim=-1
432
+ )
433
+ # (bs=1, :up to prompt duration, all 8 codebooks) = True/masked.
434
+ m = torch.cat((torch.ones_like(prompt), m), dim=1)
435
+ x_quant0 = torch.cat((prompt[..., 0], x_quant0), dim=-1)
436
+ offset = c_codes_lengths[0]
437
+
438
+ print(f"New x: {x.shape} | new x_known: {x_known.shape} . Base prompt: {prompt.shape}. New padding mask: {x_padding_mask.shape} | m shape: {m.shape}")
439
+
440
+ c = 0 # sequentially progressive diffusion offset (Section 4.2)
441
+
442
+ # ensemble bs (not in paper)
443
+ alphas = torch.linspace(1, 0, T).to(device)
444
+
445
+ pb = zip(times[:-1], times[1:])
446
+ if dsh.progress:
447
+ from fastprogress import progress_bar
448
+ pb = progress_bar(pb, total=len(times)-1)
449
+
450
+ # See RePaint paper algorithm
451
+ for t_last, t_cur in pb:
452
+
453
+ t = torch.ones((bs,), dtype=torch.long, device=x.device) * (t_last)
454
+ if t_cur < t_last:
455
+ if c > dsh.jump_n_sample:
456
+ c = 0
457
+ c += 1/dsh.jump_len
458
+
459
+ # Reverse diffusion: q
460
+ cbatch = (c_text, c_codes, c_text_lengths, c_codes_lengths, x, x_padding_mask, t)
461
+ x, x_0_pred = reverse_diffusion(diff, model, cbatch, x_known, m, temperature=dsh.x_0_temp, alphas=alphas, ensemble_size=1, dsh=dsh)
462
+ else:
463
+ # Forward diffusion: p
464
+ if dsh.enable_kevin_scaled_inference: x = forward_diffusion(diff, dtype, x, t, c=c, dsh=dsh)
465
+ else: x = forward_diffusion(diff, dtype, x, t, c=None, dsh=dsh)
466
+
467
+ if retain_quant0 and dsh.q0_override_steps < t_last:
468
+ x[..., 0] = x_quant0
469
+
470
+ # crop offset:
471
+ x = x[:, offset:]
472
+ return x
mars5/minbpe/base.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contains the base Tokenizer class and a few common helper functions.
3
+ The base class also contains the (common) save/load functionality.
4
+ It would be possible to be a lot more strict about the interface and
5
+ e.g. isolating all regex/pattern parts to the RegexTokenizer, but
6
+ some concessions are made for simplicity.
7
+ """
8
+ import unicodedata
9
+
10
+ # -----------------------------------------------------------------------------
11
+ # a few helper functions useful for both BasicTokenizer and RegexTokenizer
12
+
13
+ def get_stats(ids, counts=None):
14
+ """
15
+ Given a list of integers, return a dictionary of counts of consecutive pairs
16
+ Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
17
+ Optionally allows to update an existing dictionary of counts
18
+ """
19
+ counts = {} if counts is None else counts
20
+ for pair in zip(ids, ids[1:]): # iterate consecutive elements
21
+ counts[pair] = counts.get(pair, 0) + 1
22
+ return counts
23
+
24
+
25
+ def merge(ids, pair, idx):
26
+ """
27
+ In the list of integers (ids), replace all consecutive occurrences
28
+ of pair with the new integer token idx
29
+ Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
30
+ """
31
+ newids = []
32
+ i = 0
33
+ while i < len(ids):
34
+ # if not at the very last position AND the pair matches, replace it
35
+ if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
36
+ newids.append(idx)
37
+ i += 2
38
+ else:
39
+ newids.append(ids[i])
40
+ i += 1
41
+ return newids
42
+
43
+ # first two helper functions...
44
+ def replace_control_characters(s: str) -> str:
45
+ # we don't want to print control characters
46
+ # which distort the output (e.g. \n or much worse)
47
+ # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
48
+ # http://www.unicode.org/reports/tr44/#GC_Values_Table
49
+ chars = []
50
+ for ch in s:
51
+ if unicodedata.category(ch)[0] != "C":
52
+ chars.append(ch) # this character is ok
53
+ else:
54
+ chars.append(f"\\u{ord(ch):04x}") # escape
55
+ return "".join(chars)
56
+
57
+ def render_token(t: bytes) -> str:
58
+ # pretty print a token, escaping control characters
59
+ s = t.decode('utf-8', errors='replace')
60
+ s = replace_control_characters(s)
61
+ return s
62
+
63
+ # -----------------------------------------------------------------------------
64
+ # the base Tokenizer class
65
+
66
+ class Tokenizer:
67
+ """Base class for Tokenizers"""
68
+
69
+ def __init__(self):
70
+ # default: vocab size of 256 (all bytes), no merges, no patterns
71
+ self.merges = {} # (int, int) -> int
72
+ self.pattern = "" # str
73
+ self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
74
+ self.vocab = self._build_vocab() # int -> bytes
75
+
76
+ def train(self, text, vocab_size, verbose=False):
77
+ # Tokenizer can train a vocabulary of size vocab_size from text
78
+ raise NotImplementedError
79
+
80
+ def encode(self, text):
81
+ # Tokenizer can encode a string into a list of integers
82
+ raise NotImplementedError
83
+
84
+ def decode(self, ids):
85
+ # Tokenizer can decode a list of integers into a string
86
+ raise NotImplementedError
87
+
88
+ def _build_vocab(self):
89
+ # vocab is simply and deterministically derived from merges
90
+ vocab = {idx: bytes([idx]) for idx in range(256)}
91
+ for (p0, p1), idx in self.merges.items():
92
+ vocab[idx] = vocab[p0] + vocab[p1]
93
+ for special, idx in self.special_tokens.items():
94
+ vocab[idx] = special.encode("utf-8")
95
+ return vocab
96
+
97
+ def save(self, file_prefix):
98
+ """
99
+ Saves two files: file_prefix.vocab and file_prefix.model
100
+ This is inspired (but not equivalent to!) sentencepiece's model saving:
101
+ - model file is the critical one, intended for load()
102
+ - vocab file is just a pretty printed version for human inspection only
103
+ """
104
+ # write the model: to be used in load() later
105
+ model_file = file_prefix + ".model"
106
+ with open(model_file, 'w') as f:
107
+ # write the version, pattern and merges, that's all that's needed
108
+ f.write("minbpe v1\n")
109
+ f.write(f"{self.pattern}\n")
110
+ # write the special tokens, first the number of them, then each one
111
+ f.write(f"{len(self.special_tokens)}\n")
112
+ for special, idx in self.special_tokens.items():
113
+ f.write(f"{special} {idx}\n")
114
+ # the merges dict
115
+ for idx1, idx2 in self.merges:
116
+ f.write(f"{idx1} {idx2}\n")
117
+ # write the vocab: for the human to look at
118
+ vocab_file = file_prefix + ".vocab"
119
+ inverted_merges = {idx: pair for pair, idx in self.merges.items()}
120
+ with open(vocab_file, "w", encoding="utf-8") as f:
121
+ for idx, token in self.vocab.items():
122
+ # note: many tokens may be partial utf-8 sequences
123
+ # and cannot be decoded into valid strings. Here we're using
124
+ # errors='replace' to replace them with the replacement char �.
125
+ # this also means that we couldn't possibly use .vocab in load()
126
+ # because decoding in this way is a lossy operation!
127
+ s = render_token(token)
128
+ # find the children of this token, if any
129
+ if idx in inverted_merges:
130
+ # if this token has children, render it nicely as a merge
131
+ idx0, idx1 = inverted_merges[idx]
132
+ s0 = render_token(self.vocab[idx0])
133
+ s1 = render_token(self.vocab[idx1])
134
+ f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
135
+ else:
136
+ # otherwise this is leaf token, just print it
137
+ # (this should just be the first 256 tokens, the bytes)
138
+ f.write(f"[{s}] {idx}\n")
139
+
140
+ def load(self, model_file):
141
+ """Inverse of save() but only for the model file"""
142
+ model_file = str(model_file)
143
+ assert model_file.endswith(".model")
144
+ # read the model file
145
+ merges = {}
146
+ special_tokens = {}
147
+ idx = 256
148
+ with open(model_file, 'r', encoding="utf-8") as f:
149
+ # read the version
150
+ version = f.readline().strip()
151
+ assert version == "minbpe v1"
152
+ # read the pattern
153
+ self.pattern = f.readline().strip()
154
+ # read the special tokens
155
+ num_special = int(f.readline().strip())
156
+ for _ in range(num_special):
157
+ special, special_idx = f.readline().strip().split()
158
+ special_tokens[special] = int(special_idx)
159
+ # read the merges
160
+ for line in f:
161
+ idx1, idx2 = map(int, line.split())
162
+ merges[(idx1, idx2)] = idx
163
+ idx += 1
164
+ self.merges = merges
165
+ self.special_tokens = special_tokens
166
+ self.vocab = self._build_vocab()
mars5/minbpe/codebook.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Minimal (byte-level) Byte Pair Encoding tokenizer.
3
+
4
+ Unlike RegexTokenizer:
5
+ - Operates on integer codes from an encodec codebook.
6
+ """
7
+
8
+ import regex as re
9
+ from .base import Tokenizer, get_stats, merge
10
+
11
+
12
+ class CodebookTokenizer(Tokenizer):
13
+
14
+ def __init__(self, pattern=None, codebook_size=1024):
15
+ """
16
+ - pattern: optional string to override the default (GPT-4 split pattern)
17
+ - special_tokens: str -> int dictionary of special tokens
18
+ example: {'<|endoftext|>': 100257}
19
+ """
20
+ self.merges = {} # (int, int) -> int
21
+ self.pattern = pattern
22
+ self.compiled_pattern = re.compile(self.pattern)
23
+ self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
24
+ self.inverse_special_tokens = {}
25
+ self.codebook_size = codebook_size
26
+ self.vocab = self._build_vocab() # int -> bytes
27
+
28
+ def train(self, text, vocab_size, verbose=False):
29
+ assert vocab_size >= self.codebook_size
30
+ num_merges = vocab_size - self.codebook_size
31
+
32
+ # split the text up into text chunks
33
+ # text is a continuous signal, there is no splitting it up.
34
+ text_chunks = [text,] # re.findall(self.compiled_pattern, text)
35
+
36
+ # input text preprocessing
37
+ ids = [[int(idx) for idx in ch.split(' ')] for ch in text_chunks]
38
+
39
+ # iteratively merge the most common pairs to create new tokens
40
+ merges = {} # (int, int) -> int
41
+ # vocab = {idx: bytes([idx]) for idx in range(self.codebook_size)} # idx -> bytes
42
+ vocab = {idx: f" {idx:04d}".encode('utf-8') for idx in range(self.codebook_size)} # idx -> bytes
43
+
44
+ for i in range(num_merges):
45
+ # count the number of times every consecutive pair appears
46
+ stats = {}
47
+ for chunk_ids in ids:
48
+ # passing in stats will update it in place, adding up counts
49
+ get_stats(chunk_ids, stats)
50
+ # find the pair with the highest count
51
+ pair = max(stats, key=stats.get)
52
+ # mint a new token: assign it the next available id
53
+ idx = self.codebook_size + i
54
+ # replace all occurrences of pair in ids with idx
55
+ ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
56
+ # save the merge
57
+ merges[pair] = idx
58
+ vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
59
+ # prints
60
+ if verbose:
61
+ print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
62
+
63
+ # save class variables
64
+ self.merges = merges # used in encode()
65
+ self.vocab = vocab # used in decode()
66
+
67
+ def register_special_tokens(self, special_tokens):
68
+ # special_tokens is a dictionary of str -> int
69
+ # example: {"<|endoftext|>": 100257}
70
+ self.special_tokens = special_tokens
71
+ self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
72
+
73
+ def decode(self, ids):
74
+ # given ids (list of integers), return Python string
75
+ part_bytes = []
76
+ for idx in ids:
77
+ if idx in self.vocab:
78
+ part_bytes.append(self.vocab[idx])
79
+ elif idx in self.inverse_special_tokens:
80
+ part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
81
+ else:
82
+ raise ValueError(f"invalid token id: {idx}")
83
+ text_bytes = b"".join(part_bytes)
84
+ text = text_bytes.decode("utf-8", errors="replace")
85
+ return text
86
+
87
+ def decode_int(self, ids) -> list[int]:
88
+ ret: str = self.decode(ids)
89
+ for s in self.special_tokens:
90
+ ret = ret.replace(s, ' ' + s + ' ')
91
+ ret = ret.strip()
92
+ ret = [int(t) if t[0].isnumeric() else t for t in ret.split(' ') if len(t) > 0]
93
+ return ret
94
+
95
+ def _encode_chunk(self, text_bytes):
96
+ # return the token ids
97
+ # let's begin. first, convert all bytes to integers in range 0..255
98
+ ids = list(text_bytes)
99
+ while len(ids) >= 2:
100
+ # find the pair with the lowest merge index
101
+ stats = get_stats(ids)
102
+ pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
103
+ # subtle: if there are no more merges available, the key will
104
+ # result in an inf for every single pair, and the min will be
105
+ # just the first pair in the list, arbitrarily
106
+ # we can detect this terminating case by a membership check
107
+ if pair not in self.merges:
108
+ break # nothing else can be merged anymore
109
+ # otherwise let's merge the best pair (lowest merge index)
110
+ idx = self.merges[pair]
111
+ ids = merge(ids, pair, idx)
112
+ return ids
113
+
114
+ def encode_ordinary(self, text):
115
+ """Encoding that ignores any special tokens."""
116
+ # split text into chunks of text by categories defined in regex pattern
117
+ text_chunks = [text,] #re.findall(self.compiled_pattern, text)
118
+ # all chunks of text are encoded separately, then results are joined
119
+ ids = []
120
+ for chunk in text_chunks:
121
+ # chunk_bytes = chunk.encode("utf-8") # raw bytes
122
+ chunk_ids = [int(idx) for idx in chunk.split(' ')]
123
+ chunk_ids = self._encode_chunk(chunk_ids)
124
+ ids.extend(chunk_ids)
125
+ return ids
126
+
127
+ def encode(self, text, allowed_special="none_raise"):
128
+ """
129
+ Unlike encode_ordinary, this function handles special tokens.
130
+ allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
131
+ if none_raise, then an error is raised if any special token is encountered in text
132
+ this is the default tiktoken behavior right now as well
133
+ any other behavior is either annoying, or a major footgun
134
+ """
135
+ # decode the user desire w.r.t. handling of special tokens
136
+ special = None
137
+ if allowed_special == "all":
138
+ special = self.special_tokens
139
+ elif allowed_special == "none":
140
+ special = {}
141
+ elif allowed_special == "none_raise":
142
+ special = {}
143
+ assert all(token not in text for token in self.special_tokens)
144
+ elif isinstance(allowed_special, set):
145
+ special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
146
+ else:
147
+ raise ValueError(f"allowed_special={allowed_special} not understood")
148
+ if not special:
149
+ # shortcut: if no special tokens, just use the ordinary encoding
150
+ return self.encode_ordinary(text)
151
+ # otherwise, we have to be careful with potential special tokens in text
152
+ # we handle special tokens by splitting the text
153
+ # based on the occurrence of any exact match with any of the special tokens
154
+ # we can use re.split for this. note that surrounding the pattern with ()
155
+ # makes it into a capturing group, so the special tokens will be included
156
+ special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
157
+ special_chunks = re.split(special_pattern, text)
158
+ # now all the special characters are separated from the rest of the text
159
+ # all chunks of text are encoded separately, then results are joined
160
+ ids = []
161
+ for part in special_chunks:
162
+ part = part.strip()
163
+ if len(part) == 0: continue
164
+ if part in special:
165
+ # this is a special token, encode it separately as a special case
166
+ ids.append(special[part])
167
+ else:
168
+ # this is an ordinary sequence, encode it normally
169
+ ids.extend(self.encode_ordinary(part))
170
+ return ids
171
+
172
+
173
+ def load(self, model_file):
174
+ """Inverse of save() but only for the model file"""
175
+ model_file = str(model_file)
176
+ assert model_file.endswith(".model")
177
+ # read the model file
178
+ merges = {}
179
+ special_tokens = {}
180
+ idx = self.codebook_size
181
+ with open(model_file, 'r', encoding="utf-8") as f:
182
+ # read the version
183
+ version = f.readline().strip()
184
+ assert version == "minbpe v1"
185
+ # read the pattern
186
+ self.pattern = f.readline().strip()
187
+ # read the special tokens
188
+ num_special = int(f.readline().strip())
189
+ for _ in range(num_special):
190
+ special, special_idx = f.readline().strip().split()
191
+ special_tokens[special] = int(special_idx)
192
+ # read the merges
193
+ for line in f:
194
+ # print(line)
195
+ idx1, idx2 = map(int, line.split())
196
+ merges[(idx1, idx2)] = idx
197
+ idx += 1
198
+ self.merges = merges
199
+ self.special_tokens = special_tokens
200
+ self.vocab = self._build_vocab()
201
+
202
+
203
+ def _build_vocab(self):
204
+ # vocab is simply and deterministically derived from merges
205
+ vocab = {idx: f" {idx:04d}".encode('utf-8') for idx in range(self.codebook_size)}
206
+ for (p0, p1), idx in self.merges.items():
207
+ vocab[idx] = vocab[p0] + vocab[p1]
208
+ for special, idx in self.special_tokens.items():
209
+ vocab[idx] = special.encode("utf-8")
210
+ return vocab
mars5/minbpe/regex.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Minimal (byte-level) Byte Pair Encoding tokenizer.
3
+
4
+ Algorithmically follows along the GPT tokenizer:
5
+ https://github.com/openai/gpt-2/blob/master/src/encoder.py
6
+
7
+ Unlike BasicTokenizer:
8
+ - RegexTokenizer handles an optional regex splitting pattern.
9
+ - RegexTokenizer handles optional special tokens.
10
+ """
11
+
12
+ import regex as re
13
+ from .base import Tokenizer, get_stats, merge
14
+
15
+
16
+ # the main GPT text split patterns, see
17
+ # https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py
18
+ GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
19
+ GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
20
+
21
+
22
+ class RegexTokenizer(Tokenizer):
23
+
24
+ def __init__(self, pattern=None):
25
+ """
26
+ - pattern: optional string to override the default (GPT-4 split pattern)
27
+ - special_tokens: str -> int dictionary of special tokens
28
+ example: {'<|endoftext|>': 100257}
29
+ """
30
+ super().__init__()
31
+ self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
32
+ self.compiled_pattern = re.compile(self.pattern)
33
+ self.special_tokens = {}
34
+ self.inverse_special_tokens = {}
35
+
36
+ def train(self, text, vocab_size, verbose=False):
37
+ assert vocab_size >= 256
38
+ num_merges = vocab_size - 256
39
+
40
+ # split the text up into text chunks
41
+ text_chunks = re.findall(self.compiled_pattern, text)
42
+
43
+ # input text preprocessing
44
+ ids = [list(ch.encode("utf-8")) for ch in text_chunks]
45
+
46
+ # iteratively merge the most common pairs to create new tokens
47
+ merges = {} # (int, int) -> int
48
+ vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
49
+ for i in range(num_merges):
50
+ # count the number of times every consecutive pair appears
51
+ stats = {}
52
+ for chunk_ids in ids:
53
+ # passing in stats will update it in place, adding up counts
54
+ get_stats(chunk_ids, stats)
55
+ # find the pair with the highest count
56
+ pair = max(stats, key=stats.get)
57
+ # mint a new token: assign it the next available id
58
+ idx = 256 + i
59
+ # replace all occurrences of pair in ids with idx
60
+ ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
61
+ # save the merge
62
+ merges[pair] = idx
63
+ vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
64
+ # prints
65
+ if verbose:
66
+ print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
67
+
68
+ # save class variables
69
+ self.merges = merges # used in encode()
70
+ self.vocab = vocab # used in decode()
71
+
72
+ def register_special_tokens(self, special_tokens):
73
+ # special_tokens is a dictionary of str -> int
74
+ # example: {"<|endoftext|>": 100257}
75
+ self.special_tokens = special_tokens
76
+ self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
77
+
78
+ def decode(self, ids):
79
+ # given ids (list of integers), return Python string
80
+ part_bytes = []
81
+ for idx in ids:
82
+ if idx in self.vocab:
83
+ part_bytes.append(self.vocab[idx])
84
+ elif idx in self.inverse_special_tokens:
85
+ part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
86
+ else:
87
+ raise ValueError(f"invalid token id: {idx}")
88
+ text_bytes = b"".join(part_bytes)
89
+ text = text_bytes.decode("utf-8", errors="replace")
90
+ return text
91
+
92
+ def _encode_chunk(self, text_bytes):
93
+ # return the token ids
94
+ # let's begin. first, convert all bytes to integers in range 0..255
95
+ ids = list(text_bytes)
96
+ while len(ids) >= 2:
97
+ # find the pair with the lowest merge index
98
+ stats = get_stats(ids)
99
+ pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
100
+ # subtle: if there are no more merges available, the key will
101
+ # result in an inf for every single pair, and the min will be
102
+ # just the first pair in the list, arbitrarily
103
+ # we can detect this terminating case by a membership check
104
+ if pair not in self.merges:
105
+ break # nothing else can be merged anymore
106
+ # otherwise let's merge the best pair (lowest merge index)
107
+ idx = self.merges[pair]
108
+ ids = merge(ids, pair, idx)
109
+ return ids
110
+
111
+ def encode_ordinary(self, text):
112
+ """Encoding that ignores any special tokens."""
113
+ # split text into chunks of text by categories defined in regex pattern
114
+ text_chunks = re.findall(self.compiled_pattern, text)
115
+ # all chunks of text are encoded separately, then results are joined
116
+ ids = []
117
+ for chunk in text_chunks:
118
+ chunk_bytes = chunk.encode("utf-8") # raw bytes
119
+ chunk_ids = self._encode_chunk(chunk_bytes)
120
+ ids.extend(chunk_ids)
121
+ return ids
122
+
123
+ def encode(self, text, allowed_special="none_raise"):
124
+ """
125
+ Unlike encode_ordinary, this function handles special tokens.
126
+ allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
127
+ if none_raise, then an error is raised if any special token is encountered in text
128
+ this is the default tiktoken behavior right now as well
129
+ any other behavior is either annoying, or a major footgun
130
+ """
131
+ # decode the user desire w.r.t. handling of special tokens
132
+ special = None
133
+ if allowed_special == "all":
134
+ special = self.special_tokens
135
+ elif allowed_special == "none":
136
+ special = {}
137
+ elif allowed_special == "none_raise":
138
+ special = {}
139
+ assert all(token not in text for token in self.special_tokens)
140
+ elif isinstance(allowed_special, set):
141
+ special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
142
+ else:
143
+ raise ValueError(f"allowed_special={allowed_special} not understood")
144
+ if not special:
145
+ # shortcut: if no special tokens, just use the ordinary encoding
146
+ return self.encode_ordinary(text)
147
+ # otherwise, we have to be careful with potential special tokens in text
148
+ # we handle special tokens by splitting the text
149
+ # based on the occurrence of any exact match with any of the special tokens
150
+ # we can use re.split for this. note that surrounding the pattern with ()
151
+ # makes it into a capturing group, so the special tokens will be included
152
+ special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
153
+ special_chunks = re.split(special_pattern, text)
154
+ # now all the special characters are separated from the rest of the text
155
+ # all chunks of text are encoded separately, then results are joined
156
+ ids = []
157
+ for part in special_chunks:
158
+ if part in special:
159
+ # this is a special token, encode it separately as a special case
160
+ ids.append(special[part])
161
+ else:
162
+ # this is an ordinary sequence, encode it normally
163
+ ids.extend(self.encode_ordinary(part))
164
+ return ids
mars5/model.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+
9
+ from .nn_future import (FNNSwiGLU, MistralTransformer, ModelArgs,
10
+ RotatingBufferCache, SinePositionalEmbedding)
11
+ from .utils import construct_padding_mask, length_to_mask
12
+
13
+ LAYERNORM_EPS = 4e-5
14
+
15
+ # ------------------------
16
+ # Code adapted from OpenAI guided diffusion repo
17
+
18
+ def timestep_embedding(timesteps, dim, max_period=10000, dtype=torch.float32):
19
+ """
20
+ Create sinusoidal timestep embeddings.
21
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
22
+ These may be fractional.
23
+ :param dim: the dimension of the output.
24
+ :param max_period: controls the minimum frequency of the embeddings.
25
+ :return: an [N x dim] Tensor of positional embeddings.
26
+ """
27
+ half = dim // 2
28
+ freqs = torch.exp(
29
+ -math.log(max_period) * torch.arange(start=0, end=half) / half
30
+ ).to(device=timesteps.device)
31
+ args = timesteps[:, None].float() * freqs[None]
32
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
33
+ if dim % 2:
34
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
35
+ return embedding
36
+
37
+
38
+ # --------------------------------
39
+ # autoregressive codec language model
40
+
41
+
42
+ class CodecLM(nn.Module):
43
+
44
+ def __init__(self, n_vocab, dim=1536, nhead=24, n_layers=26, n_spk_layers=2, dim_ff_scale=None, sliding_window=3000) -> None:
45
+ super().__init__()
46
+
47
+ if dim_ff_scale is None: hidden_dim = int(dim*4*(3/4))
48
+ else: hidden_dim = int(dim*dim_ff_scale)
49
+
50
+ self.cfg = ModelArgs(n_vocab, dim=dim, n_layers=n_layers, n_heads=nhead, n_kv_heads=nhead, hidden_dim=hidden_dim, sliding_window=sliding_window)
51
+ self.ar = MistralTransformer(self.cfg)
52
+
53
+ self.embed = nn.Embedding(n_vocab, dim)
54
+
55
+ # --- spk embedding network
56
+ dim_ff = int(dim*4*(3/4))
57
+ self.pos_embedding = SinePositionalEmbedding(dim, scale=False, alpha=True)
58
+ self.ref_chunked_emb = ChunkedEmbedding(1024 + 1, 8, dim) # add 1 for pad idx
59
+ self.spk_identity_emb = nn.Embedding(1, dim)
60
+ # define custom decoder
61
+ encoder_layer = nn.TransformerEncoderLayer(dim, nhead, dim_ff,
62
+ activation=FNNSwiGLU(dim, dim_ff), dropout=0,
63
+ batch_first=True, norm_first=True, layer_norm_eps=LAYERNORM_EPS)
64
+ encoder_layer.linear1 = nn.Identity()
65
+ self.spk_encoder = nn.TransformerEncoder(encoder_layer, n_spk_layers, norm=nn.LayerNorm(dim, eps=LAYERNORM_EPS))
66
+ # monkeypatch for broken copy.deepcopy of nn.Modules in nn.TransformerDecoder
67
+ for l in self.spk_encoder.layers: l.activation = FNNSwiGLU(dim, dim_ff)
68
+
69
+
70
+ @torch.inference_mode
71
+ def get_spk_embedding(self, spk_reference, c_codes_lengths=None) -> Tensor:
72
+ """ Gets speaker reference embeddings using `spk_reference` codes of shape (bs, seq_len, n_codebooks). """
73
+ bs = spk_reference.shape[0]
74
+ if bs != 1:
75
+ raise AssertionError(f"Speaker embedding extraction only implemented using for bs=1 currently.")
76
+ spk_seq = self.ref_chunked_emb(spk_reference) # (bs, sl, dim)
77
+ spk_ref_emb = self.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim)
78
+
79
+ spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim)
80
+ # add pos encoding
81
+ spk_seq = self.pos_embedding(spk_seq)
82
+ # codebook goes from indices 0->1023, padding is idx 1024 (the 1025th entry)
83
+ src_key_padding_mask = construct_padding_mask(spk_reference[:, :, 0], 1024)
84
+ src_key_padding_mask = torch.cat((
85
+ # append a zero here since we DO want to attend to initial position.
86
+ torch.zeros(src_key_padding_mask.shape[0], 1, dtype=bool, device=src_key_padding_mask.device),
87
+ src_key_padding_mask
88
+ ),
89
+ dim=1)
90
+ # pass through transformer
91
+ res = self.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim).
92
+ return res.squeeze(1)
93
+
94
+
95
+ def forward(self, x: Tensor, x_padding_mask: Optional[Tensor] = None, spk_reference: Optional[Tensor] = None,
96
+ cache: Optional[RotatingBufferCache] = None, counter: int = 0) -> Tensor:
97
+ """ Inputs:
98
+ - `x`: (bs, seq_len, vocab_size)
99
+ - `x_padding_mask`: (bs, seq_len) mask for each input, True for positions to *ignore*, False otherwise.
100
+ Note that since this is an autoregressive model, this doesn't actually matter for infernece, so it is ignored at inference.
101
+ - `spk_reference`: (bs, seq_len, n_codebooks) corresponding to the speaker reference to clone from.
102
+ - `cache` and `counter`: used for kv caching, optional.
103
+
104
+ Returns `x` of same shape (bs, seq_len, dim)
105
+ """
106
+ x = self.embed(x)
107
+
108
+ # --- speaker reference/embedding
109
+ if spk_reference is not None:
110
+ # compute ref
111
+ bs = spk_reference.shape[0]
112
+ spk_seq = self.ref_chunked_emb(spk_reference) # (bs, sl, dim)
113
+ spk_ref_emb = self.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim)
114
+
115
+ spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim)
116
+ # add pos encoding
117
+ spk_seq = self.pos_embedding(spk_seq)
118
+ # codebook goes from indices 0->1023, padding is idx 1024 (the 1025th entry)
119
+ src_key_padding_mask = construct_padding_mask(spk_reference[:, :, 0], 1024)
120
+ src_key_padding_mask = torch.cat((
121
+ # append a zero here since we DO want to attend to initial position.
122
+ torch.zeros(src_key_padding_mask.shape[0], 1, dtype=bool, device=src_key_padding_mask.device),
123
+ src_key_padding_mask
124
+ ),
125
+ dim=1)
126
+ # pass through transformer
127
+ res = self.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim).
128
+
129
+ x = torch.cat([res, x], dim=1)
130
+
131
+ positions = torch.arange(0, x.shape[1], device=x.device, dtype=torch.long)
132
+ if cache is not None and counter != 1:
133
+ # using only the last token to predict the next one
134
+ x = x[:,-1,:].unsqueeze(1)
135
+ positions = positions[-1:]
136
+
137
+ x = self.ar(x, positions, cache) # (bs, seq_len, vocab)
138
+ if spk_reference is not None and (cache is None or counter == 1):
139
+ x = x[:, 1:] # strip out the first output token corresponding to the speaker embedding token.
140
+
141
+ return x
142
+
143
+
144
+ # -------------------------
145
+ # residual discrete diffusion model
146
+
147
+ class ChunkedEmbedding(nn.Module):
148
+
149
+ def __init__(self, codebook_size: int, n_quantizer: int, dim: int) -> None:
150
+ super().__init__()
151
+ assert dim % n_quantizer == 0, f"ChunkedEmbedding output dim ({dim}) must be divisible by n_quant {n_quantizer}"
152
+ self.embs = nn.ModuleList([nn.Embedding(codebook_size, dim//n_quantizer) for _ in range(n_quantizer)])
153
+
154
+ def forward(self, x: Tensor) -> Tensor:
155
+ """ Embeds each codebook index in `x` (bs, seq_len, n_quantizer) to an embedding vector, concatenating results.
156
+ Returns output of shape (bs, seq_len, dim)
157
+ """
158
+ y = torch.cat([self.embs[i](x[..., i]) for i in range(x.shape[-1])], dim=-1)
159
+ return y
160
+
161
+
162
+
163
+ class ResidualTransformer(nn.Module):
164
+
165
+ def __init__(self, n_text_vocab, n_quant=1024, dim=1024, nhead=16,
166
+ enc_layers=8, dec_layers=16, n_spk_layers=3,
167
+ c_quant_levels=8, pred_quant_levels=8,
168
+ t_emb_dim=1024, norm_first=True, p_cond_drop=0.1, dropout=0) -> None:
169
+ super().__init__()
170
+
171
+ self.cond_pos_embedding = SinePositionalEmbedding(dim, scale=False, alpha=True)
172
+ self.pos_embedding = SinePositionalEmbedding(dim, scale=False, alpha=True)
173
+
174
+ # *4 from heuristic, *2/3 from swiglu, since there are 3 linear matrices not 2.
175
+ # so we must keep # params the same.
176
+ dim_ff = int(dim*4*(3/4))
177
+
178
+ # define custom encoder
179
+ encoder_layer = nn.TransformerEncoderLayer(dim, nhead, dim_ff,
180
+ activation=FNNSwiGLU(dim, dim_ff), dropout=dropout,
181
+ batch_first=True, norm_first=norm_first, layer_norm_eps=LAYERNORM_EPS)
182
+ encoder_layer.linear1 = nn.Identity()
183
+ encoder = nn.TransformerEncoder(encoder_layer, enc_layers, norm=nn.LayerNorm(dim, eps=LAYERNORM_EPS) if norm_first else None)
184
+
185
+ # define custom decoder
186
+ decoder_layer = nn.TransformerDecoderLayer(dim, nhead, dim_ff,
187
+ activation=FNNSwiGLU(dim, dim_ff), dropout=dropout,
188
+ batch_first=True, norm_first=norm_first, layer_norm_eps=LAYERNORM_EPS)
189
+ decoder_layer.linear1 = nn.Identity()
190
+ decoder = nn.TransformerDecoder(decoder_layer, dec_layers, norm=nn.LayerNorm(dim, eps=LAYERNORM_EPS) if norm_first else None)
191
+
192
+ # monkeypatch for broken copy.deepcopy of nn.Modules in nn.TransformerDecoder
193
+ for l in decoder.layers: l.activation = FNNSwiGLU(dim, dim_ff)
194
+
195
+ self.tfm = nn.Transformer(dim, nhead, dim_feedforward=dim_ff, batch_first=True,
196
+ norm_first=norm_first,
197
+ num_encoder_layers=enc_layers,
198
+ num_decoder_layers=dec_layers,
199
+ custom_encoder=encoder,
200
+ custom_decoder=decoder,
201
+ layer_norm_eps=LAYERNORM_EPS,
202
+ dropout=dropout
203
+ )
204
+ # Timestep embedding network
205
+ self.t_emb_dim = t_emb_dim
206
+ self.timestep_encoder_emb = nn.Sequential(
207
+ nn.Linear(t_emb_dim, dim),
208
+ nn.SiLU(),
209
+ nn.Linear(dim, dim)
210
+ )
211
+ self.timestep_decoder_emb = nn.Sequential(
212
+ nn.Linear(t_emb_dim, dim),
213
+ nn.SiLU(),
214
+ nn.Linear(dim, dim)
215
+ )
216
+
217
+ self.text_embed = nn.Embedding(n_text_vocab, dim)
218
+
219
+ ## ----> reference / conditioning encoder:
220
+ self.ref_embedder = ChunkedEmbedding(n_quant, c_quant_levels, dim)
221
+ self.ref_pos_embedding = SinePositionalEmbedding(dim, scale=False, alpha=True)
222
+ self.spk_identity_emb = nn.Embedding(1, dim)
223
+ spk_encoder_layer = nn.TransformerEncoderLayer(dim, nhead, dim_ff,
224
+ activation=FNNSwiGLU(dim, dim_ff), dropout=dropout,
225
+ batch_first=True, norm_first=True, layer_norm_eps=LAYERNORM_EPS)
226
+ spk_encoder_layer.linear1 = nn.Identity()
227
+ self.spk_encoder = nn.TransformerEncoder(spk_encoder_layer, n_spk_layers, norm=nn.LayerNorm(dim, eps=LAYERNORM_EPS))
228
+ # monkeypatch for broken copy.deepcopy of nn.Modules in nn.TransformerDecoder
229
+ for l in self.spk_encoder.layers: l.activation = FNNSwiGLU(dim, dim_ff)
230
+ # ----> end speaker encoder network
231
+
232
+ # self.residual_encoder = nn.Embedding(n_quant, dim) # only encode first quantization level of decoder input.
233
+ self.residual_encoder = ChunkedEmbedding(n_quant, c_quant_levels, dim)
234
+
235
+ self.residual_decoder = nn.ModuleList([
236
+ nn.Sequential(
237
+ nn.LayerNorm(dim),
238
+ nn.Linear(dim, n_quant)
239
+ ) for i in range(pred_quant_levels)
240
+ ])
241
+ self.n_quantizer = pred_quant_levels
242
+ self.p_cond_drop = p_cond_drop
243
+
244
+
245
+ @torch.inference_mode
246
+ def get_spk_embedding(self, c_codes, c_codes_length) -> Tensor:
247
+ """ Obtain speaker embedding vectors using `c_codes` from reference encodec sequences, and `c_codes_length` of lengths for each sequence """
248
+ bs = c_codes.shape[0]
249
+ spk_seq = self.ref_embedder(c_codes) # (bs, sl, dim)
250
+ spk_ref_emb = self.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim)
251
+ spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim)
252
+ # add pos encoding
253
+ spk_seq = self.ref_pos_embedding(spk_seq)
254
+
255
+ # add 1 to c_codes_length to account for the fact that we concatenate the spk_ref_emb to it.
256
+ src_key_padding_mask = length_to_mask(c_codes_length+1, torch.zeros_like(c_codes_length), max_len=spk_seq.shape[1])
257
+ src_key_padding_mask = src_key_padding_mask.to(dtype=torch.bool, device=spk_seq.device)
258
+
259
+ # pass through transformer
260
+ res = self.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim).
261
+ return res.squeeze(1)
262
+
263
+
264
+ def forward(self, c_text: Tensor, c_codes: Tensor, c_texts_length: Tensor, c_codes_length: Tensor,
265
+ x: Tensor, x_padding_mask: Tensor, t: Tensor, drop_cond=False):
266
+ """ Input:
267
+ - `c_text`: (bs, seq_len1) the prompt text (BPE encoded)
268
+ - `c_codes`: (bs, seq_len2, n_quant) the full tokenized codes of the reference speech
269
+ - `c_texts_length`: (bs, ) the length of the codes in the text prompt
270
+ - `c_codes_length`: (bs, ) the length of the prompt acoustic token codes in `c_codes`.
271
+ - `x`: (bs, seq_len3) L0 residual codes
272
+ - `x`: (bs, seq_len3, n_quant) L0 residual codes
273
+ - `x_padding_mask`: (bs, seq_len3) masking for residual codes
274
+ - `t`: (bs) timestep
275
+ - `drop_cond`: bool, whether or not to forcibly drop the conditioning information.
276
+ Returns:
277
+ - outs: (bs, seq_len, n_quantizer, codebook_size)
278
+ """
279
+
280
+ c_text = self.text_embed(c_text) # (bs, seq_len1, dim)
281
+
282
+ ## ----> reference / conditioning encoder:
283
+ bs = c_codes.shape[0]
284
+
285
+
286
+ if self.training:
287
+ zero_cond_inds = torch.rand_like(t, dtype=c_text.dtype) < self.p_cond_drop
288
+ else:
289
+ # never randomly zero when in eval mode
290
+ zero_cond_inds = torch.zeros_like(t, dtype=torch.bool)
291
+ if drop_cond:
292
+ # force drop conditioning
293
+ zero_cond_inds = torch.ones_like(t, dtype=torch.bool)
294
+
295
+ c_codes_length[zero_cond_inds] = 0
296
+ c_codes[zero_cond_inds] = 1024
297
+
298
+ spk_seq = self.ref_embedder(c_codes) # (bs, sl, dim)
299
+ spk_ref_emb = self.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim)
300
+ spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim)
301
+ # add pos encoding
302
+ spk_seq = self.ref_pos_embedding(spk_seq)
303
+
304
+ # add 1 to c_codes_length to account for the fact that we concatenate the spk_ref_emb to it.
305
+ src_key_padding_mask = length_to_mask(c_codes_length+1, torch.zeros_like(c_codes_length), max_len=spk_seq.shape[1])
306
+ src_key_padding_mask = src_key_padding_mask.to(dtype=torch.bool, device=spk_seq.device)
307
+
308
+ # pass through transformer
309
+ res = self.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim).
310
+ c_codes = res # (bs, 1, dim)
311
+ c_codes_lengths_extract = torch.ones_like(c_codes_length) # manually override all the code lengths to equal 1, since we only have 1 spk embedding.
312
+ ## ----> end reference / conditioning encoder:
313
+
314
+ ## ----> timestep embeddings and parsing
315
+ t_emb = timestep_embedding(t, self.t_emb_dim, dtype=c_text.dtype)
316
+ t_emb_encoder = self.timestep_encoder_emb(t_emb) # (bs, t_dim)
317
+ t_emb_decoder = self.timestep_decoder_emb(t_emb)
318
+
319
+ ## ----> concatenating text/phone inputs and implicit speaker embedding.
320
+ c_phones_unpacked = nn.utils.rnn.unpad_sequence(c_text, c_texts_length.cpu(), batch_first=True)
321
+ c_codes_unpacked = nn.utils.rnn.unpad_sequence(c_codes, c_codes_lengths_extract.cpu(), batch_first=True)
322
+ # >>> Concat [speaker codes, text codes]
323
+ assert all(b.shape[0] == 1 for b in c_codes_unpacked)
324
+ c_joined = [torch.cat((b, a), dim=0) for a, b in zip(c_phones_unpacked, c_codes_unpacked)]
325
+
326
+ c = nn.utils.rnn.pad_sequence(c_joined, batch_first=True)
327
+ c_joined_lengths = torch.tensor([p.shape[0] for p in c_joined], device=c.device, dtype=torch.long)
328
+ c_padding_mask = length_to_mask(c_joined_lengths, torch.zeros_like(c_joined_lengths))
329
+ c = self.cond_pos_embedding(c)
330
+
331
+ ## Format input:
332
+ x = self.residual_encoder(x) # (bs, seq_len3, dim)
333
+
334
+ x = self.pos_embedding(x)
335
+
336
+ x = x + t_emb_decoder[:, None]
337
+ c = c + t_emb_encoder[:, None]
338
+ ## Perform prediction:
339
+ output = self.tfm(c, x, src_key_padding_mask=c_padding_mask,
340
+ tgt_key_padding_mask=x_padding_mask,
341
+ memory_key_padding_mask=c_padding_mask) # (bs, seq_len, dim)
342
+ outs = torch.stack([self.residual_decoder[i](output) for i in range(self.n_quantizer)], dim=-1) # (bs, seq_len, logit_dim, n_quant)
343
+ return outs
344
+
mars5/nn_future.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+ import math
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+
9
+
10
+ # --------------------------
11
+ # activation functions
12
+
13
+ class FNNSwiGLU(nn.Module):
14
+
15
+ def __init__(self, dim, dim_ff) -> None:
16
+ super().__init__()
17
+
18
+ # we will receive in xW
19
+ self.V = nn.Linear(dim, dim_ff, bias=False)
20
+ self.W = nn.Linear(dim, dim_ff, bias=False)
21
+
22
+
23
+ def forward(self, x: Tensor) -> Tensor:
24
+ """ Compute SwiGLU output of x, the output of the first linear layer. i.e.
25
+ FFNSwiGLU(x, W, V, W2) = (Swish1(xW) ⊗ xV )W2.
26
+ NOTE: the transformer linear1 layer must be overwritten to identity. This layer only applies
27
+ the Swish(xW) * xV. The W2 multiplication is done in the main transformer layer
28
+ """
29
+ return F.silu(self.W(x)) * self.V(x)
30
+
31
+
32
+ # ---------------------------------
33
+ # padding and position layers
34
+
35
+ class SinePositionalEmbedding(nn.Module):
36
+ def __init__(
37
+ self,
38
+ dim_model: int,
39
+ dropout: float = 0.0,
40
+ scale: bool = False,
41
+ alpha: bool = False,
42
+ ):
43
+ super().__init__()
44
+ self.dim_model = dim_model
45
+ self.x_scale = math.sqrt(dim_model) if scale else 1.0
46
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
47
+ self.dropout = torch.nn.Dropout(p=dropout)
48
+
49
+ self.reverse = False
50
+ self.pe = None
51
+ self.extend_pe(torch.tensor(0.0).expand(1, 4000))
52
+
53
+ def extend_pe(self, x):
54
+ """Reset the positional encodings."""
55
+ if self.pe is not None:
56
+ if self.pe.size(1) >= x.size(1):
57
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
58
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
59
+ return
60
+ pe = torch.zeros(x.size(1), self.dim_model)
61
+ if self.reverse:
62
+ position = torch.arange(
63
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
64
+ ).unsqueeze(1)
65
+ else:
66
+ position = torch.arange(
67
+ 0, x.size(1), dtype=torch.float32
68
+ ).unsqueeze(1)
69
+ div_term = torch.exp(
70
+ torch.arange(0, self.dim_model, 2, dtype=torch.float32)
71
+ * -(math.log(10000.0) / self.dim_model)
72
+ )
73
+ pe[:, 0::2] = torch.sin(position * div_term)
74
+ pe[:, 1::2] = torch.cos(position * div_term)
75
+ pe = pe.unsqueeze(0)
76
+ self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
77
+
78
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
79
+ """ Assumes x of shape (bs, seq_len, dim) """
80
+ self.extend_pe(x)
81
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
82
+ output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
83
+ return self.dropout(output)
84
+
85
+
86
+ # --------------------------------
87
+ # kv cache blocks
88
+
89
+ class CacheView:
90
+ def __init__(self, cache_k: torch.Tensor, cache_v: torch.Tensor):
91
+ self.cache_k = cache_k
92
+ self.cache_v = cache_v
93
+
94
+ @property
95
+ def sliding_window(self):
96
+ return self.cache_k.shape[1]
97
+
98
+ class RotatingBufferCache:
99
+ """
100
+ This is an example that implements a less naive rotating buffer cache, allowing for variable length sequences.
101
+ Allocated cache is rectangular which is wasteful (see PagedAttention for better mechanisms)
102
+ """
103
+ def __init__(self, n_layers: int, max_batch_size: int, sliding_window: int, n_kv_heads: int, head_dim: int):
104
+
105
+ self.sliding_window = sliding_window
106
+ self.n_kv_heads = n_kv_heads
107
+ self.head_dim = head_dim
108
+
109
+ self.cache_k = torch.empty((
110
+ n_layers,
111
+ max_batch_size,
112
+ sliding_window,
113
+ n_kv_heads,
114
+ head_dim
115
+ ))
116
+ self.cache_v = torch.empty((
117
+ n_layers,
118
+ max_batch_size,
119
+ sliding_window,
120
+ n_kv_heads,
121
+ head_dim
122
+ ))
123
+
124
+ def get_view(self, layer_id: int) -> CacheView:
125
+ return CacheView(self.cache_k[layer_id], self.cache_v[layer_id])
126
+
127
+ @property
128
+ def device(self):
129
+ return self.cache_k.device
130
+
131
+ def to(self, device: torch.device, dtype: torch.dtype):
132
+ self.cache_k = self.cache_k.to(device=device, dtype=dtype)
133
+ self.cache_v = self.cache_v.to(device=device, dtype=dtype)
134
+ return self
135
+
136
+
137
+ # --------------------------------
138
+ # Mistral transformer blocks
139
+ # Code for the follow blocks are adapted from
140
+ # https://github.com/mistralai/mistral-src
141
+ # Thank you Mistral team!
142
+
143
+ @dataclass
144
+ class ModelArgs:
145
+ vocab_size: int
146
+
147
+ dim: int = 1152 # default for mars3 and before: 1024
148
+ n_layers: int = 24
149
+ head_dim: int = 64 # = dim/n_heads
150
+ hidden_dim: int = 3584
151
+ n_heads: int = 16
152
+ n_kv_heads: int = 16 # default: 8
153
+ sliding_window: int = 1792
154
+ norm_eps: float = 1e-5
155
+
156
+ max_batch_size: int = 256
157
+
158
+
159
+ def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int):
160
+ if repeats == 1: return keys, values
161
+ keys = torch.repeat_interleave(keys, repeats=repeats, dim=2)
162
+ values = torch.repeat_interleave(values, repeats=repeats, dim=2)
163
+ return keys, values
164
+
165
+
166
+ def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
167
+ """
168
+ freqs_cis: complex - (seq_len, head_dim / 2)
169
+ x: complex - (bsz, seq_len, head_dim / 2)
170
+ """
171
+ ndim = x.ndim
172
+ assert 1 < ndim
173
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
174
+ freqs_cis.shape,
175
+ (x.shape[1], x.shape[-1]),
176
+ )
177
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
178
+ return freqs_cis.view(*shape)
179
+
180
+
181
+ def apply_rotary_emb(
182
+ xq: torch.Tensor,
183
+ xk: torch.Tensor,
184
+ freqs_cis: torch.Tensor,
185
+ ) -> tuple[torch.Tensor, torch.Tensor]:
186
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
187
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
188
+ freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
189
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
190
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
191
+ return xq_out.type_as(xq), xk_out.type_as(xk)
192
+
193
+
194
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
195
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
196
+ t = torch.arange(end, device=freqs.device) # type: ignore
197
+ freqs = torch.outer(t, freqs).float() # type: ignore
198
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
199
+
200
+
201
+ class Attention(nn.Module):
202
+ def __init__(self, args: ModelArgs):
203
+ super().__init__()
204
+ self.args = args
205
+
206
+ self.n_heads: int = args.n_heads
207
+ self.n_kv_heads: int = args.n_kv_heads
208
+
209
+ self.repeats = self.n_heads // self.n_kv_heads
210
+ self.sliding_window = self.args.sliding_window
211
+
212
+ self.scale = self.args.head_dim**-0.5
213
+
214
+ self.wq = nn.Linear(
215
+ args.dim,
216
+ args.n_heads * args.head_dim,
217
+ bias=False
218
+ )
219
+ self.wk = nn.Linear(
220
+ args.dim,
221
+ args.n_kv_heads * args.head_dim,
222
+ bias=False
223
+ )
224
+ self.wv = nn.Linear(
225
+ args.dim,
226
+ args.n_kv_heads * args.head_dim,
227
+ bias=False
228
+ )
229
+ self.wo = nn.Linear(
230
+ args.n_heads * args.head_dim,
231
+ args.dim,
232
+ bias=False
233
+ )
234
+
235
+ def forward(
236
+ self, x: torch.Tensor, freqs_cis: torch.Tensor, positions: torch.Tensor, mask: Optional[torch.Tensor], cache: Optional[CacheView]
237
+ ) -> torch.Tensor:
238
+
239
+ bsz, seqlen, _ = x.shape
240
+
241
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
242
+ xq = xq.view(bsz, seqlen, self.n_heads, self.args.head_dim)
243
+ xk = xk.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim)
244
+ xv = xv.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim)
245
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
246
+
247
+ # The cache is a rotating buffer
248
+ if cache is not None:
249
+ scatter_pos = (positions[-self.sliding_window:] % self.sliding_window)[None, :, None, None]
250
+ scatter_pos = scatter_pos.repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)
251
+ cache.cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk[:, -self.sliding_window:])
252
+ cache.cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv[:, -self.sliding_window:])
253
+
254
+ if positions.shape[0] > 1:
255
+ # prefill
256
+ key, value = repeat_kv(xk, xv, self.repeats)
257
+ else:
258
+ cur_pos = positions[-1].item() + 1
259
+ key, value = repeat_kv(cache.cache_k[:bsz, :cur_pos, ...], cache.cache_v[:bsz, :cur_pos, ...], self.repeats)
260
+
261
+ # print(f"Internal: {xq.shape}, key: {key.shape}, mask: {mask.shape} | {mask.dtype} | xq: {xq.dtype} | mask: {mask} ")
262
+ # if mask is not None:
263
+ # mask = mask[None, None, ...].expand(bsz, self.n_heads, -1, -1)
264
+ # mask = mask.to(key.dtype)
265
+
266
+ query = xq.transpose(1, 2)
267
+ key = key.transpose(1, 2)
268
+ value = value.transpose(1, 2)
269
+ # # scores : [bsz, n_heads, seqlen | 1, seqlen]
270
+ # scores = torch.matmul(query, key.transpose(2, 3)) * self.scale
271
+
272
+ output = F.scaled_dot_product_attention(query, key, value, mask) # (bs, n_local_heads, slen, head_dim)
273
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
274
+ return self.wo(output)
275
+
276
+
277
+ class FeedForward(nn.Module):
278
+ def __init__(self, args: ModelArgs):
279
+ super().__init__()
280
+
281
+ self.w1 = nn.Linear(
282
+ args.dim,
283
+ args.hidden_dim,
284
+ bias=False
285
+ )
286
+ self.w2 = nn.Linear(
287
+ args.hidden_dim,
288
+ args.dim,
289
+ bias=False
290
+ )
291
+ self.w3 = nn.Linear(
292
+ args.dim,
293
+ args.hidden_dim,
294
+ bias=False
295
+ )
296
+
297
+ def forward(self, x) -> torch.Tensor:
298
+ return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
299
+
300
+
301
+ class RMSNorm(torch.nn.Module):
302
+ def __init__(self, dim: int, eps: float = 1e-6):
303
+ super().__init__()
304
+ self.eps = eps
305
+ self.weight = nn.Parameter(torch.ones(dim))
306
+
307
+ def _norm(self, x):
308
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
309
+
310
+ def forward(self, x):
311
+ output = self._norm(x.float()).type_as(x)
312
+ return output * self.weight
313
+
314
+
315
+ class TransformerBlock(nn.Module):
316
+ def __init__(self, args: ModelArgs):
317
+ super().__init__()
318
+ self.n_heads = args.n_heads
319
+ self.dim = args.dim
320
+ self.attention = Attention(args)
321
+ self.feed_forward = FeedForward(args=args)
322
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
323
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
324
+ self.args = args
325
+
326
+ def forward(
327
+ self, x: torch.Tensor, freqs_cis: torch.Tensor, positions: torch.Tensor, mask: Optional[torch.Tensor], cache: Optional[CacheView]
328
+ ) -> torch.Tensor:
329
+ r = self.attention.forward(self.attention_norm(x), freqs_cis, positions, mask, cache)
330
+ h = x + r
331
+ r = self.feed_forward.forward(self.ffn_norm(h))
332
+ out = h + r
333
+ return out
334
+
335
+
336
+ class MistralTransformer(nn.Module):
337
+ def __init__(self, args: ModelArgs):
338
+ super().__init__()
339
+ self.args = args
340
+ self.vocab_size = args.vocab_size
341
+ self.n_layers = args.n_layers
342
+ assert self.vocab_size > 0
343
+
344
+ # self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
345
+
346
+ self.layers = torch.nn.ModuleList(
347
+ [TransformerBlock(args=args) for _ in range(args.n_layers)]
348
+ )
349
+
350
+ self.norm = RMSNorm(args.dim, eps=args.norm_eps)
351
+
352
+ self.output = nn.Linear(
353
+ args.dim,
354
+ args.vocab_size,
355
+ bias=False
356
+ )
357
+
358
+ # self.freqs_cis
359
+ self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000)
360
+
361
+ @property
362
+ def dtype(self) -> torch.dtype:
363
+ return self.tok_embeddings.weight.dtype
364
+
365
+ @property
366
+ def device(self) -> torch.device:
367
+ return self.tok_embeddings.weight.device
368
+
369
+ def forward(
370
+ self,
371
+ input_ids: torch.Tensor,
372
+ positions: torch.Tensor,
373
+ cache: Optional[RotatingBufferCache]
374
+ ):
375
+ h = input_ids
376
+ if self.freqs_cis.device != h.device:
377
+ self.freqs_cis = self.freqs_cis.to(h.device)
378
+ freqs_cis = self.freqs_cis[positions]
379
+
380
+ mask: Optional[torch.Tensor] = None
381
+ if input_ids.shape[1] > 1:
382
+ seqlen = input_ids.shape[1]
383
+ tensor = torch.full(
384
+ (seqlen, seqlen),
385
+ dtype=h.dtype,
386
+ fill_value=1,
387
+ device=h.device,
388
+ )
389
+ mask = torch.tril(tensor, diagonal=0).to(h.dtype)
390
+ # make the mask banded to account for sliding window
391
+ mask = torch.triu(mask, diagonal=-self.args.sliding_window)
392
+ mask = torch.log(mask)
393
+
394
+ for layer_id, layer in enumerate(self.layers):
395
+ cache_view = None if cache is None else cache.get_view(layer_id)
396
+ h = layer(h, freqs_cis, positions, mask, cache_view)
397
+
398
+ return self.output(self.norm(h))
399
+
400
+
mars5/samplers.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code for modifying categorical distributions to improve quality of sampling.
3
+
4
+ Adapted from:
5
+ - https://github.com/e-c-k-e-r/vall-e/blob/master/vall_e/samplers.py
6
+ - Mirosoft UniLM
7
+ - Matthew Baas's typical sampling code.
8
+ - https://github.com/LostRuins/koboldcpp
9
+ """
10
+
11
+ import math
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import numpy as np
15
+ import logging
16
+
17
+ from torch import Tensor, nn
18
+
19
+
20
+ def freq_rep_penalty(logits: Tensor, previous: Tensor, alpha_frequency: float, alpha_presence: float, penalty_window: int = 100) -> Tensor:
21
+ """ Apply frequency and presence penalty according to openai's formuation.
22
+ Concretely: given `logits` (bs, vocab_size) and `previous` (bs, seq_len,)
23
+
24
+ Modified to support batched inference.
25
+
26
+ See: https://platform.openai.com/docs/guides/text-generation/parameter-details
27
+ """
28
+ bs = logits.shape[0]
29
+ previous = previous[..., -penalty_window:]
30
+ c = torch.zeros_like(logits, device=logits.device, dtype=torch.long) # (1, vocab_size)
31
+ for i in range(bs):
32
+ vals, cnts = previous[i].unique(return_counts=True)
33
+ c[i, vals] = cnts.to(c.device)
34
+
35
+ logits = logits - c * alpha_frequency - (c > 0).to(logits.dtype) * alpha_presence
36
+ return logits
37
+
38
+
39
+ def early_eos_penalty(logits: Tensor, n_generated: int, estimated_gen_length: int, decay: float, factor: float = 1, eos_index: int = 0) -> Tensor:
40
+ """ Penalize the `eos_index` of `logits` (bs, vocab_size) up to `estimated_gen_length`,
41
+ whereby we reduce the logit value by `factor`*(expected_length - current_length)^decay,
42
+ `n_generated` is the current number of generated samples. `decay` anneals the penalty relative to the distance.
43
+
44
+ Good values for decay are between 0 and 1. 0 = hard always apply penalty of 1, 1 = linearly scale penalty relative to distance.
45
+ Setting factor = 0 disabled penatly. Increasing factor increases penalty.
46
+ """
47
+ if n_generated > estimated_gen_length: return logits
48
+ penalty = max(estimated_gen_length - n_generated, 1)
49
+
50
+ bigger = logits[:, eos_index] > 0
51
+
52
+ modifier = factor*(penalty ** decay)
53
+ # logits[bigger, eos_index] /= modifier
54
+ # logits[~bigger, eos_index] *= modifier
55
+ logits[:, eos_index] -= modifier
56
+ return logits
57
+
58
+
59
+ # Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 /
60
+ # https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
61
+ def top_k_top_p_filtering( logits: Tensor, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ) -> Tensor:
62
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
63
+ Args:
64
+ logits: logits distribution shape (batch size, vocabulary size)
65
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
66
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
67
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
68
+ Make sure we keep at least min_tokens per batch example in the output
69
+ """
70
+ if top_k > 0:
71
+ top_k = min(max(top_k, min_tokens), logits.size(-1)) # Safety check
72
+ # Remove all tokens with a probability less than the last token of the top-k
73
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
74
+ logits[indices_to_remove] = filter_value
75
+
76
+ if top_p < 1.0:
77
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
78
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
79
+
80
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
81
+ sorted_indices_to_remove = cumulative_probs > top_p
82
+ if min_tokens > 1:
83
+ # Keep at least min_tokens (set to min_tokens-1 because we add the first one below)
84
+ sorted_indices_to_remove[..., :min_tokens] = 0
85
+ # Shift the indices to the right to keep also the first token above the threshold
86
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
87
+ sorted_indices_to_remove[..., 0] = 0
88
+
89
+ # scatter sorted tensors to original indexing
90
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
91
+ logits[indices_to_remove] = filter_value
92
+
93
+ return logits
94
+
95
+
96
+ def apply_typical_p(logprobs: Tensor, mass: float) -> Tensor:
97
+ """ Warp categorical logprobs associated with `x` to be in line with `mass`. Last dimension is the bin dimension.
98
+ `mass` corresponds to `tau` in the paper.
99
+ """
100
+ if mass > 0.999: return logprobs
101
+ # see: https://arxiv.org/abs/2202.00666
102
+ # calculate entropy
103
+ # normalized = logprobs #torch.nn.functional.log_softmax(scores, dim=-1)
104
+ normalized = torch.nn.functional.log_softmax(logprobs, dim=-1)
105
+ p = torch.exp(normalized)
106
+ ent = -(normalized * p).nansum(-1, keepdim=True)
107
+
108
+ # shift and sort
109
+ shifted_scores = torch.abs((-normalized) - ent)
110
+ sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
111
+ sorted_logits = logprobs.gather(-1, sorted_indices)
112
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
113
+
114
+ # Remove tokens with cumulative mass above the threshold
115
+ last_ind = (cumulative_probs < mass).sum(dim=1)
116
+ last_ind[last_ind < 0] = 0
117
+ sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
118
+
119
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
120
+
121
+ scores = logprobs.masked_fill(indices_to_remove, -float('Inf'))
122
+ return scores
mars5/trim.py ADDED
@@ -0,0 +1,741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Custom port of librosa trim code, to remove numba dependency.
2
+ This allows us to use librosa.trim effect without the librosa or numba dependancy.
3
+
4
+ All code below adapted from librosa open source github:
5
+ """
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import warnings
11
+
12
+
13
+ def amplitude_to_db(S, ref=1.0, amin=1e-5, top_db=80.0):
14
+ """Convert an amplitude spectrogram to dB-scaled spectrogram.
15
+
16
+ This is equivalent to ``power_to_db(S**2)``, but is provided for convenience.
17
+
18
+ Parameters
19
+ ----------
20
+ S : np.ndarray
21
+ input amplitude
22
+
23
+ ref : scalar or callable
24
+ If scalar, the amplitude ``abs(S)`` is scaled relative to ``ref``:
25
+ ``20 * log10(S / ref)``.
26
+ Zeros in the output correspond to positions where ``S == ref``.
27
+
28
+ If callable, the reference value is computed as ``ref(S)``.
29
+
30
+ amin : float > 0 [scalar]
31
+ minimum threshold for ``S`` and ``ref``
32
+
33
+ top_db : float >= 0 [scalar]
34
+ threshold the output at ``top_db`` below the peak:
35
+ ``max(20 * log10(S)) - top_db``
36
+
37
+
38
+ Returns
39
+ -------
40
+ S_db : np.ndarray
41
+ ``S`` measured in dB
42
+
43
+ See Also
44
+ --------
45
+ power_to_db, db_to_amplitude
46
+
47
+ Notes
48
+ -----
49
+ This function caches at level 30.
50
+ """
51
+
52
+ # S = np.asarray(S)
53
+ S = torch.asarray(S)
54
+
55
+
56
+ magnitude = S.abs()
57
+
58
+ if callable(ref):
59
+ # User supplied a function to calculate reference power
60
+ ref_value = ref(magnitude)
61
+ else:
62
+ ref_value = torch.abs(ref)
63
+
64
+ power = torch.square(magnitude, out=magnitude)
65
+
66
+ return power_to_db(power, ref=ref_value ** 2, amin=amin ** 2, top_db=top_db)
67
+
68
+
69
+ def _signal_to_frame_nonsilent(
70
+ y, frame_length=2048, hop_length=512, top_db=60, ref=torch.max
71
+ ):
72
+ """Frame-wise non-silent indicator for audio input.
73
+
74
+ This is a helper function for `trim` and `split`.
75
+
76
+ Parameters
77
+ ----------
78
+ y : np.ndarray, shape=(n,) or (2,n)
79
+ Audio signal, mono or stereo
80
+
81
+ frame_length : int > 0
82
+ The number of samples per frame
83
+
84
+ hop_length : int > 0
85
+ The number of samples between frames
86
+
87
+ top_db : number > 0
88
+ The threshold (in decibels) below reference to consider as
89
+ silence
90
+
91
+ ref : callable or float
92
+ The reference power
93
+
94
+ Returns
95
+ -------
96
+ non_silent : np.ndarray, shape=(m,), dtype=bool
97
+ Indicator of non-silent frames
98
+ """
99
+ # Convert to mono
100
+ if y.ndim > 1:
101
+ y_mono = torch.mean(y, dim=0)
102
+ else: y_mono = y
103
+
104
+ # Compute the MSE for the signal
105
+ mse = rms(y=y_mono, frame_length=frame_length, hop_length=hop_length) ** 2
106
+
107
+ return power_to_db(mse.squeeze(), ref=ref, top_db=None) > -top_db
108
+
109
+
110
+ def trim(y, top_db=60, ref=torch.max, frame_length=2048, hop_length=512):
111
+ """Trim leading and trailing silence from an audio signal.
112
+
113
+ Parameters
114
+ ----------
115
+ y : np.ndarray, shape=(n,) or (2,n)
116
+ Audio signal, can be mono or stereo
117
+
118
+ top_db : number > 0
119
+ The threshold (in decibels) below reference to consider as
120
+ silence
121
+
122
+ ref : number or callable
123
+ The reference power. By default, it uses `np.max` and compares
124
+ to the peak power in the signal.
125
+
126
+ frame_length : int > 0
127
+ The number of samples per analysis frame
128
+
129
+ hop_length : int > 0
130
+ The number of samples between analysis frames
131
+
132
+ Returns
133
+ -------
134
+ y_trimmed : np.ndarray, shape=(m,) or (2, m)
135
+ The trimmed signal
136
+
137
+ index : np.ndarray, shape=(2,)
138
+ the interval of ``y`` corresponding to the non-silent region:
139
+ ``y_trimmed = y[index[0]:index[1]]`` (for mono) or
140
+ ``y_trimmed = y[:, index[0]:index[1]]`` (for stereo).
141
+
142
+
143
+ Examples
144
+ --------
145
+ >>> # Load some audio
146
+ >>> y, sr = librosa.load(librosa.ex('choice'))
147
+ >>> # Trim the beginning and ending silence
148
+ >>> yt, index = librosa.effects.trim(y)
149
+ >>> # Print the durations
150
+ >>> print(librosa.get_duration(y), librosa.get_duration(yt))
151
+ 25.025986394557822 25.007891156462584
152
+ """
153
+
154
+ non_silent = _signal_to_frame_nonsilent(
155
+ y, frame_length=frame_length, hop_length=hop_length, ref=ref, top_db=top_db
156
+ )
157
+
158
+ # nonzero = np.flatnonzero(non_silent)
159
+ nonzero = torch.nonzero(torch.ravel(non_silent)).squeeze()#[0]
160
+
161
+ if nonzero.numel() > 0:
162
+ # Compute the start and end positions
163
+ # End position goes one frame past the last non-zero
164
+ start = int(frames_to_samples(nonzero[0], hop_length))
165
+ end = min(y.shape[-1], int(frames_to_samples(nonzero[-1] + 1, hop_length)))
166
+ else:
167
+ # The signal only contains zeros
168
+ start, end = 0, 0
169
+
170
+ # Build the mono/stereo index
171
+ full_index = [slice(None)] * y.ndim
172
+ full_index[-1] = slice(start, end)
173
+
174
+ # print(non_silent)
175
+ # print(non_silent.shape, nonzero.shape)
176
+
177
+ return y[tuple(full_index)], torch.asarray([start, end])
178
+
179
+
180
+ def rms(
181
+ y=None, S=None, frame_length=2048, hop_length=512, center=True, pad_mode="reflect"
182
+ ):
183
+ """Compute root-mean-square (RMS) value for each frame, either from the
184
+ audio samples ``y`` or from a spectrogram ``S``.
185
+
186
+ Computing the RMS value from audio samples is faster as it doesn't require
187
+ a STFT calculation. However, using a spectrogram will give a more accurate
188
+ representation of energy over time because its frames can be windowed,
189
+ thus prefer using ``S`` if it's already available.
190
+
191
+
192
+ Parameters
193
+ ----------
194
+ y : np.ndarray [shape=(n,)] or None
195
+ (optional) audio time series. Required if ``S`` is not input.
196
+
197
+ S : np.ndarray [shape=(d, t)] or None
198
+ (optional) spectrogram magnitude. Required if ``y`` is not input.
199
+
200
+ frame_length : int > 0 [scalar]
201
+ length of analysis frame (in samples) for energy calculation
202
+
203
+ hop_length : int > 0 [scalar]
204
+ hop length for STFT. See `librosa.stft` for details.
205
+
206
+ center : bool
207
+ If `True` and operating on time-domain input (``y``), pad the signal
208
+ by ``frame_length//2`` on either side.
209
+
210
+ If operating on spectrogram input, this has no effect.
211
+
212
+ pad_mode : str
213
+ Padding mode for centered analysis. See `numpy.pad` for valid
214
+ values.
215
+
216
+ Returns
217
+ -------
218
+ rms : np.ndarray [shape=(1, t)]
219
+ RMS value for each frame
220
+
221
+
222
+ Examples
223
+ --------
224
+ >>> y, sr = librosa.load(librosa.ex('trumpet'))
225
+ >>> librosa.feature.rms(y=y)
226
+ array([[1.248e-01, 1.259e-01, ..., 1.845e-05, 1.796e-05]],
227
+ dtype=float32)
228
+
229
+ Or from spectrogram input
230
+
231
+ >>> S, phase = librosa.magphase(librosa.stft(y))
232
+ >>> rms = librosa.feature.rms(S=S)
233
+
234
+ >>> import matplotlib.pyplot as plt
235
+ >>> fig, ax = plt.subplots(nrows=2, sharex=True)
236
+ >>> times = librosa.times_like(rms)
237
+ >>> ax[0].semilogy(times, rms[0], label='RMS Energy')
238
+ >>> ax[0].set(xticks=[])
239
+ >>> ax[0].legend()
240
+ >>> ax[0].label_outer()
241
+ >>> librosa.display.specshow(librosa.amplitude_to_db(S, ref=np.max),
242
+ ... y_axis='log', x_axis='time', ax=ax[1])
243
+ >>> ax[1].set(title='log Power spectrogram')
244
+
245
+ Use a STFT window of constant ones and no frame centering to get consistent
246
+ results with the RMS computed from the audio samples ``y``
247
+
248
+ >>> S = librosa.magphase(librosa.stft(y, window=np.ones, center=False))[0]
249
+ >>> librosa.feature.rms(S=S)
250
+ >>> plt.show()
251
+
252
+ """
253
+ if y is not None:
254
+ if y.dim() > 1:
255
+ y = torch.mean(y, dim=0)
256
+
257
+ if center:
258
+ y = F.pad(y[None, None], (int(frame_length//2), int(frame_length//2)), mode=pad_mode)[0, 0]
259
+ # y = np.pad(y, int(frame_length // 2), mode=pad_mode)
260
+
261
+ x = frame(y, frame_length=frame_length, hop_length=hop_length)
262
+ # print(y.shape, x.shape, x)
263
+ # Calculate power
264
+ power = torch.mean(x.abs() ** 2, dim=0, keepdim=True)
265
+ elif S is not None:
266
+ # Check the frame length
267
+ if S.shape[0] != frame_length // 2 + 1:
268
+ raise AssertionError(
269
+ "Since S.shape[0] is {}, "
270
+ "frame_length is expected to be {} or {}; "
271
+ "found {}".format(
272
+ S.shape[0], S.shape[0] * 2 - 2, S.shape[0] * 2 - 1, frame_length
273
+ )
274
+ )
275
+
276
+ # power spectrogram
277
+ x = torch.abs(S) ** 2
278
+
279
+ # Adjust the DC and sr/2 component
280
+ x[0] *= 0.5
281
+ if frame_length % 2 == 0:
282
+ x[-1] *= 0.5
283
+
284
+ # Calculate power
285
+ power = 2 * torch.sum(x, dim=0, keepdim=True) / frame_length ** 2
286
+ else:
287
+ raise AssertionError("Either `y` or `S` must be input.")
288
+
289
+ return torch.sqrt(power)
290
+
291
+
292
+ def frame(x, frame_length, hop_length, axis=-1):
293
+ """Slice a data array into (overlapping) frames.
294
+
295
+ This implementation uses low-level stride manipulation to avoid
296
+ making a copy of the data. The resulting frame representation
297
+ is a new view of the same input data.
298
+
299
+ However, if the input data is not contiguous in memory, a warning
300
+ will be issued and the output will be a full copy, rather than
301
+ a view of the input data.
302
+
303
+ For example, a one-dimensional input ``x = [0, 1, 2, 3, 4, 5, 6]``
304
+ can be framed with frame length 3 and hop length 2 in two ways.
305
+ The first (``axis=-1``), results in the array ``x_frames``::
306
+
307
+ [[0, 2, 4],
308
+ [1, 3, 5],
309
+ [2, 4, 6]]
310
+
311
+ where each column ``x_frames[:, i]`` contains a contiguous slice of
312
+ the input ``x[i * hop_length : i * hop_length + frame_length]``.
313
+
314
+ The second way (``axis=0``) results in the array ``x_frames``::
315
+
316
+ [[0, 1, 2],
317
+ [2, 3, 4],
318
+ [4, 5, 6]]
319
+
320
+ where each row ``x_frames[i]`` contains a contiguous slice of the input.
321
+
322
+ This generalizes to higher dimensional inputs, as shown in the examples below.
323
+ In general, the framing operation increments by 1 the number of dimensions,
324
+ adding a new "frame axis" either to the end of the array (``axis=-1``)
325
+ or the beginning of the array (``axis=0``).
326
+
327
+
328
+ Parameters
329
+ ----------
330
+ x : np.ndarray
331
+ Array to frame
332
+
333
+ frame_length : int > 0 [scalar]
334
+ Length of the frame
335
+
336
+ hop_length : int > 0 [scalar]
337
+ Number of steps to advance between frames
338
+
339
+ axis : 0 or -1
340
+ The axis along which to frame.
341
+
342
+ If ``axis=-1`` (the default), then ``x`` is framed along its last dimension.
343
+ ``x`` must be "F-contiguous" in this case.
344
+
345
+ If ``axis=0``, then ``x`` is framed along its first dimension.
346
+ ``x`` must be "C-contiguous" in this case.
347
+
348
+ Returns
349
+ -------
350
+ x_frames : np.ndarray [shape=(..., frame_length, N_FRAMES) or (N_FRAMES, frame_length, ...)]
351
+ A framed view of ``x``, for example with ``axis=-1`` (framing on the last dimension)::
352
+
353
+ x_frames[..., j] == x[..., j * hop_length : j * hop_length + frame_length]
354
+
355
+ If ``axis=0`` (framing on the first dimension), then::
356
+
357
+ x_frames[j] = x[j * hop_length : j * hop_length + frame_length]
358
+
359
+ Raises
360
+ ------
361
+ ParameterError
362
+ If ``x`` is not an `np.ndarray`.
363
+
364
+ If ``x.shape[axis] < frame_length``, there is not enough data to fill one frame.
365
+
366
+ If ``hop_length < 1``, frames cannot advance.
367
+
368
+ If ``axis`` is not 0 or -1. Framing is only supported along the first or last axis.
369
+
370
+
371
+ See Also
372
+ --------
373
+ numpy.asfortranarray : Convert data to F-contiguous representation
374
+ numpy.ascontiguousarray : Convert data to C-contiguous representation
375
+ numpy.ndarray.flags : information about the memory layout of a numpy `ndarray`.
376
+
377
+ Examples
378
+ --------
379
+ Extract 2048-sample frames from monophonic signal with a hop of 64 samples per frame
380
+
381
+ >>> y, sr = librosa.load(librosa.ex('trumpet'))
382
+ >>> frames = librosa.util.frame(y, frame_length=2048, hop_length=64)
383
+ >>> frames
384
+ array([[-1.407e-03, -2.604e-02, ..., -1.795e-05, -8.108e-06],
385
+ [-4.461e-04, -3.721e-02, ..., -1.573e-05, -1.652e-05],
386
+ ...,
387
+ [ 7.960e-02, -2.335e-01, ..., -6.815e-06, 1.266e-05],
388
+ [ 9.568e-02, -1.252e-01, ..., 7.397e-06, -1.921e-05]],
389
+ dtype=float32)
390
+ >>> y.shape
391
+ (117601,)
392
+
393
+ >>> frames.shape
394
+ (2048, 1806)
395
+
396
+ Or frame along the first axis instead of the last:
397
+
398
+ >>> frames = librosa.util.frame(y, frame_length=2048, hop_length=64, axis=0)
399
+ >>> frames.shape
400
+ (1806, 2048)
401
+
402
+ Frame a stereo signal:
403
+
404
+ >>> y, sr = librosa.load(librosa.ex('trumpet', hq=True), mono=False)
405
+ >>> y.shape
406
+ (2, 117601)
407
+ >>> frames = librosa.util.frame(y, frame_length=2048, hop_length=64)
408
+ (2, 2048, 1806)
409
+
410
+ Carve an STFT into fixed-length patches of 32 frames with 50% overlap
411
+
412
+ >>> y, sr = librosa.load(librosa.ex('trumpet'))
413
+ >>> S = np.abs(librosa.stft(y))
414
+ >>> S.shape
415
+ (1025, 230)
416
+ >>> S_patch = librosa.util.frame(S, frame_length=32, hop_length=16)
417
+ >>> S_patch.shape
418
+ (1025, 32, 13)
419
+ >>> # The first patch contains the first 32 frames of S
420
+ >>> np.allclose(S_patch[:, :, 0], S[:, :32])
421
+ True
422
+ >>> # The second patch contains frames 16 to 16+32=48, and so on
423
+ >>> np.allclose(S_patch[:, :, 1], S[:, 16:48])
424
+ True
425
+ """
426
+
427
+ # if not isinstance(x, np.ndarray):
428
+ # raise AssertionError(
429
+ # "Input must be of type numpy.ndarray, " "given type(x)={}".format(type(x))
430
+ # )
431
+ x: torch.Tensor = x
432
+
433
+ if x.shape[axis] < frame_length:
434
+ raise AssertionError(
435
+ "Input is too short (n={:d})"
436
+ " for frame_length={:d}".format(x.shape[axis], frame_length)
437
+ )
438
+
439
+ if hop_length < 1:
440
+ raise AssertionError("Invalid hop_length: {:d}".format(hop_length))
441
+
442
+ if axis == -1 and not x.is_contiguous():
443
+ warnings.warn(
444
+ "librosa.util.frame called with axis={} "
445
+ "on a non-contiguous input. This will result in a copy.".format(axis)
446
+ )
447
+ x = x.contiguous()
448
+ elif axis == 0 and not x.is_contiguous():
449
+ warnings.warn(
450
+ "librosa.util.frame called with axis={} "
451
+ "on a non-contiguous input. This will result in a copy.".format(axis)
452
+ )
453
+ x = x.contiguous()
454
+
455
+ n_frames = 1 + (x.shape[axis] - frame_length) // hop_length
456
+ strides = torch.asarray(x.numpy().strides)
457
+ # print(strides, x)
458
+ new_stride = torch.prod(strides[strides > 0] // x.itemsize) * x.itemsize
459
+
460
+ if axis == -1:
461
+ shape = list(x.shape)[:-1] + [frame_length, n_frames]
462
+ strides = list(strides) + [hop_length * new_stride]
463
+
464
+ elif axis == 0:
465
+ shape = [n_frames, frame_length] + list(x.shape)[1:]
466
+ strides = [hop_length * new_stride] + list(strides)
467
+
468
+ else:
469
+ raise AssertionError("Frame axis={} must be either 0 or -1".format(axis))
470
+
471
+ return torch.from_numpy(as_strided(x, shape=shape, strides=strides))
472
+ # return x.as_strided(size=shape, stride=strides)
473
+
474
+
475
+
476
+ class DummyArray:
477
+ """Dummy object that just exists to hang __array_interface__ dictionaries
478
+ and possibly keep alive a reference to a base array.
479
+ """
480
+
481
+ def __init__(self, interface, base=None):
482
+ self.__array_interface__ = interface
483
+ self.base = base
484
+
485
+
486
+
487
+ def as_strided(x, shape=None, strides=None, subok=False, writeable=True):
488
+ """
489
+ Create a view into the array with the given shape and strides.
490
+
491
+ .. warning:: This function has to be used with extreme care, see notes.
492
+
493
+ Parameters
494
+ ----------
495
+ x : ndarray
496
+ Array to create a new.
497
+ shape : sequence of int, optional
498
+ The shape of the new array. Defaults to ``x.shape``.
499
+ strides : sequence of int, optional
500
+ The strides of the new array. Defaults to ``x.strides``.
501
+ subok : bool, optional
502
+ .. versionadded:: 1.10
503
+
504
+ If True, subclasses are preserved.
505
+ writeable : bool, optional
506
+ .. versionadded:: 1.12
507
+
508
+ If set to False, the returned array will always be readonly.
509
+ Otherwise it will be writable if the original array was. It
510
+ is advisable to set this to False if possible (see Notes).
511
+
512
+ Returns
513
+ -------
514
+ view : ndarray
515
+
516
+ See also
517
+ --------
518
+ broadcast_to : broadcast an array to a given shape.
519
+ reshape : reshape an array.
520
+ lib.stride_tricks.sliding_window_view :
521
+ userfriendly and safe function for the creation of sliding window views.
522
+
523
+ Notes
524
+ -----
525
+ ``as_strided`` creates a view into the array given the exact strides
526
+ and shape. This means it manipulates the internal data structure of
527
+ ndarray and, if done incorrectly, the array elements can point to
528
+ invalid memory and can corrupt results or crash your program.
529
+ It is advisable to always use the original ``x.strides`` when
530
+ calculating new strides to avoid reliance on a contiguous memory
531
+ layout.
532
+
533
+ Furthermore, arrays created with this function often contain self
534
+ overlapping memory, so that two elements are identical.
535
+ Vectorized write operations on such arrays will typically be
536
+ unpredictable. They may even give different results for small, large,
537
+ or transposed arrays.
538
+ Since writing to these arrays has to be tested and done with great
539
+ care, you may want to use ``writeable=False`` to avoid accidental write
540
+ operations.
541
+
542
+ For these reasons it is advisable to avoid ``as_strided`` when
543
+ possible.
544
+ """
545
+ # first convert input to array, possibly keeping subclass
546
+ x = np.array(x, copy=False, subok=subok)
547
+ interface = dict(x.__array_interface__)
548
+ if shape is not None:
549
+ interface['shape'] = tuple(shape)
550
+ if strides is not None:
551
+ interface['strides'] = tuple(strides)
552
+
553
+ array = np.asarray(DummyArray(interface, base=x))
554
+ # The route via `__interface__` does not preserve structured
555
+ # dtypes. Since dtype should remain unchanged, we set it explicitly.
556
+ array.dtype = x.dtype
557
+
558
+ view = _maybe_view_as_subclass(x, array)
559
+
560
+ if view.flags.writeable and not writeable:
561
+ view.flags.writeable = False
562
+
563
+ return view
564
+
565
+
566
+ def _maybe_view_as_subclass(original_array, new_array):
567
+ if type(original_array) is not type(new_array):
568
+ # if input was an ndarray subclass and subclasses were OK,
569
+ # then view the result as that subclass.
570
+ new_array = new_array.view(type=type(original_array))
571
+ # Since we have done something akin to a view from original_array, we
572
+ # should let the subclass finalize (if it has it implemented, i.e., is
573
+ # not None).
574
+ if new_array.__array_finalize__:
575
+ new_array.__array_finalize__(original_array)
576
+ return new_array
577
+
578
+
579
+ def power_to_db(S, ref=1.0, amin=1e-10, top_db=80.0):
580
+ """Convert a power spectrogram (amplitude squared) to decibel (dB) units
581
+
582
+ This computes the scaling ``10 * log10(S / ref)`` in a numerically
583
+ stable way.
584
+
585
+ Parameters
586
+ ----------
587
+ S : np.ndarray
588
+ input power
589
+
590
+ ref : scalar or callable
591
+ If scalar, the amplitude ``abs(S)`` is scaled relative to ``ref``::
592
+
593
+ 10 * log10(S / ref)
594
+
595
+ Zeros in the output correspond to positions where ``S == ref``.
596
+
597
+ If callable, the reference value is computed as ``ref(S)``.
598
+
599
+ amin : float > 0 [scalar]
600
+ minimum threshold for ``abs(S)`` and ``ref``
601
+
602
+ top_db : float >= 0 [scalar]
603
+ threshold the output at ``top_db`` below the peak:
604
+ ``max(10 * log10(S)) - top_db``
605
+
606
+ Returns
607
+ -------
608
+ S_db : np.ndarray
609
+ ``S_db ~= 10 * log10(S) - 10 * log10(ref)``
610
+
611
+ See Also
612
+ --------
613
+ perceptual_weighting
614
+ db_to_power
615
+ amplitude_to_db
616
+ db_to_amplitude
617
+
618
+ Notes
619
+ -----
620
+ This function caches at level 30.
621
+
622
+
623
+ Examples
624
+ --------
625
+ Get a power spectrogram from a waveform ``y``
626
+
627
+ >>> y, sr = librosa.load(librosa.ex('trumpet'))
628
+ >>> S = np.abs(librosa.stft(y))
629
+ >>> librosa.power_to_db(S**2)
630
+ array([[-41.809, -41.809, ..., -41.809, -41.809],
631
+ [-41.809, -41.809, ..., -41.809, -41.809],
632
+ ...,
633
+ [-41.809, -41.809, ..., -41.809, -41.809],
634
+ [-41.809, -41.809, ..., -41.809, -41.809]], dtype=float32)
635
+
636
+ Compute dB relative to peak power
637
+
638
+ >>> librosa.power_to_db(S**2, ref=np.max)
639
+ array([[-80., -80., ..., -80., -80.],
640
+ [-80., -80., ..., -80., -80.],
641
+ ...,
642
+ [-80., -80., ..., -80., -80.],
643
+ [-80., -80., ..., -80., -80.]], dtype=float32)
644
+
645
+ Or compare to median power
646
+
647
+ >>> librosa.power_to_db(S**2, ref=np.median)
648
+ array([[16.578, 16.578, ..., 16.578, 16.578],
649
+ [16.578, 16.578, ..., 16.578, 16.578],
650
+ ...,
651
+ [16.578, 16.578, ..., 16.578, 16.578],
652
+ [16.578, 16.578, ..., 16.578, 16.578]], dtype=float32)
653
+
654
+
655
+ And plot the results
656
+
657
+ >>> import matplotlib.pyplot as plt
658
+ >>> fig, ax = plt.subplots(nrows=2, sharex=True, sharey=True)
659
+ >>> imgpow = librosa.display.specshow(S**2, sr=sr, y_axis='log', x_axis='time',
660
+ ... ax=ax[0])
661
+ >>> ax[0].set(title='Power spectrogram')
662
+ >>> ax[0].label_outer()
663
+ >>> imgdb = librosa.display.specshow(librosa.power_to_db(S**2, ref=np.max),
664
+ ... sr=sr, y_axis='log', x_axis='time', ax=ax[1])
665
+ >>> ax[1].set(title='Log-Power spectrogram')
666
+ >>> fig.colorbar(imgpow, ax=ax[0])
667
+ >>> fig.colorbar(imgdb, ax=ax[1], format="%+2.0f dB")
668
+ """
669
+
670
+ S = torch.asarray(S)
671
+
672
+ if amin <= 0:
673
+ raise AssertionError("amin must be strictly positive")
674
+
675
+ # if np.issubdtype(S.dtype, np.complexfloating):
676
+ # warnings.warn(
677
+ # "power_to_db was called on complex input so phase "
678
+ # "information will be discarded. To suppress this warning, "
679
+ # "call power_to_db(np.abs(D)**2) instead."
680
+ # )
681
+ # magnitude = np.abs(S)
682
+ # else:
683
+ magnitude = S
684
+
685
+ if callable(ref):
686
+ # User supplied a function to calculate reference power
687
+ ref_value = ref(magnitude)
688
+ else:
689
+ ref_value = torch.abs(ref)
690
+
691
+ log_spec = 10.0 * torch.log10(torch.maximum(torch.tensor(amin), magnitude))
692
+ log_spec -= 10.0 * torch.log10(torch.maximum(torch.tensor(amin), ref_value))
693
+
694
+ if top_db is not None:
695
+ if top_db < 0:
696
+ raise AssertionError("top_db must be non-negative")
697
+ log_spec = torch.maximum(log_spec, log_spec.max() - top_db)
698
+
699
+ return log_spec
700
+
701
+
702
+ def frames_to_samples(frames, hop_length=512, n_fft=None):
703
+ """Converts frame indices to audio sample indices.
704
+
705
+ Parameters
706
+ ----------
707
+ frames : number or np.ndarray [shape=(n,)]
708
+ frame index or vector of frame indices
709
+
710
+ hop_length : int > 0 [scalar]
711
+ number of samples between successive frames
712
+
713
+ n_fft : None or int > 0 [scalar]
714
+ Optional: length of the FFT window.
715
+ If given, time conversion will include an offset of ``n_fft // 2``
716
+ to counteract windowing effects when using a non-centered STFT.
717
+
718
+ Returns
719
+ -------
720
+ times : number or np.ndarray
721
+ time (in samples) of each given frame number::
722
+
723
+ times[i] = frames[i] * hop_length
724
+
725
+ See Also
726
+ --------
727
+ frames_to_time : convert frame indices to time values
728
+ samples_to_frames : convert sample indices to frame indices
729
+
730
+ Examples
731
+ --------
732
+ >>> y, sr = librosa.load(librosa.ex('choice'))
733
+ >>> tempo, beats = librosa.beat.beat_track(y, sr=sr)
734
+ >>> beat_samples = librosa.frames_to_samples(beats)
735
+ """
736
+
737
+ offset = 0
738
+ if n_fft is not None:
739
+ offset = int(n_fft // 2)
740
+
741
+ return (torch.asarray(frames) * hop_length + offset).to(torch.int)
mars5/utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+
4
+ def length_to_mask(length, offsets, max_len=None):
5
+ """
6
+ Convert tensor of lengths into a mask.
7
+
8
+ Args:
9
+ length (Tensor): a tensor of lengths, shape = (batch_size,)
10
+ offsets (Tensor): a tensor of offsets, shape = (batch_size,)
11
+ max_len (int, optional): maximum length to be considered
12
+
13
+ Returns:
14
+ mask (Tensor): a mask tensor, shape = (batch_size, max_len),
15
+ True in masked positions, False otherwise.
16
+ """
17
+ # get the batch size
18
+ batch_size = length.size(0)
19
+
20
+ # if maximum length is not provided, then compute it from the 'length' tensor.
21
+ if max_len is None:
22
+ max_len = length.max().item()
23
+
24
+ # Create a tensor of size `(batch_size, max_len)` filled with `True`.
25
+ mask = torch.ones(size=(batch_size, max_len), dtype=torch.bool, device=length.device)
26
+
27
+ # Create a tensor with consecutive numbers.
28
+ range_tensor = torch.arange(max_len, device=length.device)
29
+
30
+ # Expand the dim of 'length' tensor and 'offset' tensor to make it `(batch_size, max_len)`.
31
+ # The added dimension will be used for broadcasting.
32
+ length_exp = length.unsqueeze(-1)
33
+ offsets_exp = offsets.unsqueeze(-1)
34
+
35
+ # Create a boolean mask where `False` represents valid positions and `True` represents padding.
36
+ mask = (range_tensor < offsets_exp) | (~(range_tensor < length_exp))
37
+
38
+ return mask
39
+
40
+
41
+ def construct_padding_mask(input_tensor, pad_token):
42
+ return (input_tensor == pad_token).cumsum(dim=1) > 0
43
+
44
+
45
+ def nuke_weight_norm(module):
46
+ """
47
+ Recursively remove weight normalization from a module and its children.
48
+
49
+ Args:
50
+ module (torch.nn.Module): The module from which to remove weight normalization.
51
+ """
52
+ # Remove weight norm from current module if it exists
53
+ try:
54
+ torch.nn.utils.remove_weight_norm(module)
55
+ logging.debug(f"Removed weight norm from {module.__class__.__name__}")
56
+ except ValueError:
57
+ # Ignore if the module does not have weight norm applied.
58
+ pass
59
+
60
+ # Recursively call the function on children modules
61
+ for child in module.children():
62
+ nuke_weight_norm(child)
mars5_demo.ipynb ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "!pip install --upgrade vocos encodec librosa"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 16,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import pprint\n",
19
+ "import IPython.display as ipd\n",
20
+ "import torch\n",
21
+ "import librosa"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "# load model\n",
31
+ "mars5, config_class = torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', trust_repo=True)"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "metadata": {},
37
+ "source": [
38
+ "Now that the model is loaded, pick a reference audio to clone from. If you want to use deep clone, also specify its transcript. "
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "# download example ref audio\n",
48
+ "!wget -O example.wav https://github.com/Camb-ai/mars5-tts/raw/master/docs/assets/example_ref.wav "
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "wav, sr = librosa.load('./example.wav', \n",
58
+ " sr=mars5.sr, mono=True)\n",
59
+ "wav = torch.from_numpy(wav)\n",
60
+ "ref_transcript = \"We actually haven't managed to meet demand.\"\n",
61
+ "print(\"Reference audio:\")\n",
62
+ "ipd.display(ipd.Audio(wav.numpy(), rate=mars5.sr))\n",
63
+ "print(f\"Reference transcript: {ref_transcript}\")"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "deep_clone = True # set to False if you don't know prompt transcript or want fast inference.\n",
73
+ "# Below you can tune other inference settings, like top_k, temperature, top_p, etc...\n",
74
+ "cfg = config_class(deep_clone=deep_clone, rep_penalty_window=100,\n",
75
+ " top_k=100, temperature=0.7, freq_penalty=3)\n",
76
+ "\n",
77
+ "ar_codes, wav_out = mars5.tts(\"The quick brown rat.\", wav, \n",
78
+ " ref_transcript,\n",
79
+ " cfg=cfg)\n",
80
+ "\n",
81
+ "print('Synthesized output audio:')\n",
82
+ "ipd.Audio(wav_out.numpy(), rate=mars5.sr)"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "metadata": {},
88
+ "source": [
89
+ "You can see all the inference settings available to tune in the inference config here:"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": null,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "pprint.pprint(config_class())"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "markdown",
103
+ "metadata": {},
104
+ "source": [
105
+ "You can also listen to the vocoded raw coarse codes, for debugging purposes:"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "ar_wav = mars5.vocode(ar_codes.cpu()[:, None])\n",
115
+ "ipd.Audio(ar_wav.numpy(), rate=mars5.sr)"
116
+ ]
117
+ }
118
+ ],
119
+ "metadata": {
120
+ "kernelspec": {
121
+ "display_name": "matt-py311",
122
+ "language": "python",
123
+ "name": "python3"
124
+ },
125
+ "language_info": {
126
+ "codemirror_mode": {
127
+ "name": "ipython",
128
+ "version": 3
129
+ },
130
+ "file_extension": ".py",
131
+ "mimetype": "text/x-python",
132
+ "name": "python",
133
+ "nbconvert_exporter": "python",
134
+ "pygments_lexer": "ipython3",
135
+ "version": "3.11.9"
136
+ }
137
+ },
138
+ "nbformat": 4,
139
+ "nbformat_minor": 2
140
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ numpy
5
+ regex
6
+ librosa
7
+ vocos
8
+ encodec