haoheliu commited on
Commit
ea270ed
·
1 Parent(s): 9cfb0dd

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +9 -0
  2. LICENSE +437 -0
  3. MANIFEST.in +2 -0
  4. README.md +16 -6
  5. app.py +361 -0
  6. audioldm2/__init__.py +2 -0
  7. audioldm2/audiomae_gen/__init__.py +1 -0
  8. audioldm2/audiomae_gen/sequence_input.py +429 -0
  9. audioldm2/audiomae_gen/utils.py +27 -0
  10. audioldm2/clap/__init__.py +0 -0
  11. audioldm2/clap/open_clip/__init__.py +25 -0
  12. audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  13. audioldm2/clap/open_clip/factory.py +276 -0
  14. audioldm2/clap/open_clip/feature_fusion.py +192 -0
  15. audioldm2/clap/open_clip/htsat.py +1304 -0
  16. audioldm2/clap/open_clip/loss.py +397 -0
  17. audioldm2/clap/open_clip/model.py +931 -0
  18. audioldm2/clap/open_clip/model_configs/HTSAT-base.json +23 -0
  19. audioldm2/clap/open_clip/model_configs/HTSAT-large.json +23 -0
  20. audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json +23 -0
  21. audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json +23 -0
  22. audioldm2/clap/open_clip/model_configs/PANN-10.json +23 -0
  23. audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json +23 -0
  24. audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json +23 -0
  25. audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json +23 -0
  26. audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json +23 -0
  27. audioldm2/clap/open_clip/model_configs/PANN-14.json +23 -0
  28. audioldm2/clap/open_clip/model_configs/PANN-6.json +23 -0
  29. audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json +22 -0
  30. audioldm2/clap/open_clip/model_configs/RN101.json +21 -0
  31. audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json +22 -0
  32. audioldm2/clap/open_clip/model_configs/RN50.json +21 -0
  33. audioldm2/clap/open_clip/model_configs/RN50x16.json +21 -0
  34. audioldm2/clap/open_clip/model_configs/RN50x4.json +21 -0
  35. audioldm2/clap/open_clip/model_configs/ViT-B-16.json +16 -0
  36. audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json +17 -0
  37. audioldm2/clap/open_clip/model_configs/ViT-B-32.json +16 -0
  38. audioldm2/clap/open_clip/model_configs/ViT-L-14.json +16 -0
  39. audioldm2/clap/open_clip/openai.py +156 -0
  40. audioldm2/clap/open_clip/pann_model.py +697 -0
  41. audioldm2/clap/open_clip/pretrained.py +167 -0
  42. audioldm2/clap/open_clip/timm_model.py +112 -0
  43. audioldm2/clap/open_clip/tokenizer.py +197 -0
  44. audioldm2/clap/open_clip/transform.py +45 -0
  45. audioldm2/clap/open_clip/utils.py +356 -0
  46. audioldm2/clap/training/__init__.py +0 -0
  47. audioldm2/clap/training/audioset_textmap.npy +3 -0
  48. audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz +3 -0
  49. audioldm2/clap/training/data.py +865 -0
  50. audioldm2/clap/training/params.py +563 -0
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ __pycache__
3
+ test.py
4
+ flagged
5
+ output
6
+ gradio_cached*
7
+ dist*
8
+ *egg-info
9
+ build*
LICENSE ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial-ShareAlike 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.cp
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More_considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58
+ Public License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial-ShareAlike 4.0 International Public License
63
+ ("Public License"). To the extent this Public License may be
64
+ interpreted as a contract, You are granted the Licensed Rights in
65
+ consideration of Your acceptance of these terms and conditions, and the
66
+ Licensor grants You such rights in consideration of benefits the
67
+ Licensor receives from making the Licensed Material available under
68
+ these terms and conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. BY-NC-SA Compatible License means a license listed at
88
+ creativecommons.org/compatiblelicenses, approved by Creative
89
+ Commons as essentially the equivalent of this Public License.
90
+
91
+ d. Copyright and Similar Rights means copyright and/or similar rights
92
+ closely related to copyright including, without limitation,
93
+ performance, broadcast, sound recording, and Sui Generis Database
94
+ Rights, without regard to how the rights are labeled or
95
+ categorized. For purposes of this Public License, the rights
96
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
97
+ Rights.
98
+
99
+ e. Effective Technological Measures means those measures that, in the
100
+ absence of proper authority, may not be circumvented under laws
101
+ fulfilling obligations under Article 11 of the WIPO Copyright
102
+ Treaty adopted on December 20, 1996, and/or similar international
103
+ agreements.
104
+
105
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
106
+ any other exception or limitation to Copyright and Similar Rights
107
+ that applies to Your use of the Licensed Material.
108
+
109
+ g. License Elements means the license attributes listed in the name
110
+ of a Creative Commons Public License. The License Elements of this
111
+ Public License are Attribution, NonCommercial, and ShareAlike.
112
+
113
+ h. Licensed Material means the artistic or literary work, database,
114
+ or other material to which the Licensor applied this Public
115
+ License.
116
+
117
+ i. Licensed Rights means the rights granted to You subject to the
118
+ terms and conditions of this Public License, which are limited to
119
+ all Copyright and Similar Rights that apply to Your use of the
120
+ Licensed Material and that the Licensor has authority to license.
121
+
122
+ j. Licensor means the individual(s) or entity(ies) granting rights
123
+ under this Public License.
124
+
125
+ k. NonCommercial means not primarily intended for or directed towards
126
+ commercial advantage or monetary compensation. For purposes of
127
+ this Public License, the exchange of the Licensed Material for
128
+ other material subject to Copyright and Similar Rights by digital
129
+ file-sharing or similar means is NonCommercial provided there is
130
+ no payment of monetary compensation in connection with the
131
+ exchange.
132
+
133
+ l. Share means to provide material to the public by any means or
134
+ process that requires permission under the Licensed Rights, such
135
+ as reproduction, public display, public performance, distribution,
136
+ dissemination, communication, or importation, and to make material
137
+ available to the public including in ways that members of the
138
+ public may access the material from a place and at a time
139
+ individually chosen by them.
140
+
141
+ m. Sui Generis Database Rights means rights other than copyright
142
+ resulting from Directive 96/9/EC of the European Parliament and of
143
+ the Council of 11 March 1996 on the legal protection of databases,
144
+ as amended and/or succeeded, as well as other essentially
145
+ equivalent rights anywhere in the world.
146
+
147
+ n. You means the individual or entity exercising the Licensed Rights
148
+ under this Public License. Your has a corresponding meaning.
149
+
150
+
151
+ Section 2 -- Scope.
152
+
153
+ a. License grant.
154
+
155
+ 1. Subject to the terms and conditions of this Public License,
156
+ the Licensor hereby grants You a worldwide, royalty-free,
157
+ non-sublicensable, non-exclusive, irrevocable license to
158
+ exercise the Licensed Rights in the Licensed Material to:
159
+
160
+ a. reproduce and Share the Licensed Material, in whole or
161
+ in part, for NonCommercial purposes only; and
162
+
163
+ b. produce, reproduce, and Share Adapted Material for
164
+ NonCommercial purposes only.
165
+
166
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
167
+ Exceptions and Limitations apply to Your use, this Public
168
+ License does not apply, and You do not need to comply with
169
+ its terms and conditions.
170
+
171
+ 3. Term. The term of this Public License is specified in Section
172
+ 6(a).
173
+
174
+ 4. Media and formats; technical modifications allowed. The
175
+ Licensor authorizes You to exercise the Licensed Rights in
176
+ all media and formats whether now known or hereafter created,
177
+ and to make technical modifications necessary to do so. The
178
+ Licensor waives and/or agrees not to assert any right or
179
+ authority to forbid You from making technical modifications
180
+ necessary to exercise the Licensed Rights, including
181
+ technical modifications necessary to circumvent Effective
182
+ Technological Measures. For purposes of this Public License,
183
+ simply making modifications authorized by this Section 2(a)
184
+ (4) never produces Adapted Material.
185
+
186
+ 5. Downstream recipients.
187
+
188
+ a. Offer from the Licensor -- Licensed Material. Every
189
+ recipient of the Licensed Material automatically
190
+ receives an offer from the Licensor to exercise the
191
+ Licensed Rights under the terms and conditions of this
192
+ Public License.
193
+
194
+ b. Additional offer from the Licensor -- Adapted Material.
195
+ Every recipient of Adapted Material from You
196
+ automatically receives an offer from the Licensor to
197
+ exercise the Licensed Rights in the Adapted Material
198
+ under the conditions of the Adapter's License You apply.
199
+
200
+ c. No downstream restrictions. You may not offer or impose
201
+ any additional or different terms or conditions on, or
202
+ apply any Effective Technological Measures to, the
203
+ Licensed Material if doing so restricts exercise of the
204
+ Licensed Rights by any recipient of the Licensed
205
+ Material.
206
+
207
+ 6. No endorsement. Nothing in this Public License constitutes or
208
+ may be construed as permission to assert or imply that You
209
+ are, or that Your use of the Licensed Material is, connected
210
+ with, or sponsored, endorsed, or granted official status by,
211
+ the Licensor or others designated to receive attribution as
212
+ provided in Section 3(a)(1)(A)(i).
213
+
214
+ b. Other rights.
215
+
216
+ 1. Moral rights, such as the right of integrity, are not
217
+ licensed under this Public License, nor are publicity,
218
+ privacy, and/or other similar personality rights; however, to
219
+ the extent possible, the Licensor waives and/or agrees not to
220
+ assert any such rights held by the Licensor to the limited
221
+ extent necessary to allow You to exercise the Licensed
222
+ Rights, but not otherwise.
223
+
224
+ 2. Patent and trademark rights are not licensed under this
225
+ Public License.
226
+
227
+ 3. To the extent possible, the Licensor waives any right to
228
+ collect royalties from You for the exercise of the Licensed
229
+ Rights, whether directly or through a collecting society
230
+ under any voluntary or waivable statutory or compulsory
231
+ licensing scheme. In all other cases the Licensor expressly
232
+ reserves any right to collect such royalties, including when
233
+ the Licensed Material is used other than for NonCommercial
234
+ purposes.
235
+
236
+
237
+ Section 3 -- License Conditions.
238
+
239
+ Your exercise of the Licensed Rights is expressly made subject to the
240
+ following conditions.
241
+
242
+ a. Attribution.
243
+
244
+ 1. If You Share the Licensed Material (including in modified
245
+ form), You must:
246
+
247
+ a. retain the following if it is supplied by the Licensor
248
+ with the Licensed Material:
249
+
250
+ i. identification of the creator(s) of the Licensed
251
+ Material and any others designated to receive
252
+ attribution, in any reasonable manner requested by
253
+ the Licensor (including by pseudonym if
254
+ designated);
255
+
256
+ ii. a copyright notice;
257
+
258
+ iii. a notice that refers to this Public License;
259
+
260
+ iv. a notice that refers to the disclaimer of
261
+ warranties;
262
+
263
+ v. a URI or hyperlink to the Licensed Material to the
264
+ extent reasonably practicable;
265
+
266
+ b. indicate if You modified the Licensed Material and
267
+ retain an indication of any previous modifications; and
268
+
269
+ c. indicate the Licensed Material is licensed under this
270
+ Public License, and include the text of, or the URI or
271
+ hyperlink to, this Public License.
272
+
273
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
274
+ reasonable manner based on the medium, means, and context in
275
+ which You Share the Licensed Material. For example, it may be
276
+ reasonable to satisfy the conditions by providing a URI or
277
+ hyperlink to a resource that includes the required
278
+ information.
279
+ 3. If requested by the Licensor, You must remove any of the
280
+ information required by Section 3(a)(1)(A) to the extent
281
+ reasonably practicable.
282
+
283
+ b. ShareAlike.
284
+
285
+ In addition to the conditions in Section 3(a), if You Share
286
+ Adapted Material You produce, the following conditions also apply.
287
+
288
+ 1. The Adapter's License You apply must be a Creative Commons
289
+ license with the same License Elements, this version or
290
+ later, or a BY-NC-SA Compatible License.
291
+
292
+ 2. You must include the text of, or the URI or hyperlink to, the
293
+ Adapter's License You apply. You may satisfy this condition
294
+ in any reasonable manner based on the medium, means, and
295
+ context in which You Share Adapted Material.
296
+
297
+ 3. You may not offer or impose any additional or different terms
298
+ or conditions on, or apply any Effective Technological
299
+ Measures to, Adapted Material that restrict exercise of the
300
+ rights granted under the Adapter's License You apply.
301
+
302
+
303
+ Section 4 -- Sui Generis Database Rights.
304
+
305
+ Where the Licensed Rights include Sui Generis Database Rights that
306
+ apply to Your use of the Licensed Material:
307
+
308
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309
+ to extract, reuse, reproduce, and Share all or a substantial
310
+ portion of the contents of the database for NonCommercial purposes
311
+ only;
312
+
313
+ b. if You include all or a substantial portion of the database
314
+ contents in a database in which You have Sui Generis Database
315
+ Rights, then the database in which You have Sui Generis Database
316
+ Rights (but not its individual contents) is Adapted Material,
317
+ including for purposes of Section 3(b); and
318
+
319
+ c. You must comply with the conditions in Section 3(a) if You Share
320
+ all or a substantial portion of the contents of the database.
321
+
322
+ For the avoidance of doubt, this Section 4 supplements and does not
323
+ replace Your obligations under this Public License where the Licensed
324
+ Rights include other Copyright and Similar Rights.
325
+
326
+
327
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328
+
329
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339
+
340
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349
+
350
+ c. The disclaimer of warranties and limitation of liability provided
351
+ above shall be interpreted in a manner that, to the extent
352
+ possible, most closely approximates an absolute disclaimer and
353
+ waiver of all liability.
354
+
355
+
356
+ Section 6 -- Term and Termination.
357
+
358
+ a. This Public License applies for the term of the Copyright and
359
+ Similar Rights licensed here. However, if You fail to comply with
360
+ this Public License, then Your rights under this Public License
361
+ terminate automatically.
362
+
363
+ b. Where Your right to use the Licensed Material has terminated under
364
+ Section 6(a), it reinstates:
365
+
366
+ 1. automatically as of the date the violation is cured, provided
367
+ it is cured within 30 days of Your discovery of the
368
+ violation; or
369
+
370
+ 2. upon express reinstatement by the Licensor.
371
+
372
+ For the avoidance of doubt, this Section 6(b) does not affect any
373
+ right the Licensor may have to seek remedies for Your violations
374
+ of this Public License.
375
+
376
+ c. For the avoidance of doubt, the Licensor may also offer the
377
+ Licensed Material under separate terms or conditions or stop
378
+ distributing the Licensed Material at any time; however, doing so
379
+ will not terminate this Public License.
380
+
381
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382
+ License.
383
+
384
+
385
+ Section 7 -- Other Terms and Conditions.
386
+
387
+ a. The Licensor shall not be bound by any additional or different
388
+ terms or conditions communicated by You unless expressly agreed.
389
+
390
+ b. Any arrangements, understandings, or agreements regarding the
391
+ Licensed Material not stated herein are separate from and
392
+ independent of the terms and conditions of this Public License.
393
+
394
+
395
+ Section 8 -- Interpretation.
396
+
397
+ a. For the avoidance of doubt, this Public License does not, and
398
+ shall not be interpreted to, reduce, limit, restrict, or impose
399
+ conditions on any use of the Licensed Material that could lawfully
400
+ be made without permission under this Public License.
401
+
402
+ b. To the extent possible, if any provision of this Public License is
403
+ deemed unenforceable, it shall be automatically reformed to the
404
+ minimum extent necessary to make it enforceable. If the provision
405
+ cannot be reformed, it shall be severed from this Public License
406
+ without affecting the enforceability of the remaining terms and
407
+ conditions.
408
+
409
+ c. No term or condition of this Public License will be waived and no
410
+ failure to comply consented to unless expressly agreed to by the
411
+ Licensor.
412
+
413
+ d. Nothing in this Public License constitutes or may be interpreted
414
+ as a limitation upon, or waiver of, any privileges and immunities
415
+ that apply to the Licensor or You, including from the legal
416
+ processes of any jurisdiction or authority.
417
+
418
+ =======================================================================
419
+
420
+ Creative Commons is not a party to its public
421
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
422
+ its public licenses to material it publishes and in those instances
423
+ will be considered the “Licensor.” The text of the Creative Commons
424
+ public licenses is dedicated to the public domain under the CC0 Public
425
+ Domain Dedication. Except for the limited purpose of indicating that
426
+ material is shared under a Creative Commons public license or as
427
+ otherwise permitted by the Creative Commons policies published at
428
+ creativecommons.org/policies, Creative Commons does not authorize the
429
+ use of the trademark "Creative Commons" or any other trademark or logo
430
+ of Creative Commons without its prior written consent including,
431
+ without limitation, in connection with any unauthorized modifications
432
+ to any of its public licenses or any other arrangements,
433
+ understandings, or agreements concerning use of licensed material. For
434
+ the avoidance of doubt, this paragraph does not form part of the
435
+ public licenses.
436
+
437
+ Creative Commons may be contacted at creativecommons.org.
MANIFEST.in ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ include *.py LICENSE README.md
2
+ recursive-include audioldm2 *.txt *.py *.gz *.npy *.json
README.md CHANGED
@@ -1,13 +1,23 @@
1
  ---
2
- title: Audioldm2 Text2audio Text2music
3
- emoji: 👁
4
- colorFrom: gray
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.39.0
8
  app_file: app.py
9
  pinned: false
10
- license: cc-by-nc-nd-4.0
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: AudioLDM2 Text2Audio Text2Music Generation
3
+ emoji: 🔊
4
+ colorFrom: indigo
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.27.0
8
  app_file: app.py
9
  pinned: false
10
+ license: bigscience-openrail-m
11
+ duplicated_from: haoheliu/audioldm2-text2audio-text2music
12
+
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
16
+
17
+ ## Reference
18
+ Part of the code from this repo is borrowed from the following repos. We would like to thank the authors of them for their contribution.
19
+
20
+ > https://github.com/LAION-AI/CLAP
21
+ > https://github.com/CompVis/stable-diffusion
22
+ > https://github.com/v-iashin/SpecVQGAN
23
+ > https://github.com/toshas/torch-fidelity
app.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ import torch
3
+ import os
4
+
5
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
6
+
7
+ import gradio as gr
8
+ from audioldm2 import text_to_audio, build_model
9
+ from share_btn import community_icon_html, loading_icon_html, share_js
10
+
11
+ model_id = "haoheliu/audioldm2-full"
12
+ hf_hub_download(repo_id="haoheliu/audioldm2-full", filename="audioldm2-full.pth")
13
+
14
+ audioldm = None
15
+ current_model_name = None
16
+
17
+ def text2audio(
18
+ text,
19
+ guidance_scale,
20
+ random_seed,
21
+ n_candidates,
22
+ model_name="audioldm2-full",
23
+ ):
24
+ global audioldm, current_model_name
25
+ torch.set_float32_matmul_precision("high")
26
+
27
+ if audioldm is None or model_name != current_model_name:
28
+ audioldm = build_model(model_name=model_name)
29
+ current_model_name = model_name
30
+ audioldm = torch.compile(audioldm)
31
+
32
+ # print(text, length, guidance_scale)
33
+ waveform = text_to_audio(
34
+ latent_diffusion=audioldm,
35
+ text=text,
36
+ seed=random_seed,
37
+ duration=10,
38
+ guidance_scale=guidance_scale,
39
+ n_candidate_gen_per_text=int(n_candidates),
40
+ ) # [bs, 1, samples]
41
+ waveform = [
42
+ gr.make_waveform((16000, wave[0]), bg_image="bg.png") for wave in waveform
43
+ ]
44
+ # waveform = [(16000, np.random.randn(16000)), (16000, np.random.randn(16000))]
45
+ if len(waveform) == 1:
46
+ waveform = waveform[0]
47
+ return waveform
48
+
49
+ css = """
50
+ a {
51
+ color: inherit;
52
+ text-decoration: underline;
53
+ }
54
+ .gradio-container {
55
+ font-family: 'IBM Plex Sans', sans-serif;
56
+ }
57
+ .gr-button {
58
+ color: white;
59
+ border-color: #000000;
60
+ background: #000000;
61
+ }
62
+ input[type='range'] {
63
+ accent-color: #000000;
64
+ }
65
+ .dark input[type='range'] {
66
+ accent-color: #dfdfdf;
67
+ }
68
+ .container {
69
+ max-width: 730px;
70
+ margin: auto;
71
+ padding-top: 1.5rem;
72
+ }
73
+ #gallery {
74
+ min-height: 22rem;
75
+ margin-bottom: 15px;
76
+ margin-left: auto;
77
+ margin-right: auto;
78
+ border-bottom-right-radius: .5rem !important;
79
+ border-bottom-left-radius: .5rem !important;
80
+ }
81
+ #gallery>div>.h-full {
82
+ min-height: 20rem;
83
+ }
84
+ .details:hover {
85
+ text-decoration: underline;
86
+ }
87
+ .gr-button {
88
+ white-space: nowrap;
89
+ }
90
+ .gr-button:focus {
91
+ border-color: rgb(147 197 253 / var(--tw-border-opacity));
92
+ outline: none;
93
+ box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
94
+ --tw-border-opacity: 1;
95
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
96
+ --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
97
+ --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
98
+ --tw-ring-opacity: .5;
99
+ }
100
+ #advanced-btn {
101
+ font-size: .7rem !important;
102
+ line-height: 19px;
103
+ margin-top: 12px;
104
+ margin-bottom: 12px;
105
+ padding: 2px 8px;
106
+ border-radius: 14px !important;
107
+ }
108
+ #advanced-options {
109
+ margin-bottom: 20px;
110
+ }
111
+ .footer {
112
+ margin-bottom: 45px;
113
+ margin-top: 35px;
114
+ text-align: center;
115
+ border-bottom: 1px solid #e5e5e5;
116
+ }
117
+ .footer>p {
118
+ font-size: .8rem;
119
+ display: inline-block;
120
+ padding: 0 10px;
121
+ transform: translateY(10px);
122
+ background: white;
123
+ }
124
+ .dark .footer {
125
+ border-color: #303030;
126
+ }
127
+ .dark .footer>p {
128
+ background: #0b0f19;
129
+ }
130
+ .acknowledgments h4{
131
+ margin: 1.25em 0 .25em 0;
132
+ font-weight: bold;
133
+ font-size: 115%;
134
+ }
135
+ #container-advanced-btns{
136
+ display: flex;
137
+ flex-wrap: wrap;
138
+ justify-content: space-between;
139
+ align-items: center;
140
+ }
141
+ .animate-spin {
142
+ animation: spin 1s linear infinite;
143
+ }
144
+ @keyframes spin {
145
+ from {
146
+ transform: rotate(0deg);
147
+ }
148
+ to {
149
+ transform: rotate(360deg);
150
+ }
151
+ }
152
+ #share-btn-container {
153
+ display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
154
+ margin-top: 10px;
155
+ margin-left: auto;
156
+ }
157
+ #share-btn {
158
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
159
+ }
160
+ #share-btn * {
161
+ all: unset;
162
+ }
163
+ #share-btn-container div:nth-child(-n+2){
164
+ width: auto !important;
165
+ min-height: 0px !important;
166
+ }
167
+ #share-btn-container .wrap {
168
+ display: none !important;
169
+ }
170
+ .gr-form{
171
+ flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
172
+ }
173
+ #prompt-container{
174
+ gap: 0;
175
+ }
176
+ #generated_id{
177
+ min-height: 700px
178
+ }
179
+ #setting_id{
180
+ margin-bottom: 12px;
181
+ text-align: center;
182
+ font-weight: 900;
183
+ }
184
+ """
185
+ iface = gr.Blocks(css=css)
186
+
187
+ with iface:
188
+ gr.HTML(
189
+ """
190
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
191
+ <div
192
+ style="
193
+ display: inline-flex;
194
+ align-items: center;
195
+ gap: 0.8rem;
196
+ font-size: 1.75rem;
197
+ "
198
+ >
199
+ <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
200
+ AudioLDM 2: A General Framework for Audio, Music, and Speech Generation
201
+ </h1>
202
+ </div>
203
+ <p style="margin-bottom: 10px; font-size: 94%">
204
+ <a href="https://arxiv.org/abs/2301.12503">[Paper]</a> <a href="https://audioldm.github.io/">[Project page]</a>
205
+ </p>
206
+ </div>
207
+ """
208
+ )
209
+ gr.HTML(
210
+ """
211
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
212
+ AudioLDM 2: A General Framework for Audio, Music, and Speech Generation
213
+ </h1>
214
+ <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
215
+ <br/>
216
+ <a href="https://huggingface.co/spaces/haoheliu/audioldm2-text2audio-text2music?duplicate=true">
217
+ <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
218
+ <p/>
219
+ """
220
+ )
221
+ with gr.Group():
222
+ with gr.Box():
223
+ ############# Input
224
+ textbox = gr.Textbox(
225
+ value="A forest of wind chimes singing a soothing melody in the breeze.",
226
+ max_lines=1,
227
+ label="Input your text here. Your text is important for the audio quality. Please ensure it is descriptive by using more adjectives.",
228
+ elem_id="prompt-in",
229
+ )
230
+
231
+ with gr.Accordion("Click to modify detailed configurations", open=False):
232
+ seed = gr.Number(
233
+ value=45,
234
+ label="Change this value (any integer number) will lead to a different generation result.",
235
+ )
236
+ # duration = gr.Slider(
237
+ # 10, 10, value=10, step=2.5, label="Duration (seconds)"
238
+ # )
239
+ guidance_scale = gr.Slider(
240
+ 0,
241
+ 6,
242
+ value=3.5,
243
+ step=0.5,
244
+ label="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
245
+ )
246
+ n_candidates = gr.Slider(
247
+ 1,
248
+ 3,
249
+ value=3,
250
+ step=1,
251
+ label="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
252
+ )
253
+ # model_name = gr.Dropdown(
254
+ # ["audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full","audioldm-s-full-v2", "audioldm-s-full", "audioldm-l-full"], value="audioldm-m-full", label="Choose the model to use. audioldm-m-text-ft and audioldm-s-text-ft are recommanded. -s- means small, -m- means medium and -l- means large",
255
+ # )
256
+ ############# Output
257
+ # outputs=gr.Audio(label="Output", type="numpy")
258
+ outputs = gr.Video(label="Output", elem_id="output-video")
259
+
260
+ # with gr.Group(elem_id="container-advanced-btns"):
261
+ # # advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
262
+ # with gr.Group(elem_id="share-btn-container"):
263
+ # community_icon = gr.HTML(community_icon_html, visible=False)
264
+ # loading_icon = gr.HTML(loading_icon_html, visible=False)
265
+ # share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
266
+ # outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")]
267
+ btn = gr.Button("Submit").style(full_width=True)
268
+
269
+ with gr.Group(elem_id="share-btn-container", visible=False):
270
+ community_icon = gr.HTML(community_icon_html)
271
+ loading_icon = gr.HTML(loading_icon_html)
272
+ share_button = gr.Button("Share to community", elem_id="share-btn")
273
+
274
+ # btn.click(text2audio, inputs=[
275
+ # textbox, duration, guidance_scale, seed, n_candidates, model_name], outputs=[outputs])
276
+ btn.click(
277
+ text2audio,
278
+ inputs=[textbox, guidance_scale, seed, n_candidates],
279
+ outputs=[outputs],
280
+ )
281
+
282
+ share_button.click(None, [], [], _js=share_js)
283
+ gr.HTML(
284
+ """
285
+ <div class="footer" style="text-align: center; max-width: 700px; margin: 0 auto;">
286
+ <p>Follow the latest update of AudioLDM on our<a href="https://github.com/haoheliu/AudioLDM" style="text-decoration: underline;" target="_blank"> Github repo</a>
287
+ </p>
288
+ <br>
289
+ <p>Model by <a href="https://twitter.com/LiuHaohe" style="text-decoration: underline;" target="_blank">Haohe Liu</a></p>
290
+ <br>
291
+ </div>
292
+ """
293
+ )
294
+ gr.Examples(
295
+ [
296
+ [
297
+ "An excited crowd cheering at a sports game.",
298
+ 3.5,
299
+ 45,
300
+ 3,
301
+ "audioldm2-full",
302
+ ],
303
+ [
304
+ "A cat is meowing for attention.",
305
+ 3.5,
306
+ 45,
307
+ 3,
308
+ "audioldm2-full",
309
+ ],
310
+ [
311
+ "Birds singing sweetly in a blooming garden.",
312
+ 3.5,
313
+ 45,
314
+ 3,
315
+ "audioldm2-full",
316
+ ],
317
+ [
318
+ "A modern synthesizer creating futuristic soundscapes.",
319
+ 3.5,
320
+ 45,
321
+ 3,
322
+ "audioldm2-full",
323
+ ],
324
+ [
325
+ "The vibrant beat of Brazilian samba drums.",
326
+ 3.5,
327
+ 45,
328
+ 3,
329
+ "audioldm2-full",
330
+ ],
331
+ ],
332
+ fn=text2audio,
333
+ # inputs=[textbox, duration, guidance_scale, seed, n_candidates, model_name],
334
+ inputs=[textbox, guidance_scale, seed, n_candidates],
335
+ outputs=[outputs],
336
+ cache_examples=True,
337
+ )
338
+ gr.HTML(
339
+ """
340
+ <div class="acknowledgements">
341
+ <p>Essential Tricks for Enhancing the Quality of Your Generated Audio</p>
342
+ <p>1. Try to use more adjectives to describe your sound. For example: "A man is speaking clearly and slowly in a large room" is better than "A man is speaking". This can make sure AudioLDM understands what you want.</p>
343
+ <p>2. Try to use different random seeds, which can affect the generation quality significantly sometimes.</p>
344
+ <p>3. It's better to use general terms like 'man' or 'woman' instead of specific names for individuals or abstract objects that humans may not be familiar with, such as 'mummy'.</p>
345
+ </div>
346
+ """
347
+ )
348
+
349
+ with gr.Accordion("Additional information", open=False):
350
+ gr.HTML(
351
+ """
352
+ <div class="acknowledgments">
353
+ <p> We build the model with data from <a href="http://research.google.com/audioset/">AudioSet</a>, <a href="https://freesound.org/">Freesound</a> and <a href="https://sound-effects.bbcrewind.co.uk/">BBC Sound Effect library</a>. We share this demo based on the <a href="https://assets.publishing.service.gov.uk/government/uploads/system/uploads/attachment_data/file/375954/Research.pdf">UK copyright exception</a> of data for academic research. </p>
354
+ </div>
355
+ """
356
+ )
357
+ # <p>This demo is strictly for research demo purpose only. For commercial use please <a href="haoheliu@gmail.com">contact us</a>.</p>
358
+
359
+ iface.queue(concurrency_count=3)
360
+ # iface.launch(debug=True)
361
+ iface.launch(debug=True, share=True)
audioldm2/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .utils import seed_everything, save_wave, get_time, get_duration, read_list
2
+ from .pipeline import *
audioldm2/audiomae_gen/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sequence_input import Sequence2AudioMAE
audioldm2/audiomae_gen/sequence_input.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from audioldm2.latent_diffusion.util import (
4
+ instantiate_from_config,
5
+ )
6
+
7
+ # from latent_diffusion.modules.encoders.modules import CLAPAudioEmbeddingClassifierFreev2
8
+ from transformers import GPT2Config, GPT2Model
9
+ import torch.optim.lr_scheduler as lr_scheduler
10
+
11
+ class Sequence2AudioMAE(nn.Module):
12
+ def __init__(
13
+ self,
14
+ base_learning_rate,
15
+ sequence_gen_length,
16
+ sequence_input_key,
17
+ sequence_input_embed_dim,
18
+ cond_stage_config,
19
+ optimizer_type="AdamW",
20
+ use_warmup=True,
21
+ use_ar_gen_loss=False,
22
+ use_audiomae_linear=False,
23
+ target_tokens_mask_ratio=0.0,
24
+ random_mask_ratio=False,
25
+ **kwargs
26
+ ):
27
+ super().__init__()
28
+ assert use_audiomae_linear == False
29
+ self.random_mask_ratio = random_mask_ratio
30
+ self.learning_rate = base_learning_rate
31
+ self.cond_stage_config = cond_stage_config
32
+ self.use_audiomae_linear = use_audiomae_linear
33
+ self.optimizer_type = optimizer_type
34
+ self.use_warmup = use_warmup
35
+ self.use_ar_gen_loss = use_ar_gen_loss
36
+ # Even though the LDM can be conditioned on mutliple pooling rate
37
+ # Our model always predict the higest pooling rate
38
+
39
+ # self.time_pool = max(self.cond_stage_config["crossattn_audiomae_pooled"]["params"]["time_pooling_factors"])
40
+ # self.freq_pool = max(self.cond_stage_config["crossattn_audiomae_pooled"]["params"]["freq_pooling_factors"])
41
+ # self.mae_token_num = int(512/(self.time_pool*self.freq_pool))
42
+
43
+ self.mae_token_num = sequence_gen_length
44
+ self.sequence_input_key = sequence_input_key
45
+ self.sequence_input_embed_dim = sequence_input_embed_dim
46
+ self.target_tokens_mask_ratio = target_tokens_mask_ratio
47
+
48
+ self.start_of_sequence_tokens = nn.Embedding(32, 768)
49
+ self.end_of_sequence_tokens = nn.Embedding(32, 768)
50
+
51
+ self.input_sequence_embed_linear = nn.ModuleList([])
52
+ self.initial_learning_rate = None
53
+
54
+ for dim in self.sequence_input_embed_dim:
55
+ self.input_sequence_embed_linear.append(nn.Linear(dim, 768))
56
+
57
+ self.cond_stage_models = nn.ModuleList([])
58
+ self.instantiate_cond_stage(cond_stage_config)
59
+ self.initialize_param_check_toolkit()
60
+
61
+ # configuration = GPT2Config(n_layer=1) # TODO
62
+ # self.model=GPT2Model(configuration)
63
+ ###################
64
+ # self.model=nn.Linear(768,768, bias=False) # TODO change the model
65
+ # with torch.no_grad():
66
+ # self.model.weight.copy_(torch.eye(768))
67
+ ###################
68
+ self.model = GPT2Model(GPT2Config.from_pretrained("gpt2"))
69
+ ###################
70
+ # self.model = nn.LSTM(input_size=768, hidden_size=768, num_layers=1,bias=False) # TODO
71
+
72
+ # self.loss_fn = nn.MSELoss()
73
+ self.loss_fn = nn.L1Loss()
74
+
75
+ self.logger_save_dir = None
76
+ self.logger_exp_name = None
77
+ self.logger_exp_group_name = None
78
+ self.logger_version = None
79
+
80
+ def set_log_dir(self, save_dir, exp_group_name, exp_name):
81
+ self.logger_save_dir = save_dir
82
+ self.logger_exp_group_name = exp_group_name
83
+ self.logger_exp_name = exp_name
84
+
85
+ def cfg_uncond(self, batch_size):
86
+ unconditional_conditioning = {}
87
+ for key in self.cond_stage_model_metadata:
88
+ model_idx = self.cond_stage_model_metadata[key]["model_idx"]
89
+ unconditional_conditioning[key] = self.cond_stage_models[
90
+ model_idx
91
+ ].get_unconditional_condition(batch_size)
92
+ assert (
93
+ "crossattn_audiomae_pooled" in unconditional_conditioning.keys()
94
+ ), "The module is not initialized with AudioMAE"
95
+ unconditional_conditioning[
96
+ "crossattn_clap_to_audiomae_feature"
97
+ ] = unconditional_conditioning["crossattn_audiomae_pooled"]
98
+ return unconditional_conditioning
99
+
100
+ def configure_optimizers(self):
101
+ lr = float(self.learning_rate)
102
+ # params = list(self.model.parameters()) + list(self.input_sequence_embed_linear.parameters())
103
+ params = list(self.parameters())
104
+
105
+ # opt = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.98), eps=1e-9)
106
+ opt = eval(self.optimizer_type)(params, lr=lr)
107
+ scheduler = lr_scheduler.StepLR(opt, step_size=10, gamma=0.8)
108
+ return [opt], [scheduler]
109
+
110
+ def add_sos_eos_tokens(self, _id, sequence, attn_mask):
111
+ batchsize = sequence.size(0)
112
+
113
+ new_attn_mask_step = torch.ones((batchsize, 1)).to(sequence.device)
114
+ key_id = torch.tensor([_id]).to(sequence.device)
115
+
116
+ # Add two more steps to attn mask
117
+ new_attn_mask = torch.cat(
118
+ [new_attn_mask_step, attn_mask, new_attn_mask_step], dim=1
119
+ )
120
+
121
+ # Add two more tokens in the sequence
122
+ sos_token = self.start_of_sequence_tokens(key_id).expand(batchsize, 1, -1)
123
+ eos_token = self.end_of_sequence_tokens(key_id).expand(batchsize, 1, -1)
124
+ new_sequence = torch.cat([sos_token, sequence, eos_token], dim=1)
125
+ return new_sequence, new_attn_mask
126
+
127
+ def truncate_sequence_and_mask(self, sequence, mask, max_len=512):
128
+ if sequence.size(1) > max_len:
129
+ print(
130
+ "The input sequence length to GPT-2 model is too long:",
131
+ sequence.size(1),
132
+ )
133
+ return sequence[:, :max_len], mask[:, :max_len]
134
+ else:
135
+ return sequence, mask
136
+
137
+ def get_input_sequence_and_mask(self, cond_dict):
138
+ input_embeds = None
139
+ input_embeds_attn_mask = None
140
+ for _id, sequence_key in enumerate(self.sequence_input_key):
141
+ assert sequence_key in cond_dict.keys(), (
142
+ "Invalid sequence key %s" % sequence_key
143
+ )
144
+ cond_embed = cond_dict[sequence_key]
145
+ if isinstance(cond_embed, list):
146
+ assert (
147
+ len(cond_embed) == 2
148
+ ), "The crossattn returned list should have length 2, including embed and attn_mask"
149
+ item_input_embeds, item_attn_mask = cond_embed
150
+
151
+ item_input_embeds = self.input_sequence_embed_linear[_id](
152
+ item_input_embeds
153
+ )
154
+
155
+ item_input_embeds, item_attn_mask = self.add_sos_eos_tokens(
156
+ _id, item_input_embeds, item_attn_mask
157
+ )
158
+
159
+ if input_embeds is None and input_embeds_attn_mask is None:
160
+ input_embeds, input_embeds_attn_mask = (
161
+ item_input_embeds,
162
+ item_attn_mask,
163
+ )
164
+ else:
165
+ input_embeds = torch.cat(
166
+ [input_embeds, item_input_embeds], dim=1
167
+ ) # The 1-st dimension is time steps
168
+ input_embeds_attn_mask = torch.cat(
169
+ [input_embeds_attn_mask, item_attn_mask], dim=1
170
+ ) # The 1-st dimension is time steps
171
+ else:
172
+ assert isinstance(cond_embed, torch.Tensor)
173
+ cond_embed = self.input_sequence_embed_linear[_id](cond_embed)
174
+ attn_mask = torch.ones((cond_embed.size(0), cond_embed.size(1))).to(
175
+ cond_embed.device
176
+ )
177
+
178
+ item_input_embeds, item_attn_mask = self.add_sos_eos_tokens(
179
+ _id, cond_embed, attn_mask
180
+ )
181
+
182
+ if input_embeds is None and input_embeds_attn_mask is None:
183
+ input_embeds, input_embeds_attn_mask = (
184
+ item_input_embeds,
185
+ item_attn_mask,
186
+ )
187
+ else:
188
+ input_embeds, input_embeds_attn_mask = torch.cat(
189
+ [input_embeds, item_input_embeds], dim=1
190
+ ), torch.cat([input_embeds_attn_mask, item_attn_mask], dim=1)
191
+
192
+ assert input_embeds is not None and input_embeds_attn_mask is not None
193
+
194
+ input_embeds, input_embeds_attn_mask = self.truncate_sequence_and_mask(
195
+ input_embeds, input_embeds_attn_mask, int(1024 - self.mae_token_num)
196
+ )
197
+ cond_sequence_end_time_idx = input_embeds.size(
198
+ 1
199
+ ) # The index that we start to collect the output embeds
200
+
201
+ return input_embeds, input_embeds_attn_mask, cond_sequence_end_time_idx
202
+
203
+ def warmup_step(self):
204
+ if self.initial_learning_rate is None:
205
+ self.initial_learning_rate = float(self.learning_rate)
206
+
207
+ # Only the first parameter group
208
+ if self.global_step <= 1000:
209
+ if self.global_step == 0:
210
+ print(
211
+ "Warming up learning rate start with %s"
212
+ % self.initial_learning_rate
213
+ )
214
+ self.trainer.optimizers[0].param_groups[0]["lr"] = (
215
+ self.global_step / 1000
216
+ ) * self.initial_learning_rate
217
+ else:
218
+ # TODO set learning rate here
219
+ self.trainer.optimizers[0].param_groups[0][
220
+ "lr"
221
+ ] = self.initial_learning_rate
222
+
223
+ def mask_target_sequence(self, target_embeds, target_embeds_attn_mask):
224
+ time_seq_mask = None
225
+ if self.target_tokens_mask_ratio > 1e-4:
226
+ batchsize, time_seq_len, embed_dim = target_embeds.size()
227
+ _, time_seq_len = target_embeds_attn_mask.size()
228
+ # Generate random mask
229
+ if self.random_mask_ratio:
230
+ mask_ratio = torch.rand(1).item() * self.target_tokens_mask_ratio
231
+ else:
232
+ mask_ratio = self.target_tokens_mask_ratio
233
+
234
+ time_seq_mask = (torch.rand((batchsize, time_seq_len)) > mask_ratio).to(
235
+ target_embeds.device
236
+ )
237
+ # Mask the target embedding
238
+ target_embeds = target_embeds * time_seq_mask.unsqueeze(-1)
239
+ target_embeds_attn_mask = target_embeds_attn_mask * time_seq_mask
240
+ return target_embeds, target_embeds_attn_mask, time_seq_mask
241
+
242
+ def generate_partial(self, batch, cond_dict=None, no_grad=False):
243
+ if cond_dict is None:
244
+ cond_dict = self.get_input(batch)
245
+
246
+ print("Generate partially prompted audio with in-context learning")
247
+ # self.model.train()
248
+ # assert self.model.training==True
249
+
250
+ target_embeds, target_embeds_attn_mask = (
251
+ cond_dict["crossattn_audiomae_pooled"][0],
252
+ cond_dict["crossattn_audiomae_pooled"][1],
253
+ )
254
+
255
+ target_time_steps = target_embeds.size(1)
256
+
257
+ (
258
+ input_embeds,
259
+ input_embeds_attn_mask,
260
+ cond_sequence_end_time_idx,
261
+ ) = self.get_input_sequence_and_mask(cond_dict)
262
+
263
+ model_input = torch.cat(
264
+ [input_embeds, target_embeds[:, : target_time_steps // 4, :]], dim=1
265
+ )
266
+ model_input_mask = torch.cat(
267
+ [
268
+ input_embeds_attn_mask,
269
+ target_embeds_attn_mask[:, : target_time_steps // 4],
270
+ ],
271
+ dim=1,
272
+ )
273
+
274
+ steps = self.mae_token_num
275
+
276
+ for _ in range(3 * steps // 4):
277
+ output = self.model(
278
+ inputs_embeds=model_input, attention_mask=model_input_mask
279
+ )["last_hidden_state"]
280
+ # Update the model input
281
+ model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
282
+ # Update the attention mask
283
+ attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to(
284
+ model_input.device
285
+ )
286
+ model_input_mask = torch.cat(
287
+ [model_input_mask, attention_mask_new_step], dim=1
288
+ )
289
+
290
+ output = model_input[:, cond_sequence_end_time_idx:]
291
+
292
+ return output, cond_dict
293
+
294
+ def generate(self, batch, cond_dict=None, no_grad=False):
295
+ if cond_dict is None:
296
+ cond_dict = self.get_input(batch)
297
+
298
+ # self.model.train()
299
+ # print("!!!!!!!!!!!!!train")
300
+
301
+ (
302
+ input_embeds,
303
+ input_embeds_attn_mask,
304
+ cond_sequence_end_time_idx,
305
+ ) = self.get_input_sequence_and_mask(cond_dict)
306
+ model_input = input_embeds
307
+ model_input_mask = input_embeds_attn_mask
308
+
309
+ steps = self.mae_token_num
310
+
311
+ for _ in range(steps):
312
+ output = self.model(
313
+ inputs_embeds=model_input, attention_mask=model_input_mask
314
+ )["last_hidden_state"]
315
+ # Update the model input
316
+ model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
317
+ # Update the attention mask
318
+ attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to(
319
+ model_input.device
320
+ )
321
+ model_input_mask = torch.cat(
322
+ [model_input_mask, attention_mask_new_step], dim=1
323
+ )
324
+
325
+ return model_input[:, cond_sequence_end_time_idx:], cond_dict
326
+
327
+ def get_input_item(self, batch, k):
328
+ fname, text, waveform, stft, fbank = (
329
+ batch["fname"],
330
+ batch["text"],
331
+ batch["waveform"],
332
+ batch["stft"],
333
+ batch["log_mel_spec"],
334
+ )
335
+ ret = {}
336
+
337
+ ret["fbank"] = (
338
+ fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float()
339
+ )
340
+ ret["stft"] = stft.to(memory_format=torch.contiguous_format).float()
341
+ # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float()
342
+ ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float()
343
+ ret["text"] = list(text)
344
+ ret["fname"] = fname
345
+
346
+ for key in batch.keys():
347
+ if key not in ret.keys():
348
+ ret[key] = batch[key]
349
+
350
+ return ret[k]
351
+
352
+ def get_input(self, batch):
353
+ cond_dict = {}
354
+ if len(self.cond_stage_model_metadata.keys()) > 0:
355
+ unconditional_cfg = False
356
+
357
+ for cond_model_key in self.cond_stage_model_metadata.keys():
358
+ cond_stage_key = self.cond_stage_model_metadata[cond_model_key][
359
+ "cond_stage_key"
360
+ ]
361
+
362
+ # if(not self.training):
363
+ # if(isinstance(self.cond_stage_models[self.cond_stage_model_metadata[cond_model_key]["model_idx"]], CLAPAudioEmbeddingClassifierFreev2)):
364
+ # assert cond_stage_key == "text" # CLAP model should use text for evaluation
365
+
366
+ # The original data for conditioning
367
+ xc = self.get_input_item(batch, cond_stage_key)
368
+ if type(xc) == torch.Tensor:
369
+ xc = xc.to(self.device)
370
+
371
+ c = self.get_learned_conditioning(
372
+ xc, key=cond_model_key, unconditional_cfg=unconditional_cfg
373
+ )
374
+ cond_dict[cond_model_key] = c
375
+
376
+ return cond_dict
377
+
378
+ def instantiate_cond_stage(self, config):
379
+ self.cond_stage_model_metadata = {}
380
+
381
+ for i, cond_model_key in enumerate(config.keys()):
382
+ model = instantiate_from_config(config[cond_model_key])
383
+ self.cond_stage_models.append(model)
384
+ self.cond_stage_model_metadata[cond_model_key] = {
385
+ "model_idx": i,
386
+ "cond_stage_key": config[cond_model_key]["cond_stage_key"],
387
+ "conditioning_key": config[cond_model_key]["conditioning_key"],
388
+ }
389
+
390
+ def get_learned_conditioning(self, c, key, unconditional_cfg):
391
+ assert key in self.cond_stage_model_metadata.keys()
392
+
393
+ # Classifier-free guidance
394
+ if not unconditional_cfg:
395
+ c = self.cond_stage_models[
396
+ self.cond_stage_model_metadata[key]["model_idx"]
397
+ ](c)
398
+ else:
399
+ if isinstance(c, torch.Tensor):
400
+ batchsize = c.size(0)
401
+ elif isinstance(c, list):
402
+ batchsize = len(c)
403
+ else:
404
+ raise NotImplementedError()
405
+ c = self.cond_stage_models[
406
+ self.cond_stage_model_metadata[key]["model_idx"]
407
+ ].get_unconditional_condition(batchsize)
408
+
409
+ return c
410
+
411
+ def initialize_param_check_toolkit(self):
412
+ self.tracked_steps = 0
413
+ self.param_dict = {}
414
+
415
+ def statistic_require_grad_tensor_number(self, module, name=None):
416
+ requires_grad_num = 0
417
+ total_num = 0
418
+ require_grad_tensor = None
419
+ for p in module.parameters():
420
+ if p.requires_grad:
421
+ requires_grad_num += 1
422
+ if require_grad_tensor is None:
423
+ require_grad_tensor = p
424
+ total_num += 1
425
+ print(
426
+ "Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)"
427
+ % (name, requires_grad_num, total_num, requires_grad_num / total_num)
428
+ )
429
+ return require_grad_tensor
audioldm2/audiomae_gen/utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class Prenet(nn.Module):
5
+ def __init__(self, in_dim, sizes=[256, 128], dropout_rate=0.5):
6
+ super(Prenet, self).__init__()
7
+ in_sizes = [in_dim] + sizes[:-1]
8
+ self.layers = nn.ModuleList(
9
+ [
10
+ nn.Linear(in_size, out_size)
11
+ for (in_size, out_size) in zip(in_sizes, sizes)
12
+ ]
13
+ )
14
+ self.relu = nn.ReLU()
15
+ self.dropout = nn.Dropout(dropout_rate)
16
+
17
+ def forward(self, inputs):
18
+ for linear in self.layers:
19
+ inputs = self.dropout(self.relu(linear(inputs)))
20
+ return inputs
21
+
22
+
23
+ if __name__ == "__main__":
24
+ model = Prenet(in_dim=128, sizes=[256, 256, 128])
25
+ import ipdb
26
+
27
+ ipdb.set_trace()
audioldm2/clap/__init__.py ADDED
File without changes
audioldm2/clap/open_clip/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .factory import (
2
+ list_models,
3
+ create_model,
4
+ create_model_and_transforms,
5
+ add_model_config,
6
+ )
7
+ from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
8
+ from .model import (
9
+ CLAP,
10
+ CLAPTextCfg,
11
+ CLAPVisionCfg,
12
+ CLAPAudioCfp,
13
+ convert_weights_to_fp16,
14
+ trace_model,
15
+ )
16
+ from .openai import load_openai_model, list_openai_models
17
+ from .pretrained import (
18
+ list_pretrained,
19
+ list_pretrained_tag_models,
20
+ list_pretrained_model_tags,
21
+ get_pretrained_url,
22
+ download_pretrained,
23
+ )
24
+ from .tokenizer import SimpleTokenizer, tokenize
25
+ from .transform import image_transform
audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
audioldm2/clap/open_clip/factory.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import re
5
+ from copy import deepcopy
6
+ from pathlib import Path
7
+
8
+ import torch
9
+
10
+ from .model import CLAP, convert_weights_to_fp16
11
+ from .openai import load_openai_model
12
+ from .pretrained import get_pretrained_url, download_pretrained
13
+ from .transform import image_transform
14
+
15
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
16
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
17
+
18
+
19
+ def _natural_key(string_):
20
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
21
+
22
+
23
+ def _rescan_model_configs():
24
+ global _MODEL_CONFIGS
25
+
26
+ config_ext = (".json",)
27
+ config_files = []
28
+ for config_path in _MODEL_CONFIG_PATHS:
29
+ if config_path.is_file() and config_path.suffix in config_ext:
30
+ config_files.append(config_path)
31
+ elif config_path.is_dir():
32
+ for ext in config_ext:
33
+ config_files.extend(config_path.glob(f"*{ext}"))
34
+
35
+ for cf in config_files:
36
+ if os.path.basename(cf)[0] == ".":
37
+ continue # Ignore hidden files
38
+
39
+ with open(cf, "r") as f:
40
+ model_cfg = json.load(f)
41
+ if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
42
+ _MODEL_CONFIGS[cf.stem] = model_cfg
43
+
44
+ _MODEL_CONFIGS = {
45
+ k: v
46
+ for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
47
+ }
48
+
49
+
50
+ _rescan_model_configs() # initial populate of model config registry
51
+
52
+
53
+ def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
54
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
55
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
56
+ state_dict = checkpoint["state_dict"]
57
+ else:
58
+ state_dict = checkpoint
59
+ if skip_params:
60
+ if next(iter(state_dict.items()))[0].startswith("module"):
61
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
62
+ # for k in state_dict:
63
+ # if k.startswith('transformer'):
64
+ # v = state_dict.pop(k)
65
+ # state_dict['text_branch.' + k[12:]] = v
66
+ return state_dict
67
+
68
+
69
+ def create_model(
70
+ amodel_name: str,
71
+ tmodel_name: str,
72
+ pretrained: str = "",
73
+ precision: str = "fp32",
74
+ device: torch.device = torch.device("cpu"),
75
+ jit: bool = False,
76
+ force_quick_gelu: bool = False,
77
+ openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
78
+ skip_params=True,
79
+ pretrained_audio: str = "",
80
+ pretrained_text: str = "",
81
+ enable_fusion: bool = False,
82
+ fusion_type: str = "None"
83
+ # pretrained_image: bool = False,
84
+ ):
85
+ amodel_name = amodel_name.replace(
86
+ "/", "-"
87
+ ) # for callers using old naming with / in ViT names
88
+ pretrained_orig = pretrained
89
+ pretrained = pretrained.lower()
90
+ if pretrained == "openai":
91
+ if amodel_name in _MODEL_CONFIGS:
92
+ logging.info(f"Loading {amodel_name} model config.")
93
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
94
+ else:
95
+ logging.error(
96
+ f"Model config for {amodel_name} not found; available models {list_models()}."
97
+ )
98
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
99
+
100
+ logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
101
+ # Hard Code in model name
102
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
103
+ model = load_openai_model(
104
+ "ViT-B-16",
105
+ model_cfg,
106
+ device=device,
107
+ jit=jit,
108
+ cache_dir=openai_model_cache_dir,
109
+ enable_fusion=enable_fusion,
110
+ fusion_type=fusion_type,
111
+ )
112
+ # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
113
+ if precision == "amp" or precision == "fp32":
114
+ model = model.float()
115
+ else:
116
+ if amodel_name in _MODEL_CONFIGS:
117
+ logging.info(f"Loading {amodel_name} model config.")
118
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
119
+ else:
120
+ logging.error(
121
+ f"Model config for {amodel_name} not found; available models {list_models()}."
122
+ )
123
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
124
+
125
+ if force_quick_gelu:
126
+ # override for use of QuickGELU on non-OpenAI transformer models
127
+ model_cfg["quick_gelu"] = True
128
+
129
+ # if pretrained_image:
130
+ # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
131
+ # # pretrained weight loading for timm models set via vision_cfg
132
+ # model_cfg['vision_cfg']['timm_model_pretrained'] = True
133
+ # else:
134
+ # assert False, 'pretrained image towers currently only supported for timm models'
135
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
136
+ model_cfg["enable_fusion"] = enable_fusion
137
+ model_cfg["fusion_type"] = fusion_type
138
+ model = CLAP(**model_cfg)
139
+
140
+ if pretrained:
141
+ checkpoint_path = ""
142
+ url = get_pretrained_url(amodel_name, pretrained)
143
+ if url:
144
+ checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
145
+ elif os.path.exists(pretrained_orig):
146
+ checkpoint_path = pretrained_orig
147
+ if checkpoint_path:
148
+ logging.info(
149
+ f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
150
+ )
151
+ ckpt = load_state_dict(checkpoint_path, skip_params=True)
152
+ model.load_state_dict(ckpt)
153
+ param_names = [n for n, p in model.named_parameters()]
154
+ # for n in param_names:
155
+ # print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
156
+ else:
157
+ logging.warning(
158
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
159
+ )
160
+ raise RuntimeError(
161
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
162
+ )
163
+
164
+ if pretrained_audio:
165
+ if amodel_name.startswith("PANN"):
166
+ if "Cnn14_mAP" in pretrained_audio: # official checkpoint
167
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
168
+ audio_ckpt = audio_ckpt["model"]
169
+ keys = list(audio_ckpt.keys())
170
+ for key in keys:
171
+ if (
172
+ "spectrogram_extractor" not in key
173
+ and "logmel_extractor" not in key
174
+ ):
175
+ v = audio_ckpt.pop(key)
176
+ audio_ckpt["audio_branch." + key] = v
177
+ elif os.path.basename(pretrained_audio).startswith(
178
+ "PANN"
179
+ ): # checkpoint trained via HTSAT codebase
180
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
181
+ audio_ckpt = audio_ckpt["state_dict"]
182
+ keys = list(audio_ckpt.keys())
183
+ for key in keys:
184
+ if key.startswith("sed_model"):
185
+ v = audio_ckpt.pop(key)
186
+ audio_ckpt["audio_branch." + key[10:]] = v
187
+ elif os.path.basename(pretrained_audio).startswith(
188
+ "finetuned"
189
+ ): # checkpoint trained via linear probe codebase
190
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
191
+ else:
192
+ raise ValueError("Unknown audio checkpoint")
193
+ elif amodel_name.startswith("HTSAT"):
194
+ if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
195
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
196
+ audio_ckpt = audio_ckpt["state_dict"]
197
+ keys = list(audio_ckpt.keys())
198
+ for key in keys:
199
+ if key.startswith("sed_model") and (
200
+ "spectrogram_extractor" not in key
201
+ and "logmel_extractor" not in key
202
+ ):
203
+ v = audio_ckpt.pop(key)
204
+ audio_ckpt["audio_branch." + key[10:]] = v
205
+ elif os.path.basename(pretrained_audio).startswith(
206
+ "HTSAT"
207
+ ): # checkpoint trained via HTSAT codebase
208
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
209
+ audio_ckpt = audio_ckpt["state_dict"]
210
+ keys = list(audio_ckpt.keys())
211
+ for key in keys:
212
+ if key.startswith("sed_model"):
213
+ v = audio_ckpt.pop(key)
214
+ audio_ckpt["audio_branch." + key[10:]] = v
215
+ elif os.path.basename(pretrained_audio).startswith(
216
+ "finetuned"
217
+ ): # checkpoint trained via linear probe codebase
218
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
219
+ else:
220
+ raise ValueError("Unknown audio checkpoint")
221
+ else:
222
+ raise f"this audio encoder pretrained checkpoint is not support"
223
+
224
+ model.load_state_dict(audio_ckpt, strict=False)
225
+ logging.info(
226
+ f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
227
+ )
228
+ param_names = [n for n, p in model.named_parameters()]
229
+ for n in param_names:
230
+ print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
231
+
232
+ model.to(device=device)
233
+ if precision == "fp16":
234
+ assert device.type != "cpu"
235
+ convert_weights_to_fp16(model)
236
+
237
+ if jit:
238
+ model = torch.jit.script(model)
239
+
240
+ return model, model_cfg
241
+
242
+
243
+ def create_model_and_transforms(
244
+ model_name: str,
245
+ pretrained: str = "",
246
+ precision: str = "fp32",
247
+ device: torch.device = torch.device("cpu"),
248
+ jit: bool = False,
249
+ force_quick_gelu: bool = False,
250
+ # pretrained_image: bool = False,
251
+ ):
252
+ model = create_model(
253
+ model_name,
254
+ pretrained,
255
+ precision,
256
+ device,
257
+ jit,
258
+ force_quick_gelu=force_quick_gelu,
259
+ # pretrained_image=pretrained_image
260
+ )
261
+ preprocess_train = image_transform(model.visual.image_size, is_train=True)
262
+ preprocess_val = image_transform(model.visual.image_size, is_train=False)
263
+ return model, preprocess_train, preprocess_val
264
+
265
+
266
+ def list_models():
267
+ """enumerate available model architectures based on config files"""
268
+ return list(_MODEL_CONFIGS.keys())
269
+
270
+
271
+ def add_model_config(path):
272
+ """add model config path or file and update registry"""
273
+ if not isinstance(path, Path):
274
+ path = Path(path)
275
+ _MODEL_CONFIG_PATHS.append(path)
276
+ _rescan_model_configs()
audioldm2/clap/open_clip/feature_fusion.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feature Fusion for Varible-Length Data Processing
3
+ AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
4
+ According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class DAF(nn.Module):
12
+ """
13
+ 直接相加 DirectAddFuse
14
+ """
15
+
16
+ def __init__(self):
17
+ super(DAF, self).__init__()
18
+
19
+ def forward(self, x, residual):
20
+ return x + residual
21
+
22
+
23
+ class iAFF(nn.Module):
24
+ """
25
+ 多特征融合 iAFF
26
+ """
27
+
28
+ def __init__(self, channels=64, r=4, type="2D"):
29
+ super(iAFF, self).__init__()
30
+ inter_channels = int(channels // r)
31
+
32
+ if type == "1D":
33
+ # 本地注意力
34
+ self.local_att = nn.Sequential(
35
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
36
+ nn.BatchNorm1d(inter_channels),
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
39
+ nn.BatchNorm1d(channels),
40
+ )
41
+
42
+ # 全局注意力
43
+ self.global_att = nn.Sequential(
44
+ nn.AdaptiveAvgPool1d(1),
45
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
46
+ nn.BatchNorm1d(inter_channels),
47
+ nn.ReLU(inplace=True),
48
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
49
+ nn.BatchNorm1d(channels),
50
+ )
51
+
52
+ # 第二次本地注意力
53
+ self.local_att2 = nn.Sequential(
54
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
55
+ nn.BatchNorm1d(inter_channels),
56
+ nn.ReLU(inplace=True),
57
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
58
+ nn.BatchNorm1d(channels),
59
+ )
60
+ # 第二次全局注意力
61
+ self.global_att2 = nn.Sequential(
62
+ nn.AdaptiveAvgPool1d(1),
63
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
64
+ nn.BatchNorm1d(inter_channels),
65
+ nn.ReLU(inplace=True),
66
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
67
+ nn.BatchNorm1d(channels),
68
+ )
69
+ elif type == "2D":
70
+ # 本地注意力
71
+ self.local_att = nn.Sequential(
72
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
73
+ nn.BatchNorm2d(inter_channels),
74
+ nn.ReLU(inplace=True),
75
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
76
+ nn.BatchNorm2d(channels),
77
+ )
78
+
79
+ # 全局注意力
80
+ self.global_att = nn.Sequential(
81
+ nn.AdaptiveAvgPool2d(1),
82
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
83
+ nn.BatchNorm2d(inter_channels),
84
+ nn.ReLU(inplace=True),
85
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
86
+ nn.BatchNorm2d(channels),
87
+ )
88
+
89
+ # 第二次本地注意力
90
+ self.local_att2 = nn.Sequential(
91
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92
+ nn.BatchNorm2d(inter_channels),
93
+ nn.ReLU(inplace=True),
94
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95
+ nn.BatchNorm2d(channels),
96
+ )
97
+ # 第二次全局注意力
98
+ self.global_att2 = nn.Sequential(
99
+ nn.AdaptiveAvgPool2d(1),
100
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101
+ nn.BatchNorm2d(inter_channels),
102
+ nn.ReLU(inplace=True),
103
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104
+ nn.BatchNorm2d(channels),
105
+ )
106
+ else:
107
+ raise f"the type is not supported"
108
+
109
+ self.sigmoid = nn.Sigmoid()
110
+
111
+ def forward(self, x, residual):
112
+ flag = False
113
+ xa = x + residual
114
+ if xa.size(0) == 1:
115
+ xa = torch.cat([xa, xa], dim=0)
116
+ flag = True
117
+ xl = self.local_att(xa)
118
+ xg = self.global_att(xa)
119
+ xlg = xl + xg
120
+ wei = self.sigmoid(xlg)
121
+ xi = x * wei + residual * (1 - wei)
122
+
123
+ xl2 = self.local_att2(xi)
124
+ xg2 = self.global_att(xi)
125
+ xlg2 = xl2 + xg2
126
+ wei2 = self.sigmoid(xlg2)
127
+ xo = x * wei2 + residual * (1 - wei2)
128
+ if flag:
129
+ xo = xo[0].unsqueeze(0)
130
+ return xo
131
+
132
+
133
+ class AFF(nn.Module):
134
+ """
135
+ 多特征融合 AFF
136
+ """
137
+
138
+ def __init__(self, channels=64, r=4, type="2D"):
139
+ super(AFF, self).__init__()
140
+ inter_channels = int(channels // r)
141
+
142
+ if type == "1D":
143
+ self.local_att = nn.Sequential(
144
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
145
+ nn.BatchNorm1d(inter_channels),
146
+ nn.ReLU(inplace=True),
147
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
148
+ nn.BatchNorm1d(channels),
149
+ )
150
+ self.global_att = nn.Sequential(
151
+ nn.AdaptiveAvgPool1d(1),
152
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
153
+ nn.BatchNorm1d(inter_channels),
154
+ nn.ReLU(inplace=True),
155
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
156
+ nn.BatchNorm1d(channels),
157
+ )
158
+ elif type == "2D":
159
+ self.local_att = nn.Sequential(
160
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
161
+ nn.BatchNorm2d(inter_channels),
162
+ nn.ReLU(inplace=True),
163
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
164
+ nn.BatchNorm2d(channels),
165
+ )
166
+ self.global_att = nn.Sequential(
167
+ nn.AdaptiveAvgPool2d(1),
168
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
169
+ nn.BatchNorm2d(inter_channels),
170
+ nn.ReLU(inplace=True),
171
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
172
+ nn.BatchNorm2d(channels),
173
+ )
174
+ else:
175
+ raise f"the type is not supported."
176
+
177
+ self.sigmoid = nn.Sigmoid()
178
+
179
+ def forward(self, x, residual):
180
+ flag = False
181
+ xa = x + residual
182
+ if xa.size(0) == 1:
183
+ xa = torch.cat([xa, xa], dim=0)
184
+ flag = True
185
+ xl = self.local_att(xa)
186
+ xg = self.global_att(xa)
187
+ xlg = xl + xg
188
+ wei = self.sigmoid(xlg)
189
+ xo = 2 * x * wei + 2 * residual * (1 - wei)
190
+ if flag:
191
+ xo = xo[0].unsqueeze(0)
192
+ return xo
audioldm2/clap/open_clip/htsat.py ADDED
@@ -0,0 +1,1304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
+ # knutchen@ucsd.edu
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # Some layers designed on the model
5
+ # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
6
+ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from itertools import repeat
11
+ import collections.abc
12
+ import math
13
+ import warnings
14
+
15
+ from torch.nn.init import _calculate_fan_in_and_fan_out
16
+ import torch.utils.checkpoint as checkpoint
17
+
18
+ import random
19
+
20
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
21
+ from torchlibrosa.augmentation import SpecAugmentation
22
+
23
+ from itertools import repeat
24
+ from .utils import do_mixup, interpolate
25
+
26
+ from .feature_fusion import iAFF, AFF, DAF
27
+
28
+
29
+ # from PyTorch internals
30
+ def _ntuple(n):
31
+ def parse(x):
32
+ if isinstance(x, collections.abc.Iterable):
33
+ return x
34
+ return tuple(repeat(x, n))
35
+
36
+ return parse
37
+
38
+
39
+ to_1tuple = _ntuple(1)
40
+ to_2tuple = _ntuple(2)
41
+ to_3tuple = _ntuple(3)
42
+ to_4tuple = _ntuple(4)
43
+ to_ntuple = _ntuple
44
+
45
+
46
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
47
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
48
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
49
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
50
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
51
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
52
+ 'survival rate' as the argument.
53
+ """
54
+ if drop_prob == 0.0 or not training:
55
+ return x
56
+ keep_prob = 1 - drop_prob
57
+ shape = (x.shape[0],) + (1,) * (
58
+ x.ndim - 1
59
+ ) # work with diff dim tensors, not just 2D ConvNets
60
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
61
+ random_tensor.floor_() # binarize
62
+ output = x.div(keep_prob) * random_tensor
63
+ return output
64
+
65
+
66
+ class DropPath(nn.Module):
67
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
68
+
69
+ def __init__(self, drop_prob=None):
70
+ super(DropPath, self).__init__()
71
+ self.drop_prob = drop_prob
72
+
73
+ def forward(self, x):
74
+ return drop_path(x, self.drop_prob, self.training)
75
+
76
+
77
+ class PatchEmbed(nn.Module):
78
+ """2D Image to Patch Embedding"""
79
+
80
+ def __init__(
81
+ self,
82
+ img_size=224,
83
+ patch_size=16,
84
+ in_chans=3,
85
+ embed_dim=768,
86
+ norm_layer=None,
87
+ flatten=True,
88
+ patch_stride=16,
89
+ enable_fusion=False,
90
+ fusion_type="None",
91
+ ):
92
+ super().__init__()
93
+ img_size = to_2tuple(img_size)
94
+ patch_size = to_2tuple(patch_size)
95
+ patch_stride = to_2tuple(patch_stride)
96
+ self.img_size = img_size
97
+ self.patch_size = patch_size
98
+ self.patch_stride = patch_stride
99
+ self.grid_size = (
100
+ img_size[0] // patch_stride[0],
101
+ img_size[1] // patch_stride[1],
102
+ )
103
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
104
+ self.flatten = flatten
105
+ self.in_chans = in_chans
106
+ self.embed_dim = embed_dim
107
+
108
+ self.enable_fusion = enable_fusion
109
+ self.fusion_type = fusion_type
110
+
111
+ padding = (
112
+ (patch_size[0] - patch_stride[0]) // 2,
113
+ (patch_size[1] - patch_stride[1]) // 2,
114
+ )
115
+
116
+ if (self.enable_fusion) and (self.fusion_type == "channel_map"):
117
+ self.proj = nn.Conv2d(
118
+ in_chans * 4,
119
+ embed_dim,
120
+ kernel_size=patch_size,
121
+ stride=patch_stride,
122
+ padding=padding,
123
+ )
124
+ else:
125
+ self.proj = nn.Conv2d(
126
+ in_chans,
127
+ embed_dim,
128
+ kernel_size=patch_size,
129
+ stride=patch_stride,
130
+ padding=padding,
131
+ )
132
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
133
+
134
+ if (self.enable_fusion) and (
135
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
136
+ ):
137
+ self.mel_conv2d = nn.Conv2d(
138
+ in_chans,
139
+ embed_dim,
140
+ kernel_size=(patch_size[0], patch_size[1] * 3),
141
+ stride=(patch_stride[0], patch_stride[1] * 3),
142
+ padding=padding,
143
+ )
144
+ if self.fusion_type == "daf_2d":
145
+ self.fusion_model = DAF()
146
+ elif self.fusion_type == "aff_2d":
147
+ self.fusion_model = AFF(channels=embed_dim, type="2D")
148
+ elif self.fusion_type == "iaff_2d":
149
+ self.fusion_model = iAFF(channels=embed_dim, type="2D")
150
+
151
+ def forward(self, x, longer_idx=None):
152
+ if (self.enable_fusion) and (
153
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
154
+ ):
155
+ global_x = x[:, 0:1, :, :]
156
+
157
+ # global processing
158
+ B, C, H, W = global_x.shape
159
+ assert (
160
+ H == self.img_size[0] and W == self.img_size[1]
161
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
162
+ global_x = self.proj(global_x)
163
+ TW = global_x.size(-1)
164
+ if len(longer_idx) > 0:
165
+ # local processing
166
+ local_x = x[longer_idx, 1:, :, :].contiguous()
167
+ B, C, H, W = local_x.shape
168
+ local_x = local_x.view(B * C, 1, H, W)
169
+ local_x = self.mel_conv2d(local_x)
170
+ local_x = local_x.view(
171
+ B, C, local_x.size(1), local_x.size(2), local_x.size(3)
172
+ )
173
+ local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3)
174
+ TB, TC, TH, _ = local_x.size()
175
+ if local_x.size(-1) < TW:
176
+ local_x = torch.cat(
177
+ [
178
+ local_x,
179
+ torch.zeros(
180
+ (TB, TC, TH, TW - local_x.size(-1)),
181
+ device=global_x.device,
182
+ ),
183
+ ],
184
+ dim=-1,
185
+ )
186
+ else:
187
+ local_x = local_x[:, :, :, :TW]
188
+
189
+ global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x)
190
+ x = global_x
191
+ else:
192
+ B, C, H, W = x.shape
193
+ assert (
194
+ H == self.img_size[0] and W == self.img_size[1]
195
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
196
+ x = self.proj(x)
197
+
198
+ if self.flatten:
199
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
200
+ x = self.norm(x)
201
+ return x
202
+
203
+
204
+ class Mlp(nn.Module):
205
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
206
+
207
+ def __init__(
208
+ self,
209
+ in_features,
210
+ hidden_features=None,
211
+ out_features=None,
212
+ act_layer=nn.GELU,
213
+ drop=0.0,
214
+ ):
215
+ super().__init__()
216
+ out_features = out_features or in_features
217
+ hidden_features = hidden_features or in_features
218
+ self.fc1 = nn.Linear(in_features, hidden_features)
219
+ self.act = act_layer()
220
+ self.fc2 = nn.Linear(hidden_features, out_features)
221
+ self.drop = nn.Dropout(drop)
222
+
223
+ def forward(self, x):
224
+ x = self.fc1(x)
225
+ x = self.act(x)
226
+ x = self.drop(x)
227
+ x = self.fc2(x)
228
+ x = self.drop(x)
229
+ return x
230
+
231
+
232
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
233
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
234
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
235
+ def norm_cdf(x):
236
+ # Computes standard normal cumulative distribution function
237
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
238
+
239
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
240
+ warnings.warn(
241
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
242
+ "The distribution of values may be incorrect.",
243
+ stacklevel=2,
244
+ )
245
+
246
+ with torch.no_grad():
247
+ # Values are generated by using a truncated uniform distribution and
248
+ # then using the inverse CDF for the normal distribution.
249
+ # Get upper and lower cdf values
250
+ l = norm_cdf((a - mean) / std)
251
+ u = norm_cdf((b - mean) / std)
252
+
253
+ # Uniformly fill tensor with values from [l, u], then translate to
254
+ # [2l-1, 2u-1].
255
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
256
+
257
+ # Use inverse cdf transform for normal distribution to get truncated
258
+ # standard normal
259
+ tensor.erfinv_()
260
+
261
+ # Transform to proper mean, std
262
+ tensor.mul_(std * math.sqrt(2.0))
263
+ tensor.add_(mean)
264
+
265
+ # Clamp to ensure it's in the proper range
266
+ tensor.clamp_(min=a, max=b)
267
+ return tensor
268
+
269
+
270
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
271
+ # type: (Tensor, float, float, float, float) -> Tensor
272
+ r"""Fills the input Tensor with values drawn from a truncated
273
+ normal distribution. The values are effectively drawn from the
274
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
275
+ with values outside :math:`[a, b]` redrawn until they are within
276
+ the bounds. The method used for generating the random values works
277
+ best when :math:`a \leq \text{mean} \leq b`.
278
+ Args:
279
+ tensor: an n-dimensional `torch.Tensor`
280
+ mean: the mean of the normal distribution
281
+ std: the standard deviation of the normal distribution
282
+ a: the minimum cutoff value
283
+ b: the maximum cutoff value
284
+ Examples:
285
+ >>> w = torch.empty(3, 5)
286
+ >>> nn.init.trunc_normal_(w)
287
+ """
288
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
289
+
290
+
291
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
292
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
293
+ if mode == "fan_in":
294
+ denom = fan_in
295
+ elif mode == "fan_out":
296
+ denom = fan_out
297
+ elif mode == "fan_avg":
298
+ denom = (fan_in + fan_out) / 2
299
+
300
+ variance = scale / denom
301
+
302
+ if distribution == "truncated_normal":
303
+ # constant is stddev of standard normal truncated to (-2, 2)
304
+ trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
305
+ elif distribution == "normal":
306
+ tensor.normal_(std=math.sqrt(variance))
307
+ elif distribution == "uniform":
308
+ bound = math.sqrt(3 * variance)
309
+ tensor.uniform_(-bound, bound)
310
+ else:
311
+ raise ValueError(f"invalid distribution {distribution}")
312
+
313
+
314
+ def lecun_normal_(tensor):
315
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
316
+
317
+
318
+ def window_partition(x, window_size):
319
+ """
320
+ Args:
321
+ x: (B, H, W, C)
322
+ window_size (int): window size
323
+ Returns:
324
+ windows: (num_windows*B, window_size, window_size, C)
325
+ """
326
+ B, H, W, C = x.shape
327
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
328
+ windows = (
329
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
330
+ )
331
+ return windows
332
+
333
+
334
+ def window_reverse(windows, window_size, H, W):
335
+ """
336
+ Args:
337
+ windows: (num_windows*B, window_size, window_size, C)
338
+ window_size (int): Window size
339
+ H (int): Height of image
340
+ W (int): Width of image
341
+ Returns:
342
+ x: (B, H, W, C)
343
+ """
344
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
345
+ x = windows.view(
346
+ B, H // window_size, W // window_size, window_size, window_size, -1
347
+ )
348
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
349
+ return x
350
+
351
+
352
+ class WindowAttention(nn.Module):
353
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
354
+ It supports both of shifted and non-shifted window.
355
+ Args:
356
+ dim (int): Number of input channels.
357
+ window_size (tuple[int]): The height and width of the window.
358
+ num_heads (int): Number of attention heads.
359
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
360
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
361
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
362
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
363
+ """
364
+
365
+ def __init__(
366
+ self,
367
+ dim,
368
+ window_size,
369
+ num_heads,
370
+ qkv_bias=True,
371
+ qk_scale=None,
372
+ attn_drop=0.0,
373
+ proj_drop=0.0,
374
+ ):
375
+ super().__init__()
376
+ self.dim = dim
377
+ self.window_size = window_size # Wh, Ww
378
+ self.num_heads = num_heads
379
+ head_dim = dim // num_heads
380
+ self.scale = qk_scale or head_dim**-0.5
381
+
382
+ # define a parameter table of relative position bias
383
+ self.relative_position_bias_table = nn.Parameter(
384
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
385
+ ) # 2*Wh-1 * 2*Ww-1, nH
386
+
387
+ # get pair-wise relative position index for each token inside the window
388
+ coords_h = torch.arange(self.window_size[0])
389
+ coords_w = torch.arange(self.window_size[1])
390
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
391
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
392
+ relative_coords = (
393
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
394
+ ) # 2, Wh*Ww, Wh*Ww
395
+ relative_coords = relative_coords.permute(
396
+ 1, 2, 0
397
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
398
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
399
+ relative_coords[:, :, 1] += self.window_size[1] - 1
400
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
401
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
402
+ self.register_buffer("relative_position_index", relative_position_index)
403
+
404
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
405
+ self.attn_drop = nn.Dropout(attn_drop)
406
+ self.proj = nn.Linear(dim, dim)
407
+ self.proj_drop = nn.Dropout(proj_drop)
408
+
409
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
410
+ self.softmax = nn.Softmax(dim=-1)
411
+
412
+ def forward(self, x, mask=None):
413
+ """
414
+ Args:
415
+ x: input features with shape of (num_windows*B, N, C)
416
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
417
+ """
418
+ B_, N, C = x.shape
419
+ qkv = (
420
+ self.qkv(x)
421
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
422
+ .permute(2, 0, 3, 1, 4)
423
+ )
424
+ q, k, v = (
425
+ qkv[0],
426
+ qkv[1],
427
+ qkv[2],
428
+ ) # make torchscript happy (cannot use tensor as tuple)
429
+
430
+ q = q * self.scale
431
+ attn = q @ k.transpose(-2, -1)
432
+
433
+ relative_position_bias = self.relative_position_bias_table[
434
+ self.relative_position_index.view(-1)
435
+ ].view(
436
+ self.window_size[0] * self.window_size[1],
437
+ self.window_size[0] * self.window_size[1],
438
+ -1,
439
+ ) # Wh*Ww,Wh*Ww,nH
440
+ relative_position_bias = relative_position_bias.permute(
441
+ 2, 0, 1
442
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
443
+ attn = attn + relative_position_bias.unsqueeze(0)
444
+
445
+ if mask is not None:
446
+ nW = mask.shape[0]
447
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
448
+ 1
449
+ ).unsqueeze(0)
450
+ attn = attn.view(-1, self.num_heads, N, N)
451
+ attn = self.softmax(attn)
452
+ else:
453
+ attn = self.softmax(attn)
454
+
455
+ attn = self.attn_drop(attn)
456
+
457
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
458
+ x = self.proj(x)
459
+ x = self.proj_drop(x)
460
+ return x, attn
461
+
462
+ def extra_repr(self):
463
+ return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
464
+
465
+
466
+ # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
467
+ class SwinTransformerBlock(nn.Module):
468
+ r"""Swin Transformer Block.
469
+ Args:
470
+ dim (int): Number of input channels.
471
+ input_resolution (tuple[int]): Input resulotion.
472
+ num_heads (int): Number of attention heads.
473
+ window_size (int): Window size.
474
+ shift_size (int): Shift size for SW-MSA.
475
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
476
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
477
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
478
+ drop (float, optional): Dropout rate. Default: 0.0
479
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
480
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
481
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
482
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
483
+ """
484
+
485
+ def __init__(
486
+ self,
487
+ dim,
488
+ input_resolution,
489
+ num_heads,
490
+ window_size=7,
491
+ shift_size=0,
492
+ mlp_ratio=4.0,
493
+ qkv_bias=True,
494
+ qk_scale=None,
495
+ drop=0.0,
496
+ attn_drop=0.0,
497
+ drop_path=0.0,
498
+ act_layer=nn.GELU,
499
+ norm_layer=nn.LayerNorm,
500
+ norm_before_mlp="ln",
501
+ ):
502
+ super().__init__()
503
+ self.dim = dim
504
+ self.input_resolution = input_resolution
505
+ self.num_heads = num_heads
506
+ self.window_size = window_size
507
+ self.shift_size = shift_size
508
+ self.mlp_ratio = mlp_ratio
509
+ self.norm_before_mlp = norm_before_mlp
510
+ if min(self.input_resolution) <= self.window_size:
511
+ # if window size is larger than input resolution, we don't partition windows
512
+ self.shift_size = 0
513
+ self.window_size = min(self.input_resolution)
514
+ assert (
515
+ 0 <= self.shift_size < self.window_size
516
+ ), "shift_size must in 0-window_size"
517
+
518
+ self.norm1 = norm_layer(dim)
519
+ self.attn = WindowAttention(
520
+ dim,
521
+ window_size=to_2tuple(self.window_size),
522
+ num_heads=num_heads,
523
+ qkv_bias=qkv_bias,
524
+ qk_scale=qk_scale,
525
+ attn_drop=attn_drop,
526
+ proj_drop=drop,
527
+ )
528
+
529
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
530
+ if self.norm_before_mlp == "ln":
531
+ self.norm2 = nn.LayerNorm(dim)
532
+ elif self.norm_before_mlp == "bn":
533
+ self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(
534
+ 1, 2
535
+ )
536
+ else:
537
+ raise NotImplementedError
538
+ mlp_hidden_dim = int(dim * mlp_ratio)
539
+ self.mlp = Mlp(
540
+ in_features=dim,
541
+ hidden_features=mlp_hidden_dim,
542
+ act_layer=act_layer,
543
+ drop=drop,
544
+ )
545
+
546
+ if self.shift_size > 0:
547
+ # calculate attention mask for SW-MSA
548
+ H, W = self.input_resolution
549
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
550
+ h_slices = (
551
+ slice(0, -self.window_size),
552
+ slice(-self.window_size, -self.shift_size),
553
+ slice(-self.shift_size, None),
554
+ )
555
+ w_slices = (
556
+ slice(0, -self.window_size),
557
+ slice(-self.window_size, -self.shift_size),
558
+ slice(-self.shift_size, None),
559
+ )
560
+ cnt = 0
561
+ for h in h_slices:
562
+ for w in w_slices:
563
+ img_mask[:, h, w, :] = cnt
564
+ cnt += 1
565
+
566
+ mask_windows = window_partition(
567
+ img_mask, self.window_size
568
+ ) # nW, window_size, window_size, 1
569
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
570
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
571
+ attn_mask = attn_mask.masked_fill(
572
+ attn_mask != 0, float(-100.0)
573
+ ).masked_fill(attn_mask == 0, float(0.0))
574
+ else:
575
+ attn_mask = None
576
+
577
+ self.register_buffer("attn_mask", attn_mask)
578
+
579
+ def forward(self, x):
580
+ # pdb.set_trace()
581
+ H, W = self.input_resolution
582
+ # print("H: ", H)
583
+ # print("W: ", W)
584
+ # pdb.set_trace()
585
+ B, L, C = x.shape
586
+ # assert L == H * W, "input feature has wrong size"
587
+
588
+ shortcut = x
589
+ x = self.norm1(x)
590
+ x = x.view(B, H, W, C)
591
+
592
+ # cyclic shift
593
+ if self.shift_size > 0:
594
+ shifted_x = torch.roll(
595
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
596
+ )
597
+ else:
598
+ shifted_x = x
599
+
600
+ # partition windows
601
+ x_windows = window_partition(
602
+ shifted_x, self.window_size
603
+ ) # nW*B, window_size, window_size, C
604
+ x_windows = x_windows.view(
605
+ -1, self.window_size * self.window_size, C
606
+ ) # nW*B, window_size*window_size, C
607
+
608
+ # W-MSA/SW-MSA
609
+ attn_windows, attn = self.attn(
610
+ x_windows, mask=self.attn_mask
611
+ ) # nW*B, window_size*window_size, C
612
+
613
+ # merge windows
614
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
615
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
616
+
617
+ # reverse cyclic shift
618
+ if self.shift_size > 0:
619
+ x = torch.roll(
620
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
621
+ )
622
+ else:
623
+ x = shifted_x
624
+ x = x.view(B, H * W, C)
625
+
626
+ # FFN
627
+ x = shortcut + self.drop_path(x)
628
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
629
+
630
+ return x, attn
631
+
632
+ def extra_repr(self):
633
+ return (
634
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
635
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
636
+ )
637
+
638
+
639
+ class PatchMerging(nn.Module):
640
+ r"""Patch Merging Layer.
641
+ Args:
642
+ input_resolution (tuple[int]): Resolution of input feature.
643
+ dim (int): Number of input channels.
644
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
645
+ """
646
+
647
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
648
+ super().__init__()
649
+ self.input_resolution = input_resolution
650
+ self.dim = dim
651
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
652
+ self.norm = norm_layer(4 * dim)
653
+
654
+ def forward(self, x):
655
+ """
656
+ x: B, H*W, C
657
+ """
658
+ H, W = self.input_resolution
659
+ B, L, C = x.shape
660
+ assert L == H * W, "input feature has wrong size"
661
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
662
+
663
+ x = x.view(B, H, W, C)
664
+
665
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
666
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
667
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
668
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
669
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
670
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
671
+
672
+ x = self.norm(x)
673
+ x = self.reduction(x)
674
+
675
+ return x
676
+
677
+ def extra_repr(self):
678
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
679
+
680
+
681
+ class BasicLayer(nn.Module):
682
+ """A basic Swin Transformer layer for one stage.
683
+ Args:
684
+ dim (int): Number of input channels.
685
+ input_resolution (tuple[int]): Input resolution.
686
+ depth (int): Number of blocks.
687
+ num_heads (int): Number of attention heads.
688
+ window_size (int): Local window size.
689
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
690
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
691
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
692
+ drop (float, optional): Dropout rate. Default: 0.0
693
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
694
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
695
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
696
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
697
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
698
+ """
699
+
700
+ def __init__(
701
+ self,
702
+ dim,
703
+ input_resolution,
704
+ depth,
705
+ num_heads,
706
+ window_size,
707
+ mlp_ratio=4.0,
708
+ qkv_bias=True,
709
+ qk_scale=None,
710
+ drop=0.0,
711
+ attn_drop=0.0,
712
+ drop_path=0.0,
713
+ norm_layer=nn.LayerNorm,
714
+ downsample=None,
715
+ use_checkpoint=False,
716
+ norm_before_mlp="ln",
717
+ ):
718
+ super().__init__()
719
+ self.dim = dim
720
+ self.input_resolution = input_resolution
721
+ self.depth = depth
722
+ self.use_checkpoint = use_checkpoint
723
+
724
+ # build blocks
725
+ self.blocks = nn.ModuleList(
726
+ [
727
+ SwinTransformerBlock(
728
+ dim=dim,
729
+ input_resolution=input_resolution,
730
+ num_heads=num_heads,
731
+ window_size=window_size,
732
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
733
+ mlp_ratio=mlp_ratio,
734
+ qkv_bias=qkv_bias,
735
+ qk_scale=qk_scale,
736
+ drop=drop,
737
+ attn_drop=attn_drop,
738
+ drop_path=drop_path[i]
739
+ if isinstance(drop_path, list)
740
+ else drop_path,
741
+ norm_layer=norm_layer,
742
+ norm_before_mlp=norm_before_mlp,
743
+ )
744
+ for i in range(depth)
745
+ ]
746
+ )
747
+
748
+ # patch merging layer
749
+ if downsample is not None:
750
+ self.downsample = downsample(
751
+ input_resolution, dim=dim, norm_layer=norm_layer
752
+ )
753
+ else:
754
+ self.downsample = None
755
+
756
+ def forward(self, x):
757
+ attns = []
758
+ for blk in self.blocks:
759
+ if self.use_checkpoint:
760
+ x = checkpoint.checkpoint(blk, x)
761
+ else:
762
+ x, attn = blk(x)
763
+ if not self.training:
764
+ attns.append(attn.unsqueeze(0))
765
+ if self.downsample is not None:
766
+ x = self.downsample(x)
767
+ if not self.training:
768
+ attn = torch.cat(attns, dim=0)
769
+ attn = torch.mean(attn, dim=0)
770
+ return x, attn
771
+
772
+ def extra_repr(self):
773
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
774
+
775
+
776
+ # The Core of HTSAT
777
+ class HTSAT_Swin_Transformer(nn.Module):
778
+ r"""HTSAT based on the Swin Transformer
779
+ Args:
780
+ spec_size (int | tuple(int)): Input Spectrogram size. Default 256
781
+ patch_size (int | tuple(int)): Patch size. Default: 4
782
+ path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
783
+ in_chans (int): Number of input image channels. Default: 1 (mono)
784
+ num_classes (int): Number of classes for classification head. Default: 527
785
+ embed_dim (int): Patch embedding dimension. Default: 96
786
+ depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
787
+ num_heads (tuple(int)): Number of attention heads in different layers.
788
+ window_size (int): Window size. Default: 8
789
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
790
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
791
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
792
+ drop_rate (float): Dropout rate. Default: 0
793
+ attn_drop_rate (float): Attention dropout rate. Default: 0
794
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
795
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
796
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
797
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
798
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
799
+ config (module): The configuration Module from config.py
800
+ """
801
+
802
+ def __init__(
803
+ self,
804
+ spec_size=256,
805
+ patch_size=4,
806
+ patch_stride=(4, 4),
807
+ in_chans=1,
808
+ num_classes=527,
809
+ embed_dim=96,
810
+ depths=[2, 2, 6, 2],
811
+ num_heads=[4, 8, 16, 32],
812
+ window_size=8,
813
+ mlp_ratio=4.0,
814
+ qkv_bias=True,
815
+ qk_scale=None,
816
+ drop_rate=0.0,
817
+ attn_drop_rate=0.0,
818
+ drop_path_rate=0.1,
819
+ norm_layer=nn.LayerNorm,
820
+ ape=False,
821
+ patch_norm=True,
822
+ use_checkpoint=False,
823
+ norm_before_mlp="ln",
824
+ config=None,
825
+ enable_fusion=False,
826
+ fusion_type="None",
827
+ **kwargs,
828
+ ):
829
+ super(HTSAT_Swin_Transformer, self).__init__()
830
+
831
+ self.config = config
832
+ self.spec_size = spec_size
833
+ self.patch_stride = patch_stride
834
+ self.patch_size = patch_size
835
+ self.window_size = window_size
836
+ self.embed_dim = embed_dim
837
+ self.depths = depths
838
+ self.ape = ape
839
+ self.in_chans = in_chans
840
+ self.num_classes = num_classes
841
+ self.num_heads = num_heads
842
+ self.num_layers = len(self.depths)
843
+ self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
844
+
845
+ self.drop_rate = drop_rate
846
+ self.attn_drop_rate = attn_drop_rate
847
+ self.drop_path_rate = drop_path_rate
848
+
849
+ self.qkv_bias = qkv_bias
850
+ self.qk_scale = None
851
+
852
+ self.patch_norm = patch_norm
853
+ self.norm_layer = norm_layer if self.patch_norm else None
854
+ self.norm_before_mlp = norm_before_mlp
855
+ self.mlp_ratio = mlp_ratio
856
+
857
+ self.use_checkpoint = use_checkpoint
858
+
859
+ self.enable_fusion = enable_fusion
860
+ self.fusion_type = fusion_type
861
+
862
+ # process mel-spec ; used only once
863
+ self.freq_ratio = self.spec_size // self.config.mel_bins
864
+ window = "hann"
865
+ center = True
866
+ pad_mode = "reflect"
867
+ ref = 1.0
868
+ amin = 1e-10
869
+ top_db = None
870
+ self.interpolate_ratio = 32 # Downsampled ratio
871
+ # Spectrogram extractor
872
+ self.spectrogram_extractor = Spectrogram(
873
+ n_fft=config.window_size,
874
+ hop_length=config.hop_size,
875
+ win_length=config.window_size,
876
+ window=window,
877
+ center=center,
878
+ pad_mode=pad_mode,
879
+ freeze_parameters=True,
880
+ )
881
+ # Logmel feature extractor
882
+ self.logmel_extractor = LogmelFilterBank(
883
+ sr=config.sample_rate,
884
+ n_fft=config.window_size,
885
+ n_mels=config.mel_bins,
886
+ fmin=config.fmin,
887
+ fmax=config.fmax,
888
+ ref=ref,
889
+ amin=amin,
890
+ top_db=top_db,
891
+ freeze_parameters=True,
892
+ )
893
+ # Spec augmenter
894
+ self.spec_augmenter = SpecAugmentation(
895
+ time_drop_width=64,
896
+ time_stripes_num=2,
897
+ freq_drop_width=8,
898
+ freq_stripes_num=2,
899
+ ) # 2 2
900
+ self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
901
+
902
+ # split spctrogram into non-overlapping patches
903
+ self.patch_embed = PatchEmbed(
904
+ img_size=self.spec_size,
905
+ patch_size=self.patch_size,
906
+ in_chans=self.in_chans,
907
+ embed_dim=self.embed_dim,
908
+ norm_layer=self.norm_layer,
909
+ patch_stride=patch_stride,
910
+ enable_fusion=self.enable_fusion,
911
+ fusion_type=self.fusion_type,
912
+ )
913
+
914
+ num_patches = self.patch_embed.num_patches
915
+ patches_resolution = self.patch_embed.grid_size
916
+ self.patches_resolution = patches_resolution
917
+
918
+ # absolute position embedding
919
+ if self.ape:
920
+ self.absolute_pos_embed = nn.Parameter(
921
+ torch.zeros(1, num_patches, self.embed_dim)
922
+ )
923
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
924
+
925
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
926
+
927
+ # stochastic depth
928
+ dpr = [
929
+ x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))
930
+ ] # stochastic depth decay rule
931
+
932
+ # build layers
933
+ self.layers = nn.ModuleList()
934
+ for i_layer in range(self.num_layers):
935
+ layer = BasicLayer(
936
+ dim=int(self.embed_dim * 2**i_layer),
937
+ input_resolution=(
938
+ patches_resolution[0] // (2**i_layer),
939
+ patches_resolution[1] // (2**i_layer),
940
+ ),
941
+ depth=self.depths[i_layer],
942
+ num_heads=self.num_heads[i_layer],
943
+ window_size=self.window_size,
944
+ mlp_ratio=self.mlp_ratio,
945
+ qkv_bias=self.qkv_bias,
946
+ qk_scale=self.qk_scale,
947
+ drop=self.drop_rate,
948
+ attn_drop=self.attn_drop_rate,
949
+ drop_path=dpr[
950
+ sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1])
951
+ ],
952
+ norm_layer=self.norm_layer,
953
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
954
+ use_checkpoint=use_checkpoint,
955
+ norm_before_mlp=self.norm_before_mlp,
956
+ )
957
+ self.layers.append(layer)
958
+
959
+ self.norm = self.norm_layer(self.num_features)
960
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
961
+ self.maxpool = nn.AdaptiveMaxPool1d(1)
962
+
963
+ SF = (
964
+ self.spec_size
965
+ // (2 ** (len(self.depths) - 1))
966
+ // self.patch_stride[0]
967
+ // self.freq_ratio
968
+ )
969
+ self.tscam_conv = nn.Conv2d(
970
+ in_channels=self.num_features,
971
+ out_channels=self.num_classes,
972
+ kernel_size=(SF, 3),
973
+ padding=(0, 1),
974
+ )
975
+ self.head = nn.Linear(num_classes, num_classes)
976
+
977
+ if (self.enable_fusion) and (
978
+ self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
979
+ ):
980
+ self.mel_conv1d = nn.Sequential(
981
+ nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
982
+ nn.BatchNorm1d(64),
983
+ )
984
+ if self.fusion_type == "daf_1d":
985
+ self.fusion_model = DAF()
986
+ elif self.fusion_type == "aff_1d":
987
+ self.fusion_model = AFF(channels=64, type="1D")
988
+ elif self.fusion_type == "iaff_1d":
989
+ self.fusion_model = iAFF(channels=64, type="1D")
990
+
991
+ self.apply(self._init_weights)
992
+
993
+ def _init_weights(self, m):
994
+ if isinstance(m, nn.Linear):
995
+ trunc_normal_(m.weight, std=0.02)
996
+ if isinstance(m, nn.Linear) and m.bias is not None:
997
+ nn.init.constant_(m.bias, 0)
998
+ elif isinstance(m, nn.LayerNorm):
999
+ nn.init.constant_(m.bias, 0)
1000
+ nn.init.constant_(m.weight, 1.0)
1001
+
1002
+ @torch.jit.ignore
1003
+ def no_weight_decay(self):
1004
+ return {"absolute_pos_embed"}
1005
+
1006
+ @torch.jit.ignore
1007
+ def no_weight_decay_keywords(self):
1008
+ return {"relative_position_bias_table"}
1009
+
1010
+ def forward_features(self, x, longer_idx=None):
1011
+ # A deprecated optimization for using a hierarchical output from different blocks
1012
+
1013
+ frames_num = x.shape[2]
1014
+ x = self.patch_embed(x, longer_idx=longer_idx)
1015
+ if self.ape:
1016
+ x = x + self.absolute_pos_embed
1017
+ x = self.pos_drop(x)
1018
+ for i, layer in enumerate(self.layers):
1019
+ x, attn = layer(x)
1020
+ # for x
1021
+ x = self.norm(x)
1022
+ B, N, C = x.shape
1023
+ SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
1024
+ ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
1025
+ x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST)
1026
+ B, C, F, T = x.shape
1027
+ # group 2D CNN
1028
+ c_freq_bin = F // self.freq_ratio
1029
+ x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
1030
+ x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1)
1031
+ # get latent_output
1032
+ fine_grained_latent_output = torch.mean(x, dim=2)
1033
+ fine_grained_latent_output = interpolate(
1034
+ fine_grained_latent_output.permute(0, 2, 1).contiguous(),
1035
+ 8 * self.patch_stride[1],
1036
+ )
1037
+
1038
+ latent_output = self.avgpool(torch.flatten(x, 2))
1039
+ latent_output = torch.flatten(latent_output, 1)
1040
+
1041
+ # display the attention map, if needed
1042
+
1043
+ x = self.tscam_conv(x)
1044
+ x = torch.flatten(x, 2) # B, C, T
1045
+
1046
+ fpx = interpolate(
1047
+ torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1]
1048
+ )
1049
+
1050
+ x = self.avgpool(x)
1051
+ x = torch.flatten(x, 1)
1052
+
1053
+ output_dict = {
1054
+ "framewise_output": fpx, # already sigmoided
1055
+ "clipwise_output": torch.sigmoid(x),
1056
+ "fine_grained_embedding": fine_grained_latent_output,
1057
+ "embedding": latent_output,
1058
+ }
1059
+
1060
+ return output_dict
1061
+
1062
+ def crop_wav(self, x, crop_size, spe_pos=None):
1063
+ time_steps = x.shape[2]
1064
+ tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
1065
+ for i in range(len(x)):
1066
+ if spe_pos is None:
1067
+ crop_pos = random.randint(0, time_steps - crop_size - 1)
1068
+ else:
1069
+ crop_pos = spe_pos
1070
+ tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :]
1071
+ return tx
1072
+
1073
+ # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
1074
+ def reshape_wav2img(self, x):
1075
+ B, C, T, F = x.shape
1076
+ target_T = int(self.spec_size * self.freq_ratio)
1077
+ target_F = self.spec_size // self.freq_ratio
1078
+ assert (
1079
+ T <= target_T and F <= target_F
1080
+ ), "the wav size should less than or equal to the swin input size"
1081
+ # to avoid bicubic zero error
1082
+ if T < target_T:
1083
+ x = nn.functional.interpolate(
1084
+ x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
1085
+ )
1086
+ if F < target_F:
1087
+ x = nn.functional.interpolate(
1088
+ x, (x.shape[2], target_F), mode="bicubic", align_corners=True
1089
+ )
1090
+ x = x.permute(0, 1, 3, 2).contiguous()
1091
+ x = x.reshape(
1092
+ x.shape[0],
1093
+ x.shape[1],
1094
+ x.shape[2],
1095
+ self.freq_ratio,
1096
+ x.shape[3] // self.freq_ratio,
1097
+ )
1098
+ # print(x.shape)
1099
+ x = x.permute(0, 1, 3, 2, 4).contiguous()
1100
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
1101
+ return x
1102
+
1103
+ # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
1104
+ def repeat_wat2img(self, x, cur_pos):
1105
+ B, C, T, F = x.shape
1106
+ target_T = int(self.spec_size * self.freq_ratio)
1107
+ target_F = self.spec_size // self.freq_ratio
1108
+ assert (
1109
+ T <= target_T and F <= target_F
1110
+ ), "the wav size should less than or equal to the swin input size"
1111
+ # to avoid bicubic zero error
1112
+ if T < target_T:
1113
+ x = nn.functional.interpolate(
1114
+ x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
1115
+ )
1116
+ if F < target_F:
1117
+ x = nn.functional.interpolate(
1118
+ x, (x.shape[2], target_F), mode="bicubic", align_corners=True
1119
+ )
1120
+ x = x.permute(0, 1, 3, 2).contiguous() # B C F T
1121
+ x = x[:, :, :, cur_pos : cur_pos + self.spec_size]
1122
+ x = x.repeat(repeats=(1, 1, 4, 1))
1123
+ return x
1124
+
1125
+ def forward(
1126
+ self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None
1127
+ ): # out_feat_keys: List[str] = None):
1128
+ if self.enable_fusion and x["longer"].sum() == 0:
1129
+ # if no audio is longer than 10s, then randomly select one audio to be longer
1130
+ x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True
1131
+
1132
+ if not self.enable_fusion:
1133
+ x = x["waveform"].to(device=device, non_blocking=True)
1134
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
1135
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
1136
+ x = x.transpose(1, 3)
1137
+ x = self.bn0(x)
1138
+ x = x.transpose(1, 3)
1139
+ if self.training:
1140
+ x = self.spec_augmenter(x)
1141
+
1142
+ if self.training and mixup_lambda is not None:
1143
+ x = do_mixup(x, mixup_lambda)
1144
+
1145
+ x = self.reshape_wav2img(x)
1146
+ output_dict = self.forward_features(x)
1147
+ else:
1148
+ longer_list = x["longer"].to(device=device, non_blocking=True)
1149
+ x = x["mel_fusion"].to(device=device, non_blocking=True)
1150
+ x = x.transpose(1, 3)
1151
+ x = self.bn0(x)
1152
+ x = x.transpose(1, 3)
1153
+ longer_list_idx = torch.where(longer_list)[0]
1154
+ if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
1155
+ new_x = x[:, 0:1, :, :].clone().contiguous()
1156
+ if len(longer_list_idx) > 0:
1157
+ # local processing
1158
+ fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
1159
+ FB, FC, FT, FF = fusion_x_local.size()
1160
+ fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
1161
+ fusion_x_local = torch.permute(
1162
+ fusion_x_local, (0, 2, 1)
1163
+ ).contiguous()
1164
+ fusion_x_local = self.mel_conv1d(fusion_x_local)
1165
+ fusion_x_local = fusion_x_local.view(
1166
+ FB, FC, FF, fusion_x_local.size(-1)
1167
+ )
1168
+ fusion_x_local = (
1169
+ torch.permute(fusion_x_local, (0, 2, 1, 3))
1170
+ .contiguous()
1171
+ .flatten(2)
1172
+ )
1173
+ if fusion_x_local.size(-1) < FT:
1174
+ fusion_x_local = torch.cat(
1175
+ [
1176
+ fusion_x_local,
1177
+ torch.zeros(
1178
+ (FB, FF, FT - fusion_x_local.size(-1)),
1179
+ device=device,
1180
+ ),
1181
+ ],
1182
+ dim=-1,
1183
+ )
1184
+ else:
1185
+ fusion_x_local = fusion_x_local[:, :, :FT]
1186
+ # 1D fusion
1187
+ new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
1188
+ new_x[longer_list_idx] = self.fusion_model(
1189
+ new_x[longer_list_idx], fusion_x_local
1190
+ )
1191
+ x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
1192
+ else:
1193
+ x = new_x
1194
+
1195
+ elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
1196
+ x = x # no change
1197
+
1198
+ if self.training:
1199
+ x = self.spec_augmenter(x)
1200
+ if self.training and mixup_lambda is not None:
1201
+ x = do_mixup(x, mixup_lambda)
1202
+
1203
+ x = self.reshape_wav2img(x)
1204
+ output_dict = self.forward_features(x, longer_idx=longer_list_idx)
1205
+
1206
+ # if infer_mode:
1207
+ # # in infer mode. we need to handle different length audio input
1208
+ # frame_num = x.shape[2]
1209
+ # target_T = int(self.spec_size * self.freq_ratio)
1210
+ # repeat_ratio = math.floor(target_T / frame_num)
1211
+ # x = x.repeat(repeats=(1,1,repeat_ratio,1))
1212
+ # x = self.reshape_wav2img(x)
1213
+ # output_dict = self.forward_features(x)
1214
+ # else:
1215
+ # if x.shape[2] > self.freq_ratio * self.spec_size:
1216
+ # if self.training:
1217
+ # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
1218
+ # x = self.reshape_wav2img(x)
1219
+ # output_dict = self.forward_features(x)
1220
+ # else:
1221
+ # # Change: Hard code here
1222
+ # overlap_size = (x.shape[2] - 1) // 4
1223
+ # output_dicts = []
1224
+ # crop_size = (x.shape[2] - 1) // 2
1225
+ # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
1226
+ # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
1227
+ # tx = self.reshape_wav2img(tx)
1228
+ # output_dicts.append(self.forward_features(tx))
1229
+ # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
1230
+ # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
1231
+ # for d in output_dicts:
1232
+ # clipwise_output += d["clipwise_output"]
1233
+ # framewise_output += d["framewise_output"]
1234
+ # clipwise_output = clipwise_output / len(output_dicts)
1235
+ # framewise_output = framewise_output / len(output_dicts)
1236
+ # output_dict = {
1237
+ # 'framewise_output': framewise_output,
1238
+ # 'clipwise_output': clipwise_output
1239
+ # }
1240
+ # else: # this part is typically used, and most easy one
1241
+ # x = self.reshape_wav2img(x)
1242
+ # output_dict = self.forward_features(x)
1243
+ # x = self.head(x)
1244
+
1245
+ # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T
1246
+
1247
+ return output_dict
1248
+
1249
+
1250
+ def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"):
1251
+ try:
1252
+ assert audio_cfg.model_name in [
1253
+ "tiny",
1254
+ "base",
1255
+ "large",
1256
+ ], "model name for HTS-AT is wrong!"
1257
+ if audio_cfg.model_name == "tiny":
1258
+ model = HTSAT_Swin_Transformer(
1259
+ spec_size=256,
1260
+ patch_size=4,
1261
+ patch_stride=(4, 4),
1262
+ num_classes=audio_cfg.class_num,
1263
+ embed_dim=96,
1264
+ depths=[2, 2, 6, 2],
1265
+ num_heads=[4, 8, 16, 32],
1266
+ window_size=8,
1267
+ config=audio_cfg,
1268
+ enable_fusion=enable_fusion,
1269
+ fusion_type=fusion_type,
1270
+ )
1271
+ elif audio_cfg.model_name == "base":
1272
+ model = HTSAT_Swin_Transformer(
1273
+ spec_size=256,
1274
+ patch_size=4,
1275
+ patch_stride=(4, 4),
1276
+ num_classes=audio_cfg.class_num,
1277
+ embed_dim=128,
1278
+ depths=[2, 2, 12, 2],
1279
+ num_heads=[4, 8, 16, 32],
1280
+ window_size=8,
1281
+ config=audio_cfg,
1282
+ enable_fusion=enable_fusion,
1283
+ fusion_type=fusion_type,
1284
+ )
1285
+ elif audio_cfg.model_name == "large":
1286
+ model = HTSAT_Swin_Transformer(
1287
+ spec_size=256,
1288
+ patch_size=4,
1289
+ patch_stride=(4, 4),
1290
+ num_classes=audio_cfg.class_num,
1291
+ embed_dim=256,
1292
+ depths=[2, 2, 12, 2],
1293
+ num_heads=[4, 8, 16, 32],
1294
+ window_size=8,
1295
+ config=audio_cfg,
1296
+ enable_fusion=enable_fusion,
1297
+ fusion_type=fusion_type,
1298
+ )
1299
+
1300
+ return model
1301
+ except:
1302
+ raise RuntimeError(
1303
+ f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
1304
+ )
audioldm2/clap/open_clip/loss.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed.nn
3
+ from torch import distributed as dist, nn as nn
4
+ from torch.nn import functional as F
5
+ import numpy as np
6
+ from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
7
+
8
+ try:
9
+ import horovod.torch as hvd
10
+ except ImportError:
11
+ hvd = None
12
+
13
+
14
+ def gather_features(
15
+ audio_features,
16
+ text_features,
17
+ audio_features_mlp=None,
18
+ text_features_mlp=None,
19
+ local_loss=False,
20
+ gather_with_grad=False,
21
+ rank=0,
22
+ world_size=1,
23
+ use_horovod=False,
24
+ mlp_loss=False,
25
+ ):
26
+ if use_horovod:
27
+ assert hvd is not None, "Please install horovod"
28
+ if gather_with_grad:
29
+ all_audio_features = hvd.allgather(audio_features)
30
+ all_text_features = hvd.allgather(text_features)
31
+ if mlp_loss:
32
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
33
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
34
+ else:
35
+ with torch.no_grad():
36
+ all_audio_features = hvd.allgather(audio_features)
37
+ all_text_features = hvd.allgather(text_features)
38
+ if mlp_loss:
39
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
40
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
41
+ if not local_loss:
42
+ # ensure grads for local rank when all_* features don't have a gradient
43
+ gathered_audio_features = list(
44
+ all_audio_features.chunk(world_size, dim=0)
45
+ )
46
+ gathered_text_features = list(
47
+ all_text_features.chunk(world_size, dim=0)
48
+ )
49
+ gathered_audio_features[rank] = audio_features
50
+ gathered_text_features[rank] = text_features
51
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
52
+ all_text_features = torch.cat(gathered_text_features, dim=0)
53
+ if mlp_loss:
54
+ gathered_audio_features_mlp = list(
55
+ all_audio_features_mlp.chunk(world_size, dim=0)
56
+ )
57
+ gathered_text_features_mlp = list(
58
+ all_text_features_mlp.chunk(world_size, dim=0)
59
+ )
60
+ gathered_audio_features_mlp[rank] = audio_features_mlp
61
+ gathered_text_features_mlp[rank] = text_features_mlp
62
+ all_audio_features_mlp = torch.cat(
63
+ gathered_audio_features_mlp, dim=0
64
+ )
65
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
66
+ else:
67
+ # We gather tensors from all gpus
68
+ if gather_with_grad:
69
+ all_audio_features = torch.cat(
70
+ torch.distributed.nn.all_gather(audio_features), dim=0
71
+ )
72
+ all_text_features = torch.cat(
73
+ torch.distributed.nn.all_gather(text_features), dim=0
74
+ )
75
+ if mlp_loss:
76
+ all_audio_features_mlp = torch.cat(
77
+ torch.distributed.nn.all_gather(audio_features_mlp), dim=0
78
+ )
79
+ all_text_features_mlp = torch.cat(
80
+ torch.distributed.nn.all_gather(text_features_mlp), dim=0
81
+ )
82
+ else:
83
+ gathered_audio_features = [
84
+ torch.zeros_like(audio_features) for _ in range(world_size)
85
+ ]
86
+ gathered_text_features = [
87
+ torch.zeros_like(text_features) for _ in range(world_size)
88
+ ]
89
+ dist.all_gather(gathered_audio_features, audio_features)
90
+ dist.all_gather(gathered_text_features, text_features)
91
+ if mlp_loss:
92
+ gathered_audio_features_mlp = [
93
+ torch.zeros_like(audio_features_mlp) for _ in range(world_size)
94
+ ]
95
+ gathered_text_features_mlp = [
96
+ torch.zeros_like(text_features_mlp) for _ in range(world_size)
97
+ ]
98
+ dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
99
+ dist.all_gather(gathered_text_features_mlp, text_features_mlp)
100
+ if not local_loss:
101
+ # ensure grads for local rank when all_* features don't have a gradient
102
+ gathered_audio_features[rank] = audio_features
103
+ gathered_text_features[rank] = text_features
104
+ if mlp_loss:
105
+ gathered_audio_features_mlp[rank] = audio_features_mlp
106
+ gathered_text_features_mlp[rank] = text_features_mlp
107
+
108
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
109
+ all_text_features = torch.cat(gathered_text_features, dim=0)
110
+ if mlp_loss:
111
+ all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
112
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
113
+ if mlp_loss:
114
+ return (
115
+ all_audio_features,
116
+ all_text_features,
117
+ all_audio_features_mlp,
118
+ all_text_features_mlp,
119
+ )
120
+ else:
121
+ return all_audio_features, all_text_features
122
+
123
+
124
+ class ClipLoss(nn.Module):
125
+ def __init__(
126
+ self,
127
+ local_loss=False,
128
+ gather_with_grad=False,
129
+ cache_labels=False,
130
+ rank=0,
131
+ world_size=1,
132
+ use_horovod=False,
133
+ mlp_loss=False,
134
+ weight_loss_kappa=0,
135
+ ):
136
+ super().__init__()
137
+ self.local_loss = local_loss
138
+ self.gather_with_grad = gather_with_grad
139
+ self.cache_labels = cache_labels
140
+ self.rank = rank
141
+ self.world_size = world_size
142
+ self.use_horovod = use_horovod
143
+ self.mlp_loss = mlp_loss
144
+ self.weighted_loss = bool(weight_loss_kappa != 0)
145
+ self.weight_loss_kappa = weight_loss_kappa
146
+ # cache state
147
+ self.prev_num_logits = 0
148
+ self.labels = {}
149
+
150
+ def forward(
151
+ self,
152
+ audio_features,
153
+ text_features,
154
+ logit_scale_a,
155
+ logit_scale_t=None,
156
+ audio_features_mlp=None,
157
+ text_features_mlp=None,
158
+ ):
159
+ device = audio_features.device
160
+ if self.mlp_loss:
161
+ if self.world_size > 1:
162
+ (
163
+ all_audio_features,
164
+ all_text_features,
165
+ all_audio_features_mlp,
166
+ all_text_features_mlp,
167
+ ) = gather_features(
168
+ audio_features=audio_features,
169
+ text_features=text_features,
170
+ audio_features_mlp=audio_features_mlp,
171
+ text_features_mlp=text_features_mlp,
172
+ local_loss=self.local_loss,
173
+ gather_with_grad=self.gather_with_grad,
174
+ rank=self.rank,
175
+ world_size=self.world_size,
176
+ use_horovod=self.use_horovod,
177
+ mlp_loss=self.mlp_loss,
178
+ )
179
+ if self.local_loss:
180
+ a_logits_per_audio = (
181
+ logit_scale_a * audio_features @ all_text_features_mlp.T
182
+ )
183
+ a_logits_per_text = (
184
+ logit_scale_a * text_features_mlp @ all_audio_features.T
185
+ )
186
+ t_logits_per_audio = (
187
+ logit_scale_t * audio_features_mlp @ all_text_features.T
188
+ )
189
+ t_logits_per_text = (
190
+ logit_scale_t * text_features @ all_audio_features_mlp.T
191
+ )
192
+ else:
193
+ a_logits_per_audio = (
194
+ logit_scale_a * all_audio_features @ all_text_features_mlp.T
195
+ )
196
+ a_logits_per_text = a_logits_per_audio.T
197
+ t_logits_per_audio = (
198
+ logit_scale_t * all_audio_features_mlp @ all_text_features.T
199
+ )
200
+ t_logits_per_text = t_logits_per_audio.T
201
+ else:
202
+ a_logits_per_audio = (
203
+ logit_scale_a * audio_features @ text_features_mlp.T
204
+ )
205
+ a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
206
+ t_logits_per_audio = (
207
+ logit_scale_t * audio_features_mlp @ text_features.T
208
+ )
209
+ t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
210
+
211
+ # calculated ground-truth and cache if enabled
212
+ num_logits = a_logits_per_audio.shape[0]
213
+ if self.prev_num_logits != num_logits or device not in self.labels:
214
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
215
+ if self.world_size > 1 and self.local_loss:
216
+ labels = labels + num_logits * self.rank
217
+ if self.cache_labels:
218
+ self.labels[device] = labels
219
+ self.prev_num_logits = num_logits
220
+ else:
221
+ labels = self.labels[device]
222
+
223
+ if not self.weighted_loss:
224
+ total_loss = (
225
+ F.cross_entropy(a_logits_per_audio, labels)
226
+ + F.cross_entropy(a_logits_per_text, labels)
227
+ + F.cross_entropy(t_logits_per_audio, labels)
228
+ + F.cross_entropy(t_logits_per_text, labels)
229
+ ) / 4
230
+ else:
231
+ audio_weight = (audio_features @ audio_features.T).detach()
232
+ audio_weight = (
233
+ torch.exp(
234
+ torch.sum(audio_weight, axis=1)
235
+ / (self.weight_loss_kappa * len(audio_weight))
236
+ )
237
+ ).detach()
238
+ text_weight = (text_features @ text_features.T).detach()
239
+ text_weight = (
240
+ torch.exp(
241
+ torch.sum(text_weight, axis=1)
242
+ / (self.weight_loss_kappa * len(text_features))
243
+ )
244
+ ).detach()
245
+ total_loss = (
246
+ F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight)
247
+ + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight)
248
+ + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight)
249
+ + F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
250
+ ) / 4
251
+ else:
252
+ if self.world_size > 1:
253
+ all_audio_features, all_text_features = gather_features(
254
+ audio_features=audio_features,
255
+ text_features=text_features,
256
+ local_loss=self.local_loss,
257
+ gather_with_grad=self.gather_with_grad,
258
+ rank=self.rank,
259
+ world_size=self.world_size,
260
+ use_horovod=self.use_horovod,
261
+ mlp_loss=self.mlp_loss,
262
+ )
263
+
264
+ if self.local_loss:
265
+ logits_per_audio = (
266
+ logit_scale_a * audio_features @ all_text_features.T
267
+ )
268
+ logits_per_text = (
269
+ logit_scale_a * text_features @ all_audio_features.T
270
+ )
271
+ else:
272
+ logits_per_audio = (
273
+ logit_scale_a * all_audio_features @ all_text_features.T
274
+ )
275
+ logits_per_text = logits_per_audio.T
276
+ else:
277
+ logits_per_audio = logit_scale_a * audio_features @ text_features.T
278
+ logits_per_text = logit_scale_a * text_features @ audio_features.T
279
+
280
+ # calculated ground-truth and cache if enabled
281
+ num_logits = logits_per_audio.shape[0]
282
+ if self.prev_num_logits != num_logits or device not in self.labels:
283
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
284
+ if self.world_size > 1 and self.local_loss:
285
+ labels = labels + num_logits * self.rank
286
+ if self.cache_labels:
287
+ self.labels[device] = labels
288
+ self.prev_num_logits = num_logits
289
+ else:
290
+ labels = self.labels[device]
291
+ if not self.weighted_loss:
292
+ total_loss = (
293
+ F.cross_entropy(logits_per_audio, labels)
294
+ + F.cross_entropy(logits_per_text, labels)
295
+ ) / 2
296
+ else:
297
+ audio_weight = (all_audio_features @ all_audio_features.T).detach()
298
+ audio_weight = (
299
+ torch.exp(
300
+ torch.sum(audio_weight, axis=1)
301
+ / (self.weight_loss_kappa * len(all_audio_features))
302
+ )
303
+ ).detach()
304
+ text_weight = (all_text_features @ all_text_features.T).detach()
305
+ text_weight = (
306
+ torch.exp(
307
+ torch.sum(text_weight, axis=1)
308
+ / (self.weight_loss_kappa * len(all_text_features))
309
+ )
310
+ ).detach()
311
+ total_loss = (
312
+ F.cross_entropy(logits_per_audio, labels, weight=text_weight)
313
+ + F.cross_entropy(logits_per_text, labels, weight=audio_weight)
314
+ ) / 2
315
+ return total_loss
316
+
317
+
318
+ def lp_gather_features(pred, target, world_size=1, use_horovod=False):
319
+ if use_horovod:
320
+ assert hvd is not None, "Please install horovod"
321
+ with torch.no_grad():
322
+ all_preds = hvd.allgather(pred)
323
+ all_targets = hvd.allgath(target)
324
+ else:
325
+ gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
326
+ gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
327
+
328
+ dist.all_gather(gathered_preds, pred)
329
+ dist.all_gather(gathered_targets, target)
330
+ all_preds = torch.cat(gathered_preds, dim=0)
331
+ all_targets = torch.cat(gathered_targets, dim=0)
332
+
333
+ return all_preds, all_targets
334
+
335
+
336
+ def get_map(pred, target):
337
+ pred = torch.sigmoid(pred).numpy()
338
+ target = target.numpy()
339
+ return np.mean(average_precision_score(target, pred, average=None))
340
+
341
+
342
+ def get_acc(pred, target):
343
+ pred = torch.argmax(pred, 1).numpy()
344
+ target = torch.argmax(target, 1).numpy()
345
+ return accuracy_score(target, pred)
346
+
347
+
348
+ def get_mauc(pred, target):
349
+ pred = torch.sigmoid(pred).numpy()
350
+ target = target.numpy()
351
+ return np.mean(roc_auc_score(target, pred, average=None))
352
+
353
+
354
+ class LPMetrics(object):
355
+ def __init__(self, metric_names=["map", "acc", "mauc"]):
356
+ self.metrics = []
357
+ for name in metric_names:
358
+ self.metrics.append(self.get_metric(name))
359
+ self.metric_names = metric_names
360
+
361
+ def get_metric(self, name):
362
+ if name == "map":
363
+ return get_map
364
+ elif name == "acc":
365
+ return get_acc
366
+ elif name == "mauc":
367
+ return get_mauc
368
+ else:
369
+ raise ValueError(f"the metric should be at least one of [map, acc, mauc]")
370
+
371
+ def evaluate_mertics(self, pred, target):
372
+ metric_dict = {}
373
+ for i in range(len(self.metric_names)):
374
+ metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
375
+ return metric_dict
376
+
377
+
378
+ def calc_celoss(pred, target):
379
+ target = torch.argmax(target, 1).long()
380
+ return nn.CrossEntropyLoss()(pred, target)
381
+
382
+
383
+ class LPLoss(nn.Module):
384
+ def __init__(self, loss_name):
385
+ super().__init__()
386
+ if loss_name == "bce":
387
+ self.loss_func = nn.BCEWithLogitsLoss()
388
+ elif loss_name == "ce":
389
+ self.loss_func = calc_celoss
390
+ elif loss_name == "mse":
391
+ self.loss_func = nn.MSELoss()
392
+ else:
393
+ raise ValueError(f"the loss func should be at least one of [bce, ce, mse]")
394
+
395
+ def forward(self, pred, target):
396
+ loss = self.loss_func(pred, target)
397
+ return loss
audioldm2/clap/open_clip/model.py ADDED
@@ -0,0 +1,931 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLAP Model
2
+
3
+ Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ Adapted to the Audio Task.
5
+ """
6
+
7
+ from collections import OrderedDict
8
+ from dataclasses import dataclass
9
+ from typing import Tuple, Union, Callable, Optional
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn
15
+
16
+ import logging
17
+ from .utils import freeze_batch_norm_2d
18
+
19
+ from .pann_model import create_pann_model
20
+ from .htsat import create_htsat_model
21
+ from transformers import BertModel, RobertaModel, BartModel, RobertaConfig
22
+
23
+
24
+ class MLPLayers(nn.Module):
25
+ def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
26
+ super(MLPLayers, self).__init__()
27
+ self.nonlin = nonlin
28
+ self.dropout = dropout
29
+
30
+ sequence = []
31
+ for u0, u1 in zip(units[:-1], units[1:]):
32
+ sequence.append(nn.Linear(u0, u1))
33
+ sequence.append(self.nonlin)
34
+ sequence.append(nn.Dropout(self.dropout))
35
+ sequence = sequence[:-2]
36
+
37
+ self.sequential = nn.Sequential(*sequence)
38
+
39
+ def forward(self, X):
40
+ X = self.sequential(X)
41
+ return X
42
+
43
+
44
+ class Bottleneck(nn.Module):
45
+ expansion = 4
46
+
47
+ def __init__(self, inplanes, planes, stride=1):
48
+ super().__init__()
49
+
50
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
51
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
52
+ self.bn1 = nn.BatchNorm2d(planes)
53
+
54
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
55
+ self.bn2 = nn.BatchNorm2d(planes)
56
+
57
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
58
+
59
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
60
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
61
+
62
+ self.relu = nn.ReLU(inplace=True)
63
+ self.downsample = None
64
+ self.stride = stride
65
+
66
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
67
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
68
+ self.downsample = nn.Sequential(
69
+ OrderedDict(
70
+ [
71
+ ("-1", nn.AvgPool2d(stride)),
72
+ (
73
+ "0",
74
+ nn.Conv2d(
75
+ inplanes,
76
+ planes * self.expansion,
77
+ 1,
78
+ stride=1,
79
+ bias=False,
80
+ ),
81
+ ),
82
+ ("1", nn.BatchNorm2d(planes * self.expansion)),
83
+ ]
84
+ )
85
+ )
86
+
87
+ def forward(self, x: torch.Tensor):
88
+ identity = x
89
+
90
+ out = self.relu(self.bn1(self.conv1(x)))
91
+ out = self.relu(self.bn2(self.conv2(out)))
92
+ out = self.avgpool(out)
93
+ out = self.bn3(self.conv3(out))
94
+
95
+ if self.downsample is not None:
96
+ identity = self.downsample(x)
97
+
98
+ out += identity
99
+ out = self.relu(out)
100
+ return out
101
+
102
+
103
+ class AttentionPool2d(nn.Module):
104
+ def __init__(
105
+ self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
106
+ ):
107
+ super().__init__()
108
+ self.positional_embedding = nn.Parameter(
109
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
110
+ )
111
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
112
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
113
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
114
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
115
+ self.num_heads = num_heads
116
+
117
+ def forward(self, x):
118
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
119
+ 2, 0, 1
120
+ ) # NCHW -> (HW)NC
121
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
122
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
123
+ x, _ = F.multi_head_attention_forward(
124
+ query=x,
125
+ key=x,
126
+ value=x,
127
+ embed_dim_to_check=x.shape[-1],
128
+ num_heads=self.num_heads,
129
+ q_proj_weight=self.q_proj.weight,
130
+ k_proj_weight=self.k_proj.weight,
131
+ v_proj_weight=self.v_proj.weight,
132
+ in_proj_weight=None,
133
+ in_proj_bias=torch.cat(
134
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
135
+ ),
136
+ bias_k=None,
137
+ bias_v=None,
138
+ add_zero_attn=False,
139
+ dropout_p=0,
140
+ out_proj_weight=self.c_proj.weight,
141
+ out_proj_bias=self.c_proj.bias,
142
+ use_separate_proj_weight=True,
143
+ training=self.training,
144
+ need_weights=False,
145
+ )
146
+
147
+ return x[0]
148
+
149
+
150
+ class ModifiedResNet(nn.Module):
151
+ """
152
+ A ResNet class that is similar to torchvision's but contains the following changes:
153
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
154
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
155
+ - The final pooling layer is a QKV attention instead of an average pool
156
+ """
157
+
158
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
159
+ super().__init__()
160
+ self.output_dim = output_dim
161
+ self.image_size = image_size
162
+
163
+ # the 3-layer stem
164
+ self.conv1 = nn.Conv2d(
165
+ 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
166
+ )
167
+ self.bn1 = nn.BatchNorm2d(width // 2)
168
+ self.conv2 = nn.Conv2d(
169
+ width // 2, width // 2, kernel_size=3, padding=1, bias=False
170
+ )
171
+ self.bn2 = nn.BatchNorm2d(width // 2)
172
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
173
+ self.bn3 = nn.BatchNorm2d(width)
174
+ self.avgpool = nn.AvgPool2d(2)
175
+ self.relu = nn.ReLU(inplace=True)
176
+
177
+ # residual layers
178
+ self._inplanes = width # this is a *mutable* variable used during construction
179
+ self.layer1 = self._make_layer(width, layers[0])
180
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
181
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
182
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
183
+
184
+ embed_dim = width * 32 # the ResNet feature dimension
185
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
186
+
187
+ self.init_parameters()
188
+
189
+ def _make_layer(self, planes, blocks, stride=1):
190
+ layers = [Bottleneck(self._inplanes, planes, stride)]
191
+
192
+ self._inplanes = planes * Bottleneck.expansion
193
+ for _ in range(1, blocks):
194
+ layers.append(Bottleneck(self._inplanes, planes))
195
+
196
+ return nn.Sequential(*layers)
197
+
198
+ def init_parameters(self):
199
+ if self.attnpool is not None:
200
+ std = self.attnpool.c_proj.in_features**-0.5
201
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
202
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
203
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
204
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
205
+
206
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
207
+ for name, param in resnet_block.named_parameters():
208
+ if name.endswith("bn3.weight"):
209
+ nn.init.zeros_(param)
210
+
211
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
212
+ assert (
213
+ unlocked_groups == 0
214
+ ), "partial locking not currently supported for this model"
215
+ for param in self.parameters():
216
+ param.requires_grad = False
217
+ if freeze_bn_stats:
218
+ freeze_batch_norm_2d(self)
219
+
220
+ def stem(self, x):
221
+ for conv, bn in [
222
+ (self.conv1, self.bn1),
223
+ (self.conv2, self.bn2),
224
+ (self.conv3, self.bn3),
225
+ ]:
226
+ x = self.relu(bn(conv(x)))
227
+ x = self.avgpool(x)
228
+ return x
229
+
230
+ def forward(self, x):
231
+ x = self.stem(x)
232
+ x = self.layer1(x)
233
+ x = self.layer2(x)
234
+ x = self.layer3(x)
235
+ x = self.layer4(x)
236
+ x = self.attnpool(x)
237
+
238
+ return x
239
+
240
+
241
+ class LayerNorm(nn.LayerNorm):
242
+ """Subclass torch's LayerNorm to handle fp16."""
243
+
244
+ def forward(self, x: torch.Tensor):
245
+ orig_type = x.dtype
246
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
247
+ return x.to(orig_type)
248
+
249
+
250
+ class QuickGELU(nn.Module):
251
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
252
+ def forward(self, x: torch.Tensor):
253
+ return x * torch.sigmoid(1.702 * x)
254
+
255
+
256
+ class ResidualAttentionBlock(nn.Module):
257
+ def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
258
+ super().__init__()
259
+
260
+ self.attn = nn.MultiheadAttention(d_model, n_head)
261
+ self.ln_1 = LayerNorm(d_model)
262
+ self.mlp = nn.Sequential(
263
+ OrderedDict(
264
+ [
265
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
266
+ ("gelu", act_layer()),
267
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
268
+ ]
269
+ )
270
+ )
271
+ self.ln_2 = LayerNorm(d_model)
272
+
273
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
274
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
275
+
276
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
277
+ x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
278
+ x = x + self.mlp(self.ln_2(x))
279
+ return x
280
+
281
+
282
+ class Transformer(nn.Module):
283
+ def __init__(
284
+ self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
285
+ ):
286
+ super().__init__()
287
+ self.width = width
288
+ self.layers = layers
289
+ self.resblocks = nn.ModuleList(
290
+ [
291
+ ResidualAttentionBlock(width, heads, act_layer=act_layer)
292
+ for _ in range(layers)
293
+ ]
294
+ )
295
+
296
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
297
+ for r in self.resblocks:
298
+ x = r(x, attn_mask=attn_mask)
299
+ return x
300
+
301
+
302
+ class VisualTransformer(nn.Module):
303
+ def __init__(
304
+ self,
305
+ image_size: int,
306
+ patch_size: int,
307
+ width: int,
308
+ layers: int,
309
+ heads: int,
310
+ output_dim: int,
311
+ act_layer: Callable = nn.GELU,
312
+ ):
313
+ super().__init__()
314
+ self.image_size = image_size
315
+ self.output_dim = output_dim
316
+ self.conv1 = nn.Conv2d(
317
+ in_channels=3,
318
+ out_channels=width,
319
+ kernel_size=patch_size,
320
+ stride=patch_size,
321
+ bias=False,
322
+ )
323
+
324
+ scale = width**-0.5
325
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
326
+ self.positional_embedding = nn.Parameter(
327
+ scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
328
+ )
329
+ self.ln_pre = LayerNorm(width)
330
+
331
+ self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
332
+
333
+ self.ln_post = LayerNorm(width)
334
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
335
+
336
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
337
+ assert (
338
+ unlocked_groups == 0
339
+ ), "partial locking not currently supported for this model"
340
+ for param in self.parameters():
341
+ param.requires_grad = False
342
+
343
+ def forward(self, x: torch.Tensor):
344
+ x = self.conv1(x) # shape = [*, width, grid, grid]
345
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
346
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
347
+ x = torch.cat(
348
+ [
349
+ self.class_embedding.to(x.dtype)
350
+ + torch.zeros(
351
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
352
+ ),
353
+ x,
354
+ ],
355
+ dim=1,
356
+ ) # shape = [*, grid ** 2 + 1, width]
357
+ x = x + self.positional_embedding.to(x.dtype)
358
+ x = self.ln_pre(x)
359
+
360
+ x = x.permute(1, 0, 2) # NLD -> LND
361
+ x = self.text_branch(x)
362
+ x = x.permute(1, 0, 2) # LND -> NLD
363
+
364
+ x = self.ln_post(x[:, 0, :])
365
+
366
+ if self.proj is not None:
367
+ x = x @ self.proj
368
+
369
+ return x
370
+
371
+
372
+ @dataclass
373
+ class CLAPVisionCfg:
374
+ layers: Union[Tuple[int, int, int, int], int] = 12
375
+ width: int = 768
376
+ patch_size: int = 16
377
+ image_size: Union[Tuple[int, int], int] = 224
378
+ timm_model_name: str = (
379
+ None # a valid model name overrides layers, width, patch_size
380
+ )
381
+ timm_model_pretrained: bool = (
382
+ False # use (imagenet) pretrained weights for named model
383
+ )
384
+ timm_pool: str = (
385
+ "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
386
+ )
387
+ timm_proj: str = (
388
+ "linear" # linear projection for timm model output ('linear', 'mlp', '')
389
+ )
390
+
391
+
392
+ # Audio Config Class
393
+ @dataclass
394
+ class CLAPAudioCfp:
395
+ model_type: str = "PANN"
396
+ model_name: str = "Cnn14"
397
+ sample_rate: int = 48000
398
+ # Param
399
+ audio_length: int = 1024
400
+ window_size: int = 1024
401
+ hop_size: int = 1024
402
+ fmin: int = 50
403
+ fmax: int = 14000
404
+ class_num: int = 527
405
+ mel_bins: int = 64
406
+ clip_samples: int = 480000
407
+
408
+
409
+ @dataclass
410
+ class CLAPTextCfg:
411
+ context_length: int
412
+ vocab_size: int
413
+ width: int
414
+ heads: int
415
+ layers: int
416
+ model_type: str
417
+
418
+
419
+ class CLAP(nn.Module):
420
+ def __init__(
421
+ self,
422
+ embed_dim: int,
423
+ audio_cfg: CLAPAudioCfp,
424
+ text_cfg: CLAPTextCfg,
425
+ quick_gelu: bool = False,
426
+ enable_fusion: bool = False,
427
+ fusion_type: str = "None",
428
+ joint_embed_shape: int = 512,
429
+ mlp_act: str = "relu",
430
+ ):
431
+ super().__init__()
432
+ if isinstance(audio_cfg, dict):
433
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
434
+ if isinstance(text_cfg, dict):
435
+ text_cfg = CLAPTextCfg(**text_cfg)
436
+
437
+ self.audio_cfg = audio_cfg
438
+ self.text_cfg = text_cfg
439
+ self.enable_fusion = enable_fusion
440
+ self.fusion_type = fusion_type
441
+ self.joint_embed_shape = joint_embed_shape
442
+ self.mlp_act = mlp_act
443
+
444
+ self.context_length = text_cfg.context_length
445
+
446
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
447
+ # memory efficient in recent PyTorch releases (>= 1.10).
448
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
449
+ act_layer = QuickGELU if quick_gelu else nn.GELU
450
+
451
+ if mlp_act == "relu":
452
+ mlp_act_layer = nn.ReLU()
453
+ elif mlp_act == "gelu":
454
+ mlp_act_layer = nn.GELU()
455
+ else:
456
+ raise NotImplementedError
457
+
458
+ # audio branch
459
+ # audio branch parameters
460
+ if audio_cfg.model_type == "PANN":
461
+ self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
462
+ elif audio_cfg.model_type == "HTSAT":
463
+ self.audio_branch = create_htsat_model(
464
+ audio_cfg, enable_fusion, fusion_type
465
+ )
466
+ else:
467
+ logging.error(f"Model config for {audio_cfg.model_type} not found")
468
+ raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
469
+
470
+ # text branch
471
+ # text branch parameters
472
+ if text_cfg.model_type == "transformer":
473
+ self.text_branch = Transformer(
474
+ width=text_cfg.width,
475
+ layers=text_cfg.layers,
476
+ heads=text_cfg.heads,
477
+ act_layer=act_layer,
478
+ )
479
+ self.vocab_size = text_cfg.vocab_size
480
+ self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
481
+ self.positional_embedding = nn.Parameter(
482
+ torch.empty(self.context_length, text_cfg.width)
483
+ )
484
+ self.ln_final = LayerNorm(text_cfg.width)
485
+ self.text_transform = MLPLayers(
486
+ units=[
487
+ self.joint_embed_shape,
488
+ self.joint_embed_shape,
489
+ self.joint_embed_shape,
490
+ ],
491
+ dropout=0.1,
492
+ )
493
+ self.text_projection = nn.Sequential(
494
+ nn.Linear(text_cfg.width, self.joint_embed_shape),
495
+ mlp_act_layer,
496
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
497
+ )
498
+ elif text_cfg.model_type == "bert":
499
+ self.text_branch = BertModel.from_pretrained("bert-base-uncased")
500
+ self.text_transform = MLPLayers(
501
+ units=[
502
+ self.joint_embed_shape,
503
+ self.joint_embed_shape,
504
+ self.joint_embed_shape,
505
+ ],
506
+ dropout=0.1,
507
+ )
508
+ self.text_projection = nn.Sequential(
509
+ nn.Linear(768, self.joint_embed_shape),
510
+ mlp_act_layer,
511
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
512
+ )
513
+ elif text_cfg.model_type == "roberta":
514
+ self.text_branch = RobertaModel(
515
+ RobertaConfig.from_pretrained("roberta-base")
516
+ )
517
+ self.text_transform = MLPLayers(
518
+ units=[
519
+ self.joint_embed_shape,
520
+ self.joint_embed_shape,
521
+ self.joint_embed_shape,
522
+ ],
523
+ dropout=0.1,
524
+ )
525
+ self.text_projection = nn.Sequential(
526
+ nn.Linear(768, self.joint_embed_shape),
527
+ mlp_act_layer,
528
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
529
+ )
530
+ elif text_cfg.model_type == "bart":
531
+ self.text_branch = BartModel.from_pretrained("facebook/bart-base")
532
+ self.text_transform = MLPLayers(
533
+ units=[
534
+ self.joint_embed_shape,
535
+ self.joint_embed_shape,
536
+ self.joint_embed_shape,
537
+ ],
538
+ dropout=0.1,
539
+ )
540
+ self.text_projection = nn.Sequential(
541
+ nn.Linear(768, self.joint_embed_shape),
542
+ mlp_act_layer,
543
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
544
+ )
545
+ else:
546
+ logging.error(f"Model config for {text_cfg.model_type} not found")
547
+ raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
548
+ self.text_branch_type = text_cfg.model_type
549
+ # text branch parameters
550
+
551
+ # audio branch parameters
552
+ self.audio_transform = MLPLayers(
553
+ units=[
554
+ self.joint_embed_shape,
555
+ self.joint_embed_shape,
556
+ self.joint_embed_shape,
557
+ ],
558
+ dropout=0.1,
559
+ )
560
+
561
+ # below here is text branch parameters
562
+
563
+ # ============================================================================================================
564
+ self.audio_projection = nn.Sequential(
565
+ nn.Linear(embed_dim, self.joint_embed_shape),
566
+ mlp_act_layer,
567
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
568
+ )
569
+
570
+ self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
571
+ self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
572
+ self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
573
+
574
+ self.init_text_branch_parameters()
575
+
576
+ def init_text_branch_parameters(self):
577
+ if self.text_branch_type == "transformer":
578
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
579
+ nn.init.normal_(self.positional_embedding, std=0.01)
580
+ proj_std = (self.text_branch.width**-0.5) * (
581
+ (2 * self.text_branch.layers) ** -0.5
582
+ )
583
+ attn_std = self.text_branch.width**-0.5
584
+ fc_std = (2 * self.text_branch.width) ** -0.5
585
+ for block in self.text_branch.resblocks:
586
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
587
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
588
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
589
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
590
+ if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
591
+ self.text_branch.embeddings.word_embeddings.weight.shape[-1]
592
+ elif self.text_branch_type == "bart":
593
+ self.text_branch.shared.weight.shape[-1]
594
+ else:
595
+ self.text_branch.width
596
+ nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
597
+ nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
598
+
599
+ # deprecated
600
+ # if hasattr(self.visual, 'init_parameters'):
601
+ # self.visual.init_parameters()
602
+
603
+ # if self.text_projection is not None:
604
+ # nn.init.normal_(self.text_projection, std=width**-0.5)
605
+
606
+ def build_attention_mask(self):
607
+ # lazily create causal attention mask, with full attention between the vision tokens
608
+ # pytorch uses additive attention mask; fill with -inf
609
+ mask = torch.empty(self.context_length, self.context_length)
610
+ mask.fill_(float("-inf"))
611
+ mask.triu_(1) # zero out the lower diagonal
612
+ return mask
613
+
614
+ def encode_audio(self, audio, device):
615
+ return self.audio_branch(
616
+ audio, mixup_lambda=None, device=device
617
+ ) # mix lambda needs to add
618
+
619
+ # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
620
+ # tmp = {}
621
+ # for k in x[0].keys():
622
+ # tmp[k] = []
623
+ # for i in range(len(x)):
624
+ # tmp[k].append(x[i][k][:77])
625
+ # for k in x[0].keys():
626
+ # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
627
+ # return tmp
628
+
629
+ def encode_text(self, text, device):
630
+ if self.text_branch_type == "transformer":
631
+ text = text.to(device=device, non_blocking=True)
632
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
633
+
634
+ x = x + self.positional_embedding
635
+ x = x.permute(1, 0, 2) # NLD -> LND
636
+ x = self.text_branch(x, attn_mask=self.attn_mask)
637
+ x = x.permute(1, 0, 2) # LND -> NLD
638
+ x = self.ln_final(x)
639
+
640
+ # x.shape = [batch_size, n_ctx, transformer.width]
641
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
642
+ x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
643
+ elif self.text_branch_type == "bert":
644
+ # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
645
+ # text = BatchEncoding(text)
646
+ x = self.text_branch(
647
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
648
+ attention_mask=text["attention_mask"].to(
649
+ device=device, non_blocking=True
650
+ ),
651
+ token_type_ids=text["token_type_ids"].to(
652
+ device=device, non_blocking=True
653
+ ),
654
+ )["pooler_output"]
655
+ x = self.text_projection(x)
656
+ elif self.text_branch_type == "roberta":
657
+ x = self.text_branch(
658
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
659
+ attention_mask=text["attention_mask"].to(
660
+ device=device, non_blocking=True
661
+ ),
662
+ )["pooler_output"]
663
+ x = self.text_projection(x)
664
+ elif self.text_branch_type == "bart":
665
+ x = torch.mean(
666
+ self.text_branch(
667
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
668
+ attention_mask=text["attention_mask"].to(
669
+ device=device, non_blocking=True
670
+ ),
671
+ )["encoder_last_hidden_state"],
672
+ axis=1,
673
+ )
674
+ x = self.text_projection(x)
675
+ else:
676
+ logging.error(f"Model type {self.text_branch_type} not found")
677
+ raise RuntimeError(f"Model type {self.text_branch_type} not found.")
678
+ return x
679
+
680
+ def forward(self, audio, text, device=None):
681
+ """Forward audio and text into the CLAP
682
+
683
+ Parameters
684
+ ----------
685
+ audio: torch.Tensor (batch_size, audio_length)
686
+ the time-domain audio input / the batch of mel_spec and longer list.
687
+ text: torch.Tensor () // need to add
688
+ the text token input
689
+ """
690
+ if device is None:
691
+ if audio is not None:
692
+ device = audio.device
693
+ elif text is not None:
694
+ device = text.device
695
+ if audio is None and text is None:
696
+ # a hack to get the logit scale
697
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
698
+ elif audio is None:
699
+ return self.encode_text(text, device=device)
700
+ elif text is None:
701
+ return self.audio_projection(
702
+ self.encode_audio(audio, device=device)["embedding"]
703
+ )
704
+ audio_features = self.audio_projection(
705
+ self.encode_audio(audio, device=device)["embedding"]
706
+ )
707
+ audio_features = F.normalize(audio_features, dim=-1)
708
+
709
+ text_features = self.encode_text(text, device=device)
710
+ # print("text_features", text_features)
711
+ # print("text_features.shape", text_features.shape)
712
+ # print("text_features.type", type(text_features))
713
+ text_features = F.normalize(text_features, dim=-1)
714
+
715
+ audio_features_mlp = self.audio_transform(audio_features)
716
+ text_features_mlp = self.text_transform(text_features)
717
+ # Four outputs: audio features (basic & MLP), text features (basic & MLP)
718
+ return (
719
+ audio_features,
720
+ text_features,
721
+ audio_features_mlp,
722
+ text_features_mlp,
723
+ self.logit_scale_a.exp(),
724
+ self.logit_scale_t.exp(),
725
+ )
726
+
727
+ def get_logit_scale(self):
728
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
729
+
730
+ def get_text_embedding(self, data):
731
+ """Get the text embedding from the model
732
+
733
+ Parameters
734
+ ----------
735
+ data: torch.Tensor
736
+ a tensor of text embedding
737
+
738
+ Returns
739
+ ----------
740
+ text_embed: torch.Tensor
741
+ a tensor of text_embeds (N, D)
742
+
743
+ """
744
+ device = next(self.parameters()).device
745
+ for k in data:
746
+ data[k] = data[k].to(device)
747
+ text_embeds = self.encode_text(data, device=device)
748
+ text_embeds = F.normalize(text_embeds, dim=-1)
749
+
750
+ return text_embeds
751
+
752
+ def get_audio_embedding(self, data):
753
+ """Get the audio embedding from the model
754
+
755
+ Parameters
756
+ ----------
757
+ data: a list of dict
758
+ the audio input dict list from 'get_audio_feature' method
759
+
760
+ Returns
761
+ ----------
762
+ audio_embed: torch.Tensor
763
+ a tensor of audio_embeds (N, D)
764
+
765
+ """
766
+ device = next(self.parameters()).device
767
+ # input_dict = {}
768
+ # keys = data[0].keys()
769
+ # for k in keys:
770
+ # input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(
771
+ # device
772
+ # )
773
+ audio_embeds = self.audio_projection(
774
+ self.encode_audio(data, device=device)["embedding"]
775
+ )
776
+ audio_embeds = F.normalize(audio_embeds, dim=-1)
777
+
778
+ return audio_embeds
779
+
780
+ def audio_infer(self, audio, hopsize=None, device=None):
781
+ """Forward one audio and produce the audio embedding
782
+
783
+ Parameters
784
+ ----------
785
+ audio: (audio_length)
786
+ the time-domain audio input, notice that it must be only one input
787
+ hopsize: int
788
+ the overlap hopsize as the sliding window
789
+
790
+ Returns
791
+ ----------
792
+ output_dict: {
793
+ key: [n, (embedding_shape)] if "HTS-AT"
794
+ or
795
+ key: [(embedding_shape)] if "PANN"
796
+ }
797
+ the list of key values of the audio branch
798
+
799
+ """
800
+
801
+ assert not self.training, "the inference mode must be run at eval stage"
802
+ output_dict = {}
803
+ # PANN
804
+ if self.audio_cfg.model_type == "PANN":
805
+ audio_input = audio.unsqueeze(dim=0)
806
+ output_dict[key] = self.encode_audio(audio_input, device=device)[
807
+ key
808
+ ].squeeze(dim=0)
809
+ elif self.audio_cfg.model_type == "HTSAT":
810
+ # repeat
811
+ audio_len = len(audio)
812
+ k = self.audio_cfg.clip_samples // audio_len
813
+ if k > 1:
814
+ audio = audio.repeat(k)
815
+ audio_len = len(audio)
816
+
817
+ if hopsize is None:
818
+ hopsize = min(hopsize, audio_len)
819
+
820
+ if audio_len > self.audio_cfg.clip_samples:
821
+ audio_input = [
822
+ audio[pos : pos + self.audio_cfg.clip_samples].clone()
823
+ for pos in range(
824
+ 0, audio_len - self.audio_cfg.clip_samples, hopsize
825
+ )
826
+ ]
827
+ audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
828
+ audio_input = torch.stack(audio_input)
829
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key]
830
+ else:
831
+ audio_input = audio.unsqueeze(dim=0)
832
+ output_dict[key] = self.encode_audio(audio_input, device=device)[
833
+ key
834
+ ].squeeze(dim=0)
835
+
836
+ return output_dict
837
+
838
+
839
+ def convert_weights_to_fp16(model: nn.Module):
840
+ """Convert applicable model parameters to fp16"""
841
+
842
+ def _convert_weights_to_fp16(l):
843
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
844
+ l.weight.data = l.weight.data.half()
845
+ if l.bias is not None:
846
+ l.bias.data = l.bias.data.half()
847
+
848
+ if isinstance(l, nn.MultiheadAttention):
849
+ for attr in [
850
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
851
+ "in_proj_bias",
852
+ "bias_k",
853
+ "bias_v",
854
+ ]:
855
+ tensor = getattr(l, attr)
856
+ if tensor is not None:
857
+ tensor.data = tensor.data.half()
858
+
859
+ for name in ["text_projection", "proj"]:
860
+ if hasattr(l, name):
861
+ attr = getattr(l, name)
862
+ if attr is not None:
863
+ attr.data = attr.data.half()
864
+
865
+ model.apply(_convert_weights_to_fp16)
866
+
867
+
868
+ # Ignore the state dict of the vision part
869
+ def build_model_from_openai_state_dict(
870
+ state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None"
871
+ ):
872
+ embed_dim = model_cfg["embed_dim"]
873
+ audio_cfg = model_cfg["audio_cfg"]
874
+ text_cfg = model_cfg["text_cfg"]
875
+ state_dict["positional_embedding"].shape[0]
876
+ state_dict["token_embedding.weight"].shape[0]
877
+ transformer_width = state_dict["ln_final.weight"].shape[0]
878
+ transformer_width // 64
879
+ transformer_layers = len(
880
+ set(
881
+ k.split(".")[2]
882
+ for k in state_dict
883
+ if k.startswith(f"transformer.resblocks")
884
+ )
885
+ )
886
+
887
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
888
+ text_cfg = CLAPTextCfg(**text_cfg)
889
+
890
+ model = CLAP(
891
+ embed_dim,
892
+ audio_cfg=audio_cfg,
893
+ text_cfg=text_cfg,
894
+ quick_gelu=True, # OpenAI models were trained with QuickGELU
895
+ enable_fusion=enable_fusion,
896
+ fusion_type=fusion_type,
897
+ )
898
+ state_dict["logit_scale_a"] = state_dict["logit_scale"]
899
+ state_dict["logit_scale_t"] = state_dict["logit_scale"]
900
+ pop_keys = list(state_dict.keys())[::]
901
+ # pop the visual branch saved weights
902
+ for key in pop_keys:
903
+ if key.startswith("visual."):
904
+ state_dict.pop(key, None)
905
+
906
+ for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
907
+ state_dict.pop(key, None)
908
+
909
+ # not use fp16
910
+ # convert_weights_to_fp16(model)
911
+ model.load_state_dict(state_dict, strict=False)
912
+ return model.eval()
913
+
914
+
915
+ def trace_model(model, batch_size=256, device=torch.device("cpu")):
916
+ model.eval()
917
+ audio_length = model.audio_cfg.audio_length
918
+ example_audio = torch.ones((batch_size, audio_length), device=device)
919
+ example_text = torch.zeros(
920
+ (batch_size, model.context_length), dtype=torch.int, device=device
921
+ )
922
+ model = torch.jit.trace_module(
923
+ model,
924
+ inputs=dict(
925
+ forward=(example_audio, example_text),
926
+ encode_text=(example_text,),
927
+ encode_image=(example_audio,),
928
+ ),
929
+ )
930
+ model.audio_cfg.audio_length = audio_length # Question: what does this do?
931
+ return model
audioldm2/clap/open_clip/model_configs/HTSAT-base.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "base"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm2/clap/open_clip/model_configs/HTSAT-large.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "large"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1536,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "tiny"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "tiny"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm2/clap/open_clip/model_configs/PANN-10.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn10"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 18000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 960000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 360,
10
+ "fmin": 50,
11
+ "fmax": 8000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 4
22
+ }
23
+ }
audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1536,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm2/clap/open_clip/model_configs/PANN-14.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm2/clap/open_clip/model_configs/PANN-6.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn6"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 23,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
audioldm2/clap/open_clip/model_configs/RN101.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 23,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 6,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
audioldm2/clap/open_clip/model_configs/RN50.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 6,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
audioldm2/clap/open_clip/model_configs/RN50x16.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 384,
5
+ "layers": [
6
+ 6,
7
+ 8,
8
+ 18,
9
+ 8
10
+ ],
11
+ "width": 96,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 768,
18
+ "heads": 12,
19
+ "layers": 12
20
+ }
21
+ }
audioldm2/clap/open_clip/model_configs/RN50x4.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 288,
5
+ "layers": [
6
+ 4,
7
+ 6,
8
+ 10,
9
+ 6
10
+ ],
11
+ "width": 80,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 640,
18
+ "heads": 10,
19
+ "layers": 12
20
+ }
21
+ }
audioldm2/clap/open_clip/model_configs/ViT-B-16.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": 12,
7
+ "width": 768,
8
+ "patch_size": 32
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ }
17
+ }
audioldm2/clap/open_clip/model_configs/ViT-B-32.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
audioldm2/clap/open_clip/model_configs/ViT-L-14.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 14
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
audioldm2/clap/open_clip/openai.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ OpenAI pretrained model functions
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+
6
+ import os
7
+ import warnings
8
+ from typing import Union, List
9
+
10
+ import torch
11
+
12
+ from .model import build_model_from_openai_state_dict
13
+ from .pretrained import (
14
+ get_pretrained_url,
15
+ list_pretrained_tag_models,
16
+ download_pretrained,
17
+ )
18
+
19
+ __all__ = ["list_openai_models", "load_openai_model"]
20
+
21
+
22
+ def list_openai_models() -> List[str]:
23
+ """Returns the names of available CLIP models"""
24
+ return list_pretrained_tag_models("openai")
25
+
26
+
27
+ def load_openai_model(
28
+ name: str,
29
+ model_cfg,
30
+ device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
31
+ jit=True,
32
+ cache_dir=os.path.expanduser("~/.cache/clip"),
33
+ enable_fusion: bool = False,
34
+ fusion_type: str = "None",
35
+ ):
36
+ """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
37
+
38
+ Parameters
39
+ ----------
40
+ name : str
41
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
42
+ device : Union[str, torch.device]
43
+ The device to put the loaded model
44
+ jit : bool
45
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
46
+
47
+ Returns
48
+ -------
49
+ model : torch.nn.Module
50
+ The CLAP model
51
+ preprocess : Callable[[PIL.Image], torch.Tensor]
52
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
53
+ """
54
+ if get_pretrained_url(name, "openai"):
55
+ model_path = download_pretrained(
56
+ get_pretrained_url(name, "openai"), root=cache_dir
57
+ )
58
+ elif os.path.isfile(name):
59
+ model_path = name
60
+ else:
61
+ raise RuntimeError(
62
+ f"Model {name} not found; available models = {list_openai_models()}"
63
+ )
64
+
65
+ try:
66
+ # loading JIT archive
67
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
68
+ state_dict = None
69
+ except RuntimeError:
70
+ # loading saved state dict
71
+ if jit:
72
+ warnings.warn(
73
+ f"File {model_path} is not a JIT archive. Loading as a state dict instead"
74
+ )
75
+ jit = False
76
+ state_dict = torch.load(model_path, map_location="cpu")
77
+
78
+ if not jit:
79
+ try:
80
+ model = build_model_from_openai_state_dict(
81
+ state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type
82
+ ).to(device)
83
+ except KeyError:
84
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
85
+ model = build_model_from_openai_state_dict(
86
+ sd, model_cfg, enable_fusion, fusion_type
87
+ ).to(device)
88
+
89
+ if str(device) == "cpu":
90
+ model.float()
91
+ return model
92
+
93
+ # patch the device names
94
+ device_holder = torch.jit.trace(
95
+ lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
96
+ )
97
+ device_node = [
98
+ n
99
+ for n in device_holder.graph.findAllNodes("prim::Constant")
100
+ if "Device" in repr(n)
101
+ ][-1]
102
+
103
+ def patch_device(module):
104
+ try:
105
+ graphs = [module.graph] if hasattr(module, "graph") else []
106
+ except RuntimeError:
107
+ graphs = []
108
+
109
+ if hasattr(module, "forward1"):
110
+ graphs.append(module.forward1.graph)
111
+
112
+ for graph in graphs:
113
+ for node in graph.findAllNodes("prim::Constant"):
114
+ if "value" in node.attributeNames() and str(node["value"]).startswith(
115
+ "cuda"
116
+ ):
117
+ node.copyAttributes(device_node)
118
+
119
+ model.apply(patch_device)
120
+ patch_device(model.encode_audio)
121
+ patch_device(model.encode_text)
122
+
123
+ # patch dtype to float32 on CPU
124
+ if str(device) == "cpu":
125
+ float_holder = torch.jit.trace(
126
+ lambda: torch.ones([]).float(), example_inputs=[]
127
+ )
128
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
129
+ float_node = float_input.node()
130
+
131
+ def patch_float(module):
132
+ try:
133
+ graphs = [module.graph] if hasattr(module, "graph") else []
134
+ except RuntimeError:
135
+ graphs = []
136
+
137
+ if hasattr(module, "forward1"):
138
+ graphs.append(module.forward1.graph)
139
+
140
+ for graph in graphs:
141
+ for node in graph.findAllNodes("aten::to"):
142
+ inputs = list(node.inputs())
143
+ for i in [
144
+ 1,
145
+ 2,
146
+ ]: # dtype can be the second or third argument to aten::to()
147
+ if inputs[i].node()["value"] == 5:
148
+ inputs[i].node().copyAttributes(float_node)
149
+
150
+ model.apply(patch_float)
151
+ patch_float(model.encode_audio)
152
+ patch_float(model.encode_text)
153
+ model.float()
154
+
155
+ model.audio_branch.audio_length = model.audio_cfg.audio_length
156
+ return model
audioldm2/clap/open_clip/pann_model.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
2
+ # Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn
3
+ # Some layers are re-designed for CLAP
4
+ import os
5
+
6
+ os.environ["NUMBA_CACHE_DIR"] = "/tmp/"
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
12
+ from torchlibrosa.augmentation import SpecAugmentation
13
+
14
+ from .utils import do_mixup, interpolate
15
+ from .feature_fusion import iAFF, AFF, DAF
16
+
17
+
18
+ def init_layer(layer):
19
+ """Initialize a Linear or Convolutional layer."""
20
+ nn.init.xavier_uniform_(layer.weight)
21
+
22
+ if hasattr(layer, "bias"):
23
+ if layer.bias is not None:
24
+ layer.bias.data.fill_(0.0)
25
+
26
+
27
+ def init_bn(bn):
28
+ """Initialize a Batchnorm layer."""
29
+ bn.bias.data.fill_(0.0)
30
+ bn.weight.data.fill_(1.0)
31
+
32
+
33
+ class ConvBlock(nn.Module):
34
+ def __init__(self, in_channels, out_channels):
35
+ super(ConvBlock, self).__init__()
36
+
37
+ self.conv1 = nn.Conv2d(
38
+ in_channels=in_channels,
39
+ out_channels=out_channels,
40
+ kernel_size=(3, 3),
41
+ stride=(1, 1),
42
+ padding=(1, 1),
43
+ bias=False,
44
+ )
45
+
46
+ self.conv2 = nn.Conv2d(
47
+ in_channels=out_channels,
48
+ out_channels=out_channels,
49
+ kernel_size=(3, 3),
50
+ stride=(1, 1),
51
+ padding=(1, 1),
52
+ bias=False,
53
+ )
54
+
55
+ self.bn1 = nn.BatchNorm2d(out_channels)
56
+ self.bn2 = nn.BatchNorm2d(out_channels)
57
+
58
+ self.init_weight()
59
+
60
+ def init_weight(self):
61
+ init_layer(self.conv1)
62
+ init_layer(self.conv2)
63
+ init_bn(self.bn1)
64
+ init_bn(self.bn2)
65
+
66
+ def forward(self, input, pool_size=(2, 2), pool_type="avg"):
67
+ x = input
68
+ x = F.relu_(self.bn1(self.conv1(x)))
69
+ x = F.relu_(self.bn2(self.conv2(x)))
70
+ if pool_type == "max":
71
+ x = F.max_pool2d(x, kernel_size=pool_size)
72
+ elif pool_type == "avg":
73
+ x = F.avg_pool2d(x, kernel_size=pool_size)
74
+ elif pool_type == "avg+max":
75
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
76
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
77
+ x = x1 + x2
78
+ else:
79
+ raise Exception("Incorrect argument!")
80
+
81
+ return x
82
+
83
+
84
+ class ConvBlock5x5(nn.Module):
85
+ def __init__(self, in_channels, out_channels):
86
+ super(ConvBlock5x5, self).__init__()
87
+
88
+ self.conv1 = nn.Conv2d(
89
+ in_channels=in_channels,
90
+ out_channels=out_channels,
91
+ kernel_size=(5, 5),
92
+ stride=(1, 1),
93
+ padding=(2, 2),
94
+ bias=False,
95
+ )
96
+
97
+ self.bn1 = nn.BatchNorm2d(out_channels)
98
+
99
+ self.init_weight()
100
+
101
+ def init_weight(self):
102
+ init_layer(self.conv1)
103
+ init_bn(self.bn1)
104
+
105
+ def forward(self, input, pool_size=(2, 2), pool_type="avg"):
106
+ x = input
107
+ x = F.relu_(self.bn1(self.conv1(x)))
108
+ if pool_type == "max":
109
+ x = F.max_pool2d(x, kernel_size=pool_size)
110
+ elif pool_type == "avg":
111
+ x = F.avg_pool2d(x, kernel_size=pool_size)
112
+ elif pool_type == "avg+max":
113
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
114
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
115
+ x = x1 + x2
116
+ else:
117
+ raise Exception("Incorrect argument!")
118
+
119
+ return x
120
+
121
+
122
+ class AttBlock(nn.Module):
123
+ def __init__(self, n_in, n_out, activation="linear", temperature=1.0):
124
+ super(AttBlock, self).__init__()
125
+
126
+ self.activation = activation
127
+ self.temperature = temperature
128
+ self.att = nn.Conv1d(
129
+ in_channels=n_in,
130
+ out_channels=n_out,
131
+ kernel_size=1,
132
+ stride=1,
133
+ padding=0,
134
+ bias=True,
135
+ )
136
+ self.cla = nn.Conv1d(
137
+ in_channels=n_in,
138
+ out_channels=n_out,
139
+ kernel_size=1,
140
+ stride=1,
141
+ padding=0,
142
+ bias=True,
143
+ )
144
+
145
+ self.bn_att = nn.BatchNorm1d(n_out)
146
+ self.init_weights()
147
+
148
+ def init_weights(self):
149
+ init_layer(self.att)
150
+ init_layer(self.cla)
151
+ init_bn(self.bn_att)
152
+
153
+ def forward(self, x):
154
+ # x: (n_samples, n_in, n_time)
155
+ norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
156
+ cla = self.nonlinear_transform(self.cla(x))
157
+ x = torch.sum(norm_att * cla, dim=2)
158
+ return x, norm_att, cla
159
+
160
+ def nonlinear_transform(self, x):
161
+ if self.activation == "linear":
162
+ return x
163
+ elif self.activation == "sigmoid":
164
+ return torch.sigmoid(x)
165
+
166
+
167
+ class Cnn14(nn.Module):
168
+ def __init__(
169
+ self,
170
+ sample_rate,
171
+ window_size,
172
+ hop_size,
173
+ mel_bins,
174
+ fmin,
175
+ fmax,
176
+ classes_num,
177
+ enable_fusion=False,
178
+ fusion_type="None",
179
+ ):
180
+ super(Cnn14, self).__init__()
181
+
182
+ window = "hann"
183
+ center = True
184
+ pad_mode = "reflect"
185
+ ref = 1.0
186
+ amin = 1e-10
187
+ top_db = None
188
+
189
+ self.enable_fusion = enable_fusion
190
+ self.fusion_type = fusion_type
191
+
192
+ # Spectrogram extractor
193
+ self.spectrogram_extractor = Spectrogram(
194
+ n_fft=window_size,
195
+ hop_length=hop_size,
196
+ win_length=window_size,
197
+ window=window,
198
+ center=center,
199
+ pad_mode=pad_mode,
200
+ freeze_parameters=True,
201
+ )
202
+
203
+ # Logmel feature extractor
204
+ self.logmel_extractor = LogmelFilterBank(
205
+ sr=sample_rate,
206
+ n_fft=window_size,
207
+ n_mels=mel_bins,
208
+ fmin=fmin,
209
+ fmax=fmax,
210
+ ref=ref,
211
+ amin=amin,
212
+ top_db=top_db,
213
+ freeze_parameters=True,
214
+ )
215
+
216
+ # Spec augmenter
217
+ self.spec_augmenter = SpecAugmentation(
218
+ time_drop_width=64,
219
+ time_stripes_num=2,
220
+ freq_drop_width=8,
221
+ freq_stripes_num=2,
222
+ )
223
+
224
+ self.bn0 = nn.BatchNorm2d(64)
225
+
226
+ if (self.enable_fusion) and (self.fusion_type == "channel_map"):
227
+ self.conv_block1 = ConvBlock(in_channels=4, out_channels=64)
228
+ else:
229
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
230
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
231
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
232
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
233
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
234
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
235
+
236
+ self.fc1 = nn.Linear(2048, 2048, bias=True)
237
+ self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
238
+
239
+ if (self.enable_fusion) and (
240
+ self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
241
+ ):
242
+ self.mel_conv1d = nn.Sequential(
243
+ nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
244
+ nn.BatchNorm1d(64), # No Relu
245
+ )
246
+ if self.fusion_type == "daf_1d":
247
+ self.fusion_model = DAF()
248
+ elif self.fusion_type == "aff_1d":
249
+ self.fusion_model = AFF(channels=64, type="1D")
250
+ elif self.fusion_type == "iaff_1d":
251
+ self.fusion_model = iAFF(channels=64, type="1D")
252
+
253
+ if (self.enable_fusion) and (
254
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
255
+ ):
256
+ self.mel_conv2d = nn.Sequential(
257
+ nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)),
258
+ nn.BatchNorm2d(64),
259
+ nn.ReLU(inplace=True),
260
+ )
261
+
262
+ if self.fusion_type == "daf_2d":
263
+ self.fusion_model = DAF()
264
+ elif self.fusion_type == "aff_2d":
265
+ self.fusion_model = AFF(channels=64, type="2D")
266
+ elif self.fusion_type == "iaff_2d":
267
+ self.fusion_model = iAFF(channels=64, type="2D")
268
+ self.init_weight()
269
+
270
+ def init_weight(self):
271
+ init_bn(self.bn0)
272
+ init_layer(self.fc1)
273
+ init_layer(self.fc_audioset)
274
+
275
+ def forward(self, input, mixup_lambda=None, device=None):
276
+ """
277
+ Input: (batch_size, data_length)"""
278
+
279
+ if self.enable_fusion and input["longer"].sum() == 0:
280
+ # if no audio is longer than 10s, then randomly select one audio to be longer
281
+ input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True
282
+
283
+ if not self.enable_fusion:
284
+ x = self.spectrogram_extractor(
285
+ input["waveform"].to(device=device, non_blocking=True)
286
+ ) # (batch_size, 1, time_steps, freq_bins)
287
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
288
+
289
+ x = x.transpose(1, 3)
290
+ x = self.bn0(x)
291
+ x = x.transpose(1, 3)
292
+ else:
293
+ longer_list = input["longer"].to(device=device, non_blocking=True)
294
+ x = input["mel_fusion"].to(device=device, non_blocking=True)
295
+ longer_list_idx = torch.where(longer_list)[0]
296
+ x = x.transpose(1, 3)
297
+ x = self.bn0(x)
298
+ x = x.transpose(1, 3)
299
+ if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
300
+ new_x = x[:, 0:1, :, :].clone().contiguous()
301
+ # local processing
302
+ if len(longer_list_idx) > 0:
303
+ fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
304
+ FB, FC, FT, FF = fusion_x_local.size()
305
+ fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
306
+ fusion_x_local = torch.permute(
307
+ fusion_x_local, (0, 2, 1)
308
+ ).contiguous()
309
+ fusion_x_local = self.mel_conv1d(fusion_x_local)
310
+ fusion_x_local = fusion_x_local.view(
311
+ FB, FC, FF, fusion_x_local.size(-1)
312
+ )
313
+ fusion_x_local = (
314
+ torch.permute(fusion_x_local, (0, 2, 1, 3))
315
+ .contiguous()
316
+ .flatten(2)
317
+ )
318
+ if fusion_x_local.size(-1) < FT:
319
+ fusion_x_local = torch.cat(
320
+ [
321
+ fusion_x_local,
322
+ torch.zeros(
323
+ (FB, FF, FT - fusion_x_local.size(-1)),
324
+ device=device,
325
+ ),
326
+ ],
327
+ dim=-1,
328
+ )
329
+ else:
330
+ fusion_x_local = fusion_x_local[:, :, :FT]
331
+ # 1D fusion
332
+ new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
333
+ new_x[longer_list_idx] = self.fusion_model(
334
+ new_x[longer_list_idx], fusion_x_local
335
+ )
336
+ x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
337
+ else:
338
+ x = new_x
339
+ elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
340
+ x = x # no change
341
+
342
+ if self.training:
343
+ x = self.spec_augmenter(x)
344
+ # Mixup on spectrogram
345
+ if self.training and mixup_lambda is not None:
346
+ x = do_mixup(x, mixup_lambda)
347
+ if (self.enable_fusion) and (
348
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
349
+ ):
350
+ global_x = x[:, 0:1, :, :]
351
+
352
+ # global processing
353
+ B, C, H, W = global_x.shape
354
+ global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg")
355
+ if len(longer_list_idx) > 0:
356
+ local_x = x[longer_list_idx, 1:, :, :].contiguous()
357
+ TH = global_x.size(-2)
358
+ # local processing
359
+ B, C, H, W = local_x.shape
360
+ local_x = local_x.view(B * C, 1, H, W)
361
+ local_x = self.mel_conv2d(local_x)
362
+ local_x = local_x.view(
363
+ B, C, local_x.size(1), local_x.size(2), local_x.size(3)
364
+ )
365
+ local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3)
366
+ TB, TC, _, TW = local_x.size()
367
+ if local_x.size(-2) < TH:
368
+ local_x = torch.cat(
369
+ [
370
+ local_x,
371
+ torch.zeros(
372
+ (TB, TC, TH - local_x.size(-2), TW),
373
+ device=global_x.device,
374
+ ),
375
+ ],
376
+ dim=-2,
377
+ )
378
+ else:
379
+ local_x = local_x[:, :, :TH, :]
380
+
381
+ global_x[longer_list_idx] = self.fusion_model(
382
+ global_x[longer_list_idx], local_x
383
+ )
384
+ x = global_x
385
+ else:
386
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
387
+
388
+ x = F.dropout(x, p=0.2, training=self.training)
389
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
390
+ x = F.dropout(x, p=0.2, training=self.training)
391
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
392
+ x = F.dropout(x, p=0.2, training=self.training)
393
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
394
+ x = F.dropout(x, p=0.2, training=self.training)
395
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
396
+ x = F.dropout(x, p=0.2, training=self.training)
397
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
398
+ x = F.dropout(x, p=0.2, training=self.training)
399
+ x = torch.mean(x, dim=3)
400
+
401
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
402
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
403
+ latent_x = latent_x1 + latent_x2
404
+ latent_x = latent_x.transpose(1, 2)
405
+ latent_x = F.relu_(self.fc1(latent_x))
406
+ latent_output = interpolate(latent_x, 32)
407
+
408
+ (x1, _) = torch.max(x, dim=2)
409
+ x2 = torch.mean(x, dim=2)
410
+ x = x1 + x2
411
+ x = F.dropout(x, p=0.5, training=self.training)
412
+ x = F.relu_(self.fc1(x))
413
+ embedding = F.dropout(x, p=0.5, training=self.training)
414
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
415
+
416
+ output_dict = {
417
+ "clipwise_output": clipwise_output,
418
+ "embedding": embedding,
419
+ "fine_grained_embedding": latent_output,
420
+ }
421
+ return output_dict
422
+
423
+
424
+ class Cnn6(nn.Module):
425
+ def __init__(
426
+ self,
427
+ sample_rate,
428
+ window_size,
429
+ hop_size,
430
+ mel_bins,
431
+ fmin,
432
+ fmax,
433
+ classes_num,
434
+ enable_fusion=False,
435
+ fusion_type="None",
436
+ ):
437
+ super(Cnn6, self).__init__()
438
+
439
+ window = "hann"
440
+ center = True
441
+ pad_mode = "reflect"
442
+ ref = 1.0
443
+ amin = 1e-10
444
+ top_db = None
445
+
446
+ self.enable_fusion = enable_fusion
447
+ self.fusion_type = fusion_type
448
+
449
+ # Spectrogram extractor
450
+ self.spectrogram_extractor = Spectrogram(
451
+ n_fft=window_size,
452
+ hop_length=hop_size,
453
+ win_length=window_size,
454
+ window=window,
455
+ center=center,
456
+ pad_mode=pad_mode,
457
+ freeze_parameters=True,
458
+ )
459
+
460
+ # Logmel feature extractor
461
+ self.logmel_extractor = LogmelFilterBank(
462
+ sr=sample_rate,
463
+ n_fft=window_size,
464
+ n_mels=mel_bins,
465
+ fmin=fmin,
466
+ fmax=fmax,
467
+ ref=ref,
468
+ amin=amin,
469
+ top_db=top_db,
470
+ freeze_parameters=True,
471
+ )
472
+
473
+ # Spec augmenter
474
+ self.spec_augmenter = SpecAugmentation(
475
+ time_drop_width=64,
476
+ time_stripes_num=2,
477
+ freq_drop_width=8,
478
+ freq_stripes_num=2,
479
+ )
480
+
481
+ self.bn0 = nn.BatchNorm2d(64)
482
+
483
+ self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
484
+ self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
485
+ self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
486
+ self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
487
+
488
+ self.fc1 = nn.Linear(512, 512, bias=True)
489
+ self.fc_audioset = nn.Linear(512, classes_num, bias=True)
490
+
491
+ self.init_weight()
492
+
493
+ def init_weight(self):
494
+ init_bn(self.bn0)
495
+ init_layer(self.fc1)
496
+ init_layer(self.fc_audioset)
497
+
498
+ def forward(self, input, mixup_lambda=None, device=None):
499
+ """
500
+ Input: (batch_size, data_length)"""
501
+
502
+ x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
503
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
504
+
505
+ x = x.transpose(1, 3)
506
+ x = self.bn0(x)
507
+ x = x.transpose(1, 3)
508
+
509
+ if self.training:
510
+ x = self.spec_augmenter(x)
511
+
512
+ # Mixup on spectrogram
513
+ if self.training and mixup_lambda is not None:
514
+ x = do_mixup(x, mixup_lambda)
515
+
516
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
517
+ x = F.dropout(x, p=0.2, training=self.training)
518
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
519
+ x = F.dropout(x, p=0.2, training=self.training)
520
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
521
+ x = F.dropout(x, p=0.2, training=self.training)
522
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
523
+ x = F.dropout(x, p=0.2, training=self.training)
524
+ x = torch.mean(x, dim=3)
525
+
526
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
527
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
528
+ latent_x = latent_x1 + latent_x2
529
+ latent_x = latent_x.transpose(1, 2)
530
+ latent_x = F.relu_(self.fc1(latent_x))
531
+ latent_output = interpolate(latent_x, 16)
532
+
533
+ (x1, _) = torch.max(x, dim=2)
534
+ x2 = torch.mean(x, dim=2)
535
+ x = x1 + x2
536
+ x = F.dropout(x, p=0.5, training=self.training)
537
+ x = F.relu_(self.fc1(x))
538
+ embedding = F.dropout(x, p=0.5, training=self.training)
539
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
540
+
541
+ output_dict = {
542
+ "clipwise_output": clipwise_output,
543
+ "embedding": embedding,
544
+ "fine_grained_embedding": latent_output,
545
+ }
546
+
547
+ return output_dict
548
+
549
+
550
+ class Cnn10(nn.Module):
551
+ def __init__(
552
+ self,
553
+ sample_rate,
554
+ window_size,
555
+ hop_size,
556
+ mel_bins,
557
+ fmin,
558
+ fmax,
559
+ classes_num,
560
+ enable_fusion=False,
561
+ fusion_type="None",
562
+ ):
563
+ super(Cnn10, self).__init__()
564
+
565
+ window = "hann"
566
+ center = True
567
+ pad_mode = "reflect"
568
+ ref = 1.0
569
+ amin = 1e-10
570
+ top_db = None
571
+
572
+ self.enable_fusion = enable_fusion
573
+ self.fusion_type = fusion_type
574
+
575
+ # Spectrogram extractor
576
+ self.spectrogram_extractor = Spectrogram(
577
+ n_fft=window_size,
578
+ hop_length=hop_size,
579
+ win_length=window_size,
580
+ window=window,
581
+ center=center,
582
+ pad_mode=pad_mode,
583
+ freeze_parameters=True,
584
+ )
585
+
586
+ # Logmel feature extractor
587
+ self.logmel_extractor = LogmelFilterBank(
588
+ sr=sample_rate,
589
+ n_fft=window_size,
590
+ n_mels=mel_bins,
591
+ fmin=fmin,
592
+ fmax=fmax,
593
+ ref=ref,
594
+ amin=amin,
595
+ top_db=top_db,
596
+ freeze_parameters=True,
597
+ )
598
+
599
+ # Spec augmenter
600
+ self.spec_augmenter = SpecAugmentation(
601
+ time_drop_width=64,
602
+ time_stripes_num=2,
603
+ freq_drop_width=8,
604
+ freq_stripes_num=2,
605
+ )
606
+
607
+ self.bn0 = nn.BatchNorm2d(64)
608
+
609
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
610
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
611
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
612
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
613
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
614
+
615
+ self.fc1 = nn.Linear(1024, 1024, bias=True)
616
+ self.fc_audioset = nn.Linear(1024, classes_num, bias=True)
617
+
618
+ self.init_weight()
619
+
620
+ def init_weight(self):
621
+ init_bn(self.bn0)
622
+ init_layer(self.fc1)
623
+ init_layer(self.fc_audioset)
624
+
625
+ def forward(self, input, mixup_lambda=None, device=None):
626
+ """
627
+ Input: (batch_size, data_length)"""
628
+
629
+ x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
630
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
631
+
632
+ x = x.transpose(1, 3)
633
+ x = self.bn0(x)
634
+ x = x.transpose(1, 3)
635
+
636
+ if self.training:
637
+ x = self.spec_augmenter(x)
638
+
639
+ # Mixup on spectrogram
640
+ if self.training and mixup_lambda is not None:
641
+ x = do_mixup(x, mixup_lambda)
642
+
643
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
644
+ x = F.dropout(x, p=0.2, training=self.training)
645
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
646
+ x = F.dropout(x, p=0.2, training=self.training)
647
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
648
+ x = F.dropout(x, p=0.2, training=self.training)
649
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
650
+ x = F.dropout(x, p=0.2, training=self.training)
651
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
652
+ x = F.dropout(x, p=0.2, training=self.training)
653
+ x = torch.mean(x, dim=3)
654
+
655
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
656
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
657
+ latent_x = latent_x1 + latent_x2
658
+ latent_x = latent_x.transpose(1, 2)
659
+ latent_x = F.relu_(self.fc1(latent_x))
660
+ latent_output = interpolate(latent_x, 32)
661
+
662
+ (x1, _) = torch.max(x, dim=2)
663
+ x2 = torch.mean(x, dim=2)
664
+ x = x1 + x2
665
+ x = F.dropout(x, p=0.5, training=self.training)
666
+ x = F.relu_(self.fc1(x))
667
+ embedding = F.dropout(x, p=0.5, training=self.training)
668
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
669
+
670
+ output_dict = {
671
+ "clipwise_output": clipwise_output,
672
+ "embedding": embedding,
673
+ "fine_grained_embedding": latent_output,
674
+ }
675
+
676
+ return output_dict
677
+
678
+
679
+ def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"):
680
+ try:
681
+ ModelProto = eval(audio_cfg.model_name)
682
+ model = ModelProto(
683
+ sample_rate=audio_cfg.sample_rate,
684
+ window_size=audio_cfg.window_size,
685
+ hop_size=audio_cfg.hop_size,
686
+ mel_bins=audio_cfg.mel_bins,
687
+ fmin=audio_cfg.fmin,
688
+ fmax=audio_cfg.fmax,
689
+ classes_num=audio_cfg.class_num,
690
+ enable_fusion=enable_fusion,
691
+ fusion_type=fusion_type,
692
+ )
693
+ return model
694
+ except:
695
+ raise RuntimeError(
696
+ f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
697
+ )
audioldm2/clap/open_clip/pretrained.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+
6
+ from tqdm import tqdm
7
+
8
+ _RN50 = dict(
9
+ openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
10
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
11
+ cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
12
+ )
13
+
14
+ _RN50_quickgelu = dict(
15
+ openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
16
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
17
+ cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
18
+ )
19
+
20
+ _RN101 = dict(
21
+ openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
22
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
23
+ )
24
+
25
+ _RN101_quickgelu = dict(
26
+ openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
27
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
28
+ )
29
+
30
+ _RN50x4 = dict(
31
+ openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
32
+ )
33
+
34
+ _RN50x16 = dict(
35
+ openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
36
+ )
37
+
38
+ _RN50x64 = dict(
39
+ openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
40
+ )
41
+
42
+ _VITB32 = dict(
43
+ openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
44
+ laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
45
+ laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
46
+ laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
47
+ )
48
+
49
+ _VITB32_quickgelu = dict(
50
+ openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
51
+ laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
52
+ laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
53
+ laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
54
+ )
55
+
56
+ _VITB16 = dict(
57
+ openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
58
+ )
59
+
60
+ _VITL14 = dict(
61
+ openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
62
+ )
63
+
64
+ _PRETRAINED = {
65
+ "RN50": _RN50,
66
+ "RN50-quickgelu": _RN50_quickgelu,
67
+ "RN101": _RN101,
68
+ "RN101-quickgelu": _RN101_quickgelu,
69
+ "RN50x4": _RN50x4,
70
+ "RN50x16": _RN50x16,
71
+ "ViT-B-32": _VITB32,
72
+ "ViT-B-32-quickgelu": _VITB32_quickgelu,
73
+ "ViT-B-16": _VITB16,
74
+ "ViT-L-14": _VITL14,
75
+ }
76
+
77
+
78
+ def list_pretrained(as_str: bool = False):
79
+ """returns list of pretrained models
80
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
81
+ """
82
+ return [
83
+ ":".join([k, t]) if as_str else (k, t)
84
+ for k in _PRETRAINED.keys()
85
+ for t in _PRETRAINED[k].keys()
86
+ ]
87
+
88
+
89
+ def list_pretrained_tag_models(tag: str):
90
+ """return all models having the specified pretrain tag"""
91
+ models = []
92
+ for k in _PRETRAINED.keys():
93
+ if tag in _PRETRAINED[k]:
94
+ models.append(k)
95
+ return models
96
+
97
+
98
+ def list_pretrained_model_tags(model: str):
99
+ """return all pretrain tags for the specified model architecture"""
100
+ tags = []
101
+ if model in _PRETRAINED:
102
+ tags.extend(_PRETRAINED[model].keys())
103
+ return tags
104
+
105
+
106
+ def get_pretrained_url(model: str, tag: str):
107
+ if model not in _PRETRAINED:
108
+ return ""
109
+ model_pretrained = _PRETRAINED[model]
110
+ if tag not in model_pretrained:
111
+ return ""
112
+ return model_pretrained[tag]
113
+
114
+
115
+ def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")):
116
+ os.makedirs(root, exist_ok=True)
117
+ filename = os.path.basename(url)
118
+
119
+ if "openaipublic" in url:
120
+ expected_sha256 = url.split("/")[-2]
121
+ else:
122
+ expected_sha256 = ""
123
+
124
+ download_target = os.path.join(root, filename)
125
+
126
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
127
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
128
+
129
+ if os.path.isfile(download_target):
130
+ if expected_sha256:
131
+ if (
132
+ hashlib.sha256(open(download_target, "rb").read()).hexdigest()
133
+ == expected_sha256
134
+ ):
135
+ return download_target
136
+ else:
137
+ warnings.warn(
138
+ f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
139
+ )
140
+ else:
141
+ return download_target
142
+
143
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
144
+ with tqdm(
145
+ total=int(source.info().get("Content-Length")),
146
+ ncols=80,
147
+ unit="iB",
148
+ unit_scale=True,
149
+ ) as loop:
150
+ while True:
151
+ buffer = source.read(8192)
152
+ if not buffer:
153
+ break
154
+
155
+ output.write(buffer)
156
+ loop.update(len(buffer))
157
+
158
+ if (
159
+ expected_sha256
160
+ and hashlib.sha256(open(download_target, "rb").read()).hexdigest()
161
+ != expected_sha256
162
+ ):
163
+ raise RuntimeError(
164
+ f"Model has been downloaded but the SHA256 checksum does not not match"
165
+ )
166
+
167
+ return download_target
audioldm2/clap/open_clip/timm_model.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ timm model adapter
2
+
3
+ Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4
+ """
5
+ from collections import OrderedDict
6
+
7
+ import torch.nn as nn
8
+
9
+ try:
10
+ import timm
11
+ from timm.models.layers import Mlp, to_2tuple
12
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
13
+ from timm.models.layers.attention_pool2d import (
14
+ AttentionPool2d as AbsAttentionPool2d,
15
+ )
16
+ except ImportError:
17
+ timm = None
18
+
19
+ from .utils import freeze_batch_norm_2d
20
+
21
+
22
+ class TimmModel(nn.Module):
23
+ """timm model adapter
24
+ # FIXME this adapter is a work in progress, may change in ways that break weight compat
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ model_name,
30
+ embed_dim,
31
+ image_size=224,
32
+ pool="avg",
33
+ proj="linear",
34
+ drop=0.0,
35
+ pretrained=False,
36
+ ):
37
+ super().__init__()
38
+ if timm is None:
39
+ raise RuntimeError("Please `pip install timm` to use timm models.")
40
+
41
+ self.image_size = to_2tuple(image_size)
42
+ self.trunk = timm.create_model(model_name, pretrained=pretrained)
43
+ feat_size = self.trunk.default_cfg.get("pool_size", None)
44
+ feature_ndim = 1 if not feat_size else 2
45
+ if pool in ("abs_attn", "rot_attn"):
46
+ assert feature_ndim == 2
47
+ # if attn pooling used, remove both classifier and default pool
48
+ self.trunk.reset_classifier(0, global_pool="")
49
+ else:
50
+ # reset global pool if pool config set, otherwise leave as network default
51
+ reset_kwargs = dict(global_pool=pool) if pool else {}
52
+ self.trunk.reset_classifier(0, **reset_kwargs)
53
+ prev_chs = self.trunk.num_features
54
+
55
+ head_layers = OrderedDict()
56
+ if pool == "abs_attn":
57
+ head_layers["pool"] = AbsAttentionPool2d(
58
+ prev_chs, feat_size=feat_size, out_features=embed_dim
59
+ )
60
+ prev_chs = embed_dim
61
+ elif pool == "rot_attn":
62
+ head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
63
+ prev_chs = embed_dim
64
+ else:
65
+ assert proj, "projection layer needed if non-attention pooling is used."
66
+
67
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
68
+ if proj == "linear":
69
+ head_layers["drop"] = nn.Dropout(drop)
70
+ head_layers["proj"] = nn.Linear(prev_chs, embed_dim)
71
+ elif proj == "mlp":
72
+ head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
73
+
74
+ self.head = nn.Sequential(head_layers)
75
+
76
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
77
+ """lock modules
78
+ Args:
79
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
80
+ """
81
+ if not unlocked_groups:
82
+ # lock full model
83
+ for param in self.trunk.parameters():
84
+ param.requires_grad = False
85
+ if freeze_bn_stats:
86
+ freeze_batch_norm_2d(self.trunk)
87
+ else:
88
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
89
+ try:
90
+ # FIXME import here until API stable and in an official release
91
+ from timm.models.helpers import group_parameters, group_modules
92
+ except ImportError:
93
+ raise RuntimeError(
94
+ "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`"
95
+ )
96
+ matcher = self.trunk.group_matcher()
97
+ gparams = group_parameters(self.trunk, matcher)
98
+ max_layer_id = max(gparams.keys())
99
+ max_layer_id = max_layer_id - unlocked_groups
100
+ for group_idx in range(max_layer_id + 1):
101
+ group = gparams[group_idx]
102
+ for param in group:
103
+ self.trunk.get_parameter(param).requires_grad = False
104
+ if freeze_bn_stats:
105
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
106
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
107
+ freeze_batch_norm_2d(self.trunk, gmodules)
108
+
109
+ def forward(self, x):
110
+ x = self.trunk(x)
111
+ x = self.head(x)
112
+ return x
audioldm2/clap/open_clip/tokenizer.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP tokenizer
2
+
3
+ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import gzip
6
+ import html
7
+ import os
8
+ from functools import lru_cache
9
+ from typing import Union, List
10
+
11
+ import ftfy
12
+ import regex as re
13
+ import torch
14
+
15
+
16
+ @lru_cache()
17
+ def default_bpe():
18
+ return os.path.join(
19
+ os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
20
+ )
21
+
22
+
23
+ @lru_cache()
24
+ def bytes_to_unicode():
25
+ """
26
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
27
+ The reversible bpe codes work on unicode strings.
28
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
29
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
30
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
31
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
32
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
33
+ """
34
+ bs = (
35
+ list(range(ord("!"), ord("~") + 1))
36
+ + list(range(ord("¡"), ord("¬") + 1))
37
+ + list(range(ord("®"), ord("ÿ") + 1))
38
+ )
39
+ cs = bs[:]
40
+ n = 0
41
+ for b in range(2**8):
42
+ if b not in bs:
43
+ bs.append(b)
44
+ cs.append(2**8 + n)
45
+ n += 1
46
+ cs = [chr(n) for n in cs]
47
+ return dict(zip(bs, cs))
48
+
49
+
50
+ def get_pairs(word):
51
+ """Return set of symbol pairs in a word.
52
+ Word is represented as tuple of symbols (symbols being variable-length strings).
53
+ """
54
+ pairs = set()
55
+ prev_char = word[0]
56
+ for char in word[1:]:
57
+ pairs.add((prev_char, char))
58
+ prev_char = char
59
+ return pairs
60
+
61
+
62
+ def basic_clean(text):
63
+ text = ftfy.fix_text(text)
64
+ text = html.unescape(html.unescape(text))
65
+ return text.strip()
66
+
67
+
68
+ def whitespace_clean(text):
69
+ text = re.sub(r"\s+", " ", text)
70
+ text = text.strip()
71
+ return text
72
+
73
+
74
+ class SimpleTokenizer(object):
75
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
76
+ self.byte_encoder = bytes_to_unicode()
77
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
78
+ merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
79
+ merges = merges[1 : 49152 - 256 - 2 + 1]
80
+ merges = [tuple(merge.split()) for merge in merges]
81
+ vocab = list(bytes_to_unicode().values())
82
+ vocab = vocab + [v + "</w>" for v in vocab]
83
+ for merge in merges:
84
+ vocab.append("".join(merge))
85
+ if not special_tokens:
86
+ special_tokens = ["<start_of_text>", "<end_of_text>"]
87
+ else:
88
+ special_tokens = ["<start_of_text>", "<end_of_text>"] + special_tokens
89
+ vocab.extend(special_tokens)
90
+ self.encoder = dict(zip(vocab, range(len(vocab))))
91
+ self.decoder = {v: k for k, v in self.encoder.items()}
92
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
93
+ self.cache = {t: t for t in special_tokens}
94
+ special = "|".join(special_tokens)
95
+ self.pat = re.compile(
96
+ special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
97
+ re.IGNORECASE,
98
+ )
99
+
100
+ self.vocab_size = len(self.encoder)
101
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
102
+
103
+ def bpe(self, token):
104
+ if token in self.cache:
105
+ return self.cache[token]
106
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
107
+ pairs = get_pairs(word)
108
+
109
+ if not pairs:
110
+ return token + "</w>"
111
+
112
+ while True:
113
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
114
+ if bigram not in self.bpe_ranks:
115
+ break
116
+ first, second = bigram
117
+ new_word = []
118
+ i = 0
119
+ while i < len(word):
120
+ try:
121
+ j = word.index(first, i)
122
+ new_word.extend(word[i:j])
123
+ i = j
124
+ except:
125
+ new_word.extend(word[i:])
126
+ break
127
+
128
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
129
+ new_word.append(first + second)
130
+ i += 2
131
+ else:
132
+ new_word.append(word[i])
133
+ i += 1
134
+ new_word = tuple(new_word)
135
+ word = new_word
136
+ if len(word) == 1:
137
+ break
138
+ else:
139
+ pairs = get_pairs(word)
140
+ word = " ".join(word)
141
+ self.cache[token] = word
142
+ return word
143
+
144
+ def encode(self, text):
145
+ bpe_tokens = []
146
+ text = whitespace_clean(basic_clean(text)).lower()
147
+ for token in re.findall(self.pat, text):
148
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
149
+ bpe_tokens.extend(
150
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
151
+ )
152
+ return bpe_tokens
153
+
154
+ def decode(self, tokens):
155
+ text = "".join([self.decoder[token] for token in tokens])
156
+ text = (
157
+ bytearray([self.byte_decoder[c] for c in text])
158
+ .decode("utf-8", errors="replace")
159
+ .replace("</w>", " ")
160
+ )
161
+ return text
162
+
163
+
164
+ _tokenizer = SimpleTokenizer()
165
+
166
+
167
+ def tokenize(
168
+ texts: Union[str, List[str]], context_length: int = 77
169
+ ) -> torch.LongTensor:
170
+ """
171
+ Returns the tokenized representation of given input string(s)
172
+
173
+ Parameters
174
+ ----------
175
+ texts : Union[str, List[str]]
176
+ An input string or a list of input strings to tokenize
177
+ context_length : int
178
+ The context length to use; all CLIP models use 77 as the context length
179
+
180
+ Returns
181
+ -------
182
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
183
+ """
184
+ if isinstance(texts, str):
185
+ texts = [texts]
186
+
187
+ sot_token = _tokenizer.encoder["<start_of_text>"]
188
+ eot_token = _tokenizer.encoder["<end_of_text>"]
189
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
190
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
191
+
192
+ for i, tokens in enumerate(all_tokens):
193
+ if len(tokens) > context_length:
194
+ tokens = tokens[:context_length] # Truncate
195
+ result[i, : len(tokens)] = torch.tensor(tokens)
196
+
197
+ return result
audioldm2/clap/open_clip/transform.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.transforms import (
2
+ Normalize,
3
+ Compose,
4
+ RandomResizedCrop,
5
+ InterpolationMode,
6
+ ToTensor,
7
+ Resize,
8
+ CenterCrop,
9
+ )
10
+
11
+
12
+ def _convert_to_rgb(image):
13
+ return image.convert("RGB")
14
+
15
+
16
+ def image_transform(
17
+ image_size: int,
18
+ is_train: bool,
19
+ mean=(0.48145466, 0.4578275, 0.40821073),
20
+ std=(0.26862954, 0.26130258, 0.27577711),
21
+ ):
22
+ normalize = Normalize(mean=mean, std=std)
23
+ if is_train:
24
+ return Compose(
25
+ [
26
+ RandomResizedCrop(
27
+ image_size,
28
+ scale=(0.9, 1.0),
29
+ interpolation=InterpolationMode.BICUBIC,
30
+ ),
31
+ _convert_to_rgb,
32
+ ToTensor(),
33
+ normalize,
34
+ ]
35
+ )
36
+ else:
37
+ return Compose(
38
+ [
39
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
40
+ CenterCrop(image_size),
41
+ _convert_to_rgb,
42
+ ToTensor(),
43
+ normalize,
44
+ ]
45
+ )
audioldm2/clap/open_clip/utils.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn as nn
4
+ from torchvision.ops.misc import FrozenBatchNorm2d
5
+ import logging
6
+ import h5py
7
+ from tqdm import tqdm
8
+ import random
9
+ import json
10
+ import os
11
+ import pathlib
12
+
13
+ # TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later.
14
+ dataset_split = {
15
+ "audiocaps": ["train", "valid", "test"],
16
+ "audioset": ["balanced_train", "unbalanced_train", "eval"],
17
+ "BBCSoundEffects": ["train", "test"],
18
+ "Clotho": ["train", "test", "valid"],
19
+ "free_to_use_sounds": ["train", "test"],
20
+ "paramount_motion": ["train", "test"],
21
+ "sonniss_game_effects": ["train", "test"],
22
+ "wesoundeffects": ["train", "test"],
23
+ "MACS": ["train", "test"],
24
+ "freesound": ["train", "test"],
25
+ "FSD50K": ["train", "test", "valid"],
26
+ "fsd50k_class_label": ["train", "test", "valid"],
27
+ "esc50": ["train", "test"],
28
+ "audiostock": ["train", "test"],
29
+ "freesound_no_overlap_noesc50": ["train", "test"],
30
+ "epidemic_sound_effects": ["train", "test"],
31
+ "VGGSound": ["train", "test"],
32
+ "urbansound8k_class_label": ["train", "test"],
33
+ "audioset_t5": ["balanced_train", "unbalanced_train", "eval"],
34
+ "epidemic_sound_effects_t5": ["train", "test"],
35
+ "WavText5K": ["train", "test"],
36
+ "esc50_no_overlap": ["train", "test"],
37
+ "usd8k_no_overlap": ["train", "test"],
38
+ "fsd50k_200_class_label": ["train", "test", "valid"],
39
+ }
40
+
41
+
42
+ def freeze_batch_norm_2d(module, module_match={}, name=""):
43
+ """
44
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
45
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
46
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
47
+
48
+ Args:
49
+ module (torch.nn.Module): Any PyTorch module.
50
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
51
+ name (str): Full module name (prefix)
52
+
53
+ Returns:
54
+ torch.nn.Module: Resulting module
55
+
56
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
57
+ """
58
+ res = module
59
+ is_match = True
60
+ if module_match:
61
+ is_match = name in module_match
62
+ if is_match and isinstance(
63
+ module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)
64
+ ):
65
+ res = FrozenBatchNorm2d(module.num_features)
66
+ res.num_features = module.num_features
67
+ res.affine = module.affine
68
+ if module.affine:
69
+ res.weight.data = module.weight.data.clone().detach()
70
+ res.bias.data = module.bias.data.clone().detach()
71
+ res.running_mean.data = module.running_mean.data
72
+ res.running_var.data = module.running_var.data
73
+ res.eps = module.eps
74
+ else:
75
+ for child_name, child in module.named_children():
76
+ full_child_name = ".".join([name, child_name]) if name else child_name
77
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
78
+ if new_child is not child:
79
+ res.add_module(child_name, new_child)
80
+ return res
81
+
82
+
83
+ def exist(dataset_name, dataset_type):
84
+ """
85
+ Check if dataset exists
86
+ """
87
+ if dataset_type in dataset_split[dataset_name]:
88
+ return True
89
+ else:
90
+ return False
91
+
92
+
93
+ def get_tar_path_from_dataset_name(
94
+ dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None
95
+ ):
96
+ """
97
+ Get tar path from dataset name and type
98
+ """
99
+ output = []
100
+ for n in dataset_names:
101
+ if full_dataset is not None and n in full_dataset:
102
+ current_dataset_types = dataset_split[n]
103
+ else:
104
+ current_dataset_types = dataset_types
105
+ for s in current_dataset_types:
106
+ tmp = []
107
+ if islocal:
108
+ sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json"
109
+ if not os.path.exists(sizefilepath_):
110
+ sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
111
+ else:
112
+ sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
113
+ if not os.path.exists(sizefilepath_):
114
+ continue
115
+ sizes = json.load(open(sizefilepath_, "r"))
116
+ for k in sizes.keys():
117
+ if islocal:
118
+ tmp.append(f"{dataset_path}/{n}/{s}/{k}")
119
+ else:
120
+ tmp.append(
121
+ f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -"
122
+ )
123
+ if proportion != 1:
124
+ tmp = random.sample(tmp, int(proportion * len(tmp)))
125
+ output.append(tmp)
126
+ return sum(output, [])
127
+
128
+
129
+ def get_tar_path_from_txts(txt_path, islocal, proportion=1):
130
+ """
131
+ Get tar path from txt path
132
+ """
133
+ if isinstance(txt_path, (list, tuple)):
134
+ return sum(
135
+ [
136
+ get_tar_path_from_txts(
137
+ txt_path[i], islocal=islocal, proportion=proportion
138
+ )
139
+ for i in range(len(txt_path))
140
+ ],
141
+ [],
142
+ )
143
+ if isinstance(txt_path, str):
144
+ with open(txt_path) as f:
145
+ lines = f.readlines()
146
+ if islocal:
147
+ lines = [
148
+ lines[i]
149
+ .split("\n")[0]
150
+ .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/")
151
+ for i in range(len(lines))
152
+ ]
153
+ else:
154
+ lines = [
155
+ lines[i].split("\n")[0].replace(".tar", ".tar -")
156
+ for i in range(len(lines))
157
+ ]
158
+ if proportion != 1:
159
+ print("Sampling tars with proportion of {}".format(proportion))
160
+ lines = random.sample(lines, int(proportion * len(lines)))
161
+ return lines
162
+
163
+
164
+ def get_mix_lambda(mixup_alpha, batch_size):
165
+ mixup_lambdas = [
166
+ np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)
167
+ ]
168
+ return np.array(mixup_lambdas).astype(np.float32)
169
+
170
+
171
+ def do_mixup(x, mixup_lambda):
172
+ """
173
+ Args:
174
+ x: (batch_size , ...)
175
+ mixup_lambda: (batch_size,)
176
+ Returns:
177
+ out: (batch_size, ...)
178
+ """
179
+ out = (
180
+ x.transpose(0, -1) * mixup_lambda
181
+ + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
182
+ ).transpose(0, -1)
183
+ return out
184
+
185
+
186
+ def interpolate(x, ratio):
187
+ """Interpolate data in time domain. This is used to compensate the
188
+ resolution reduction in downsampling of a CNN.
189
+
190
+ Args:
191
+ x: (batch_size, time_steps, classes_num)
192
+ ratio: int, ratio to interpolate
193
+ Returns:
194
+ upsampled: (batch_size, time_steps * ratio, classes_num)
195
+ """
196
+ (batch_size, time_steps, classes_num) = x.shape
197
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
198
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
199
+ return upsampled
200
+
201
+
202
+ def pad_framewise_output(framewise_output, frames_num):
203
+ """Pad framewise_output to the same length as input frames. The pad value
204
+ is the same as the value of the last frame.
205
+ Args:
206
+ framewise_output: (batch_size, frames_num, classes_num)
207
+ frames_num: int, number of frames to pad
208
+ Outputs:
209
+ output: (batch_size, frames_num, classes_num)
210
+ """
211
+ pad = framewise_output[:, -1:, :].repeat(
212
+ 1, frames_num - framewise_output.shape[1], 1
213
+ )
214
+ """tensor for padding"""
215
+
216
+ output = torch.cat((framewise_output, pad), dim=1)
217
+ """(batch_size, frames_num, classes_num)"""
218
+
219
+
220
+ def process_ipc(index_path, classes_num, filename):
221
+ # load data
222
+ logging.info("Load Data...............")
223
+ ipc = [[] for _ in range(classes_num)]
224
+ with h5py.File(index_path, "r") as f:
225
+ for i in tqdm(range(len(f["target"]))):
226
+ t_class = np.where(f["target"][i])[0]
227
+ for t in t_class:
228
+ ipc[t].append(i)
229
+ print(ipc)
230
+ np.save(filename, ipc)
231
+ logging.info("Load Data Succeed...............")
232
+
233
+
234
+ def save_to_dict(s, o_={}):
235
+ sp = s.split(": ")
236
+ o_.update({sp[0]: float(sp[1])})
237
+ return o_
238
+
239
+
240
+ def get_data_from_log(txt_path):
241
+ """
242
+ Output dictionary from out.txt log file
243
+ """
244
+ with open(txt_path) as f:
245
+ lines = f.readlines()
246
+ val_data = {}
247
+ train_data = {}
248
+ train_losses = []
249
+ train_losses_epoch = []
250
+ for i in range(len(lines)):
251
+ if "| INFO |" in lines[i]:
252
+ if "Eval Epoch" in lines[i]:
253
+ if "val_loss" in lines[i]:
254
+ # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", ""))
255
+ line = lines[i].split("Eval Epoch: ")[-1]
256
+ num_epoch = int(line.split(" ")[0].split(" ")[0])
257
+ d = {
258
+ line.split(" ")[0]
259
+ .split(" ")[1]
260
+ .replace(":", ""): float(line.split(" ")[0].split(" ")[-1])
261
+ }
262
+ for i in range(1, len(line.split(" "))):
263
+ d = save_to_dict(line.split(" ")[i], d)
264
+ val_data[num_epoch] = d
265
+ elif "Train Epoch" in lines[i]:
266
+ num_epoch = int(lines[i].split("Train Epoch: ")[1][0])
267
+ loss = float(lines[i].split("Loss: ")[-1].split(" (")[0])
268
+ train_losses.append(loss)
269
+ train_losses_epoch.append(num_epoch)
270
+ for i in range(len(train_losses)):
271
+ train_data[i] = {
272
+ "num_epoch": train_losses_epoch[i],
273
+ "train_loss": train_losses[i],
274
+ }
275
+ return train_data, val_data
276
+
277
+
278
+ def save_p(obj, filename):
279
+ import pickle
280
+
281
+ try:
282
+ from deepdiff import DeepDiff
283
+ except:
284
+ os.system("pip install deepdiff")
285
+ from deepdiff import DeepDiff
286
+ with open(filename, "wb") as file:
287
+ pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol
288
+ with open(filename, "rb") as file:
289
+ z = pickle.load(file)
290
+ assert (
291
+ DeepDiff(obj, z, ignore_string_case=True) == {}
292
+ ), "there is something wrong with the saving process"
293
+ return
294
+
295
+
296
+ def load_p(filename):
297
+ import pickle
298
+
299
+ with open(filename, "rb") as file:
300
+ z = pickle.load(file)
301
+ return z
302
+
303
+
304
+ def save_json(data, name="data.json"):
305
+ import json
306
+
307
+ with open(name, "w") as fp:
308
+ json.dump(data, fp)
309
+ return
310
+
311
+
312
+ def load_json(name):
313
+ import json
314
+
315
+ with open(name, "r") as fp:
316
+ data = json.load(fp)
317
+ return data
318
+
319
+
320
+ def load_class_label(path):
321
+ # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
322
+ # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
323
+ out = None
324
+ if path is not None:
325
+ if pathlib.Path(path).suffix in [".pkl", ".pickle"]:
326
+ out = load_p(path)
327
+ elif pathlib.Path(path).suffix in [".json", ".txt"]:
328
+ out = load_json(path)
329
+ elif pathlib.Path(path).suffix in [".npy", ".npz"]:
330
+ out = np.load(path)
331
+ elif pathlib.Path(path).suffix in [".csv"]:
332
+ import pandas as pd
333
+
334
+ out = pd.read_csv(path)
335
+ return out
336
+ # if out is None:
337
+ # return None
338
+ # else:
339
+ # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False)
340
+ # val = Array('i', out.values(), lock=False)
341
+ # return (key, val)
342
+
343
+
344
+ from torch import optim
345
+
346
+
347
+ def get_optimizer(params, lr, betas, eps, momentum, optimizer_name):
348
+ if optimizer_name.lower() == "adamw":
349
+ optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps)
350
+ elif optimizer_name.lower() == "sgd":
351
+ optimizer = optim.SGD(params, lr=lr, momentum=momentum)
352
+ elif optimizer_name.lower() == "adam":
353
+ optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps)
354
+ else:
355
+ raise ValueError("optimizer name is not correct")
356
+ return optimizer
audioldm2/clap/training/__init__.py ADDED
File without changes
audioldm2/clap/training/audioset_textmap.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bada103070d92f9eadd33e1b4f45ec8583f59080ef218c966b43294bd4c86d5b
3
+ size 84448
audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
audioldm2/clap/training/data.py ADDED
@@ -0,0 +1,865 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import random
5
+ import h5py
6
+ from dataclasses import dataclass
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ import torchvision.datasets as datasets
11
+ from PIL import Image
12
+ from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
13
+ from torch.utils.data.distributed import DistributedSampler
14
+ import soundfile as sf
15
+ import io
16
+ from pathlib import Path
17
+ # import wget
18
+
19
+ from audioldm2.clap.open_clip.utils import get_tar_path_from_dataset_name
20
+ from audioldm2.clap.open_clip.utils import load_class_label
21
+
22
+ try:
23
+ import horovod.torch as hvd
24
+ except ImportError:
25
+ hvd = None
26
+
27
+ try:
28
+ import torchaudio
29
+ except ImportError:
30
+ torchaudio = None
31
+
32
+ from audioldm2.clap.open_clip import tokenize
33
+
34
+
35
+ def tokenizer(text):
36
+ return tokenize(text).squeeze(0)
37
+
38
+
39
+ from transformers import RobertaTokenizer
40
+
41
+ tokenize = RobertaTokenizer.from_pretrained("roberta-base")
42
+
43
+
44
+ def tokenizer(text):
45
+ result = tokenize(
46
+ text,
47
+ padding="max_length",
48
+ truncation=True,
49
+ max_length=77,
50
+ return_tensors="pt",
51
+ )
52
+ return {k: v.squeeze(0) for k, v in result.items()}
53
+
54
+
55
+ # initizlied the audioset map
56
+ _AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy")
57
+ _AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True)
58
+
59
+
60
+ def int16_to_float32(x):
61
+ return (x / 32767.0).astype(np.float32)
62
+
63
+
64
+ def float32_to_int16(x):
65
+ x = np.clip(x, a_min=-1.0, a_max=1.0)
66
+ return (x * 32767.0).astype(np.int16)
67
+
68
+
69
+ # For Toy Dataset
70
+ class ToyDataset(Dataset):
71
+ def __init__(self, index_path, ipc, config, eval_mode=False):
72
+ """Toy Dataset for testing the audioset input with text labels
73
+ Parameters
74
+ ----------
75
+ index_path: str
76
+ the link to the h5 file of each audio
77
+ idc: str
78
+ the link to the npy file, the number of samples in each class
79
+ config: dict
80
+ the audio cfg file
81
+ eval_model (bool): to indicate if the dataset is a testing dataset
82
+ """
83
+ self.audio_cfg = config["audio_cfg"]
84
+ self.text_cfg = config["text_cfg"]
85
+ self.fp = h5py.File(index_path, "r")
86
+ self.ipc = np.load(ipc, allow_pickle=True)
87
+ self.total_size = len(self.fp["audio_name"])
88
+ self.classes_num = self.audio_cfg["class_num"]
89
+ self.eval_mode = eval_mode
90
+
91
+ if not eval_mode:
92
+ self.generate_queue()
93
+ else:
94
+ self.queue = []
95
+ for i in range(self.total_size):
96
+ target = self.fp["target"][i]
97
+ if np.sum(target) > 0:
98
+ self.queue.append(i)
99
+ self.total_size = len(self.queue)
100
+ logging.info("total dataset size: %d" % (self.total_size))
101
+ logging.info("class num: %d" % (self.classes_num))
102
+
103
+ def time_shifting(self, x):
104
+ frame_num = len(x)
105
+ shift_len = random.randint(0, frame_num - 1)
106
+ new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0)
107
+ return new_sample
108
+
109
+ def generate_queue(self):
110
+ self.queue = []
111
+ while len(self.queue) < self.total_size:
112
+ class_set = [*range(self.classes_num)]
113
+ random.shuffle(class_set)
114
+ self.queue += [
115
+ self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set
116
+ ]
117
+ self.queue = self.queue[: self.total_size]
118
+
119
+ logging.info("queue regenerated:%s" % (self.queue[-5:]))
120
+
121
+ def crop_wav(self, x):
122
+ crop_size = self.audio_cfg["crop_size"]
123
+ crop_pos = random.randint(0, len(x) - crop_size - 1)
124
+ return x[crop_pos : crop_pos + crop_size]
125
+
126
+ def prompt_text(self, target):
127
+ events = _AUDIOSET_MAP[np.where(target > 0)]
128
+ event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1]
129
+ text = tokenize(event_text)[0]
130
+ return text
131
+
132
+ def __getitem__(self, index):
133
+ """Load waveform, text, and target of an audio clip
134
+
135
+ Parameters
136
+ ----------
137
+ index: int
138
+ the index number
139
+ Return
140
+ ------
141
+ output: dict {
142
+ "hdf5_path": str,
143
+ "index_in_hdf5": int,
144
+ "audio_name": str,
145
+ "waveform": list (audio_length,),
146
+ "target": list (class_num, ),
147
+ "text": torch.tensor (context_length,)
148
+ }
149
+ the output dictionary
150
+ """
151
+ s_index = self.queue[index]
152
+
153
+ audio_name = self.fp["audio_name"][s_index].decode()
154
+ # Hardcode here CHANGE
155
+ hdf5_path = (
156
+ self.fp["hdf5_path"][s_index]
157
+ .decode()
158
+ .replace(
159
+ "../workspace",
160
+ "/home/la/kechen/Research/ke_zsasp/workspace",
161
+ )
162
+ )
163
+ r_idx = self.fp["index_in_hdf5"][s_index]
164
+ target = self.fp["target"][s_index].astype(np.float32)
165
+ text = self.prompt_text(target)
166
+ with h5py.File(hdf5_path, "r") as f:
167
+ waveform = int16_to_float32(f["waveform"][r_idx])[
168
+ : self.audio_cfg["clip_samples"]
169
+ ]
170
+ assert (
171
+ len(waveform) == self.audio_cfg["clip_samples"]
172
+ ), "The sample length is not match"
173
+ # Time shift
174
+ # if (self.config.enable_time_shift) and (not self.eval_mode):
175
+ # waveform = self.time_shifting(waveform)
176
+ # # Label Enhance
177
+ # if (self.config.crop_size is not None) and (not self.eval_mode):
178
+ # waveform = self.crop_wav(waveform)
179
+ # # the label enhance rate is fixed 0.5
180
+ # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5:
181
+ # kidx = np.where(target)[0]
182
+ # for k in kidx:
183
+ # for add_key in self.class_map[k][1]:
184
+ # target[add_key] = 1.0
185
+ # if len(self.class_map[k][2]) > 0:
186
+ # add_key = random.choice(self.class_map[k][2])
187
+ # target[add_key] = 1.0
188
+
189
+ # missing the text input
190
+ mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :]
191
+ mel_spec = (
192
+ torch.cat(
193
+ [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0
194
+ )
195
+ .cpu()
196
+ .numpy()
197
+ )
198
+ longer = random.choice([True, False])
199
+ if longer == False:
200
+ mel_spec[1:, :, :] = 0.0
201
+ data_dict = {
202
+ "hdf5_path": hdf5_path,
203
+ "index_in_hdf5": r_idx,
204
+ "audio_name": audio_name,
205
+ "waveform": waveform,
206
+ "class_label": target,
207
+ "text": text,
208
+ "longer": longer,
209
+ "mel_fusion": mel_spec,
210
+ }
211
+ return data_dict
212
+
213
+ def __len__(self):
214
+ return self.total_size
215
+
216
+
217
+ class CsvDataset(Dataset):
218
+ def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
219
+ logging.debug(f"Loading csv data from {input_filename}.")
220
+ df = pd.read_csv(input_filename, sep=sep)
221
+
222
+ self.images = df[img_key].tolist()
223
+ self.captions = df[caption_key].tolist()
224
+ self.transforms = transforms
225
+ logging.debug("Done loading data.")
226
+
227
+ def __len__(self):
228
+ return len(self.captions)
229
+
230
+ def __getitem__(self, idx):
231
+ images = self.transforms(Image.open(str(self.images[idx])))
232
+ texts = tokenize([str(self.captions[idx])])[0]
233
+ return images, texts
234
+
235
+
236
+ @dataclass
237
+ class DataInfo:
238
+ dataloader: DataLoader
239
+ sampler: DistributedSampler
240
+
241
+
242
+ def preprocess_txt(text):
243
+ return tokenize([str(text)])[0]
244
+
245
+
246
+ # def get_dataset_size(shards, sizefilepath_=None, is_local=True):
247
+ # if isinstance(shards, list):
248
+ # size_list = []
249
+ # for s in shards:
250
+ # size_list.append(
251
+ # get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0]
252
+ # )
253
+ # else:
254
+ # if not is_local:
255
+ # for n in dataset_split.keys():
256
+ # if n in shards.split("/"):
257
+ # break
258
+ # for s in dataset_split[n]:
259
+ # if s in shards.split("/"):
260
+ # break
261
+ # sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
262
+ # shards_list = list(braceexpand.braceexpand(shards))
263
+ # dir_path = os.path.dirname(shards)
264
+ # if sizefilepath_ is not None:
265
+ # sizes = json.load(open(sizefilepath_, "r"))
266
+ # total_size = sum(
267
+ # [
268
+ # int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))])
269
+ # for shard in shards_list
270
+ # ]
271
+ # )
272
+ # else:
273
+ # sizes_filename = os.path.join(dir_path, "sizes.json")
274
+ # len_filename = os.path.join(dir_path, "__len__")
275
+ # if os.path.exists(sizes_filename):
276
+ # sizes = json.load(open(sizes_filename, "r"))
277
+ # total_size = sum(
278
+ # [int(sizes[os.path.basename(shard)]) for shard in shards_list]
279
+ # )
280
+ # elif os.path.exists(len_filename):
281
+ # # FIXME this used to be eval(open(...)) but that seemed rather unsafe
282
+ # total_size = ast.literal_eval(open(len_filename, "r").read())
283
+ # else:
284
+ # raise Exception(
285
+ # "Cannot find sizes file for dataset. Please specify the path to the file."
286
+ # )
287
+ # # total_size = None # num samples undefined
288
+ # # some common dataset sizes (at time of authors last download)
289
+ # # cc3m-train: 2905954
290
+ # # cc12m: 10968539
291
+ # # LAION-400m: 407332084
292
+ # num_shards = len(shards_list)
293
+ # if isinstance(shards, list):
294
+ # return sum(size_list), len(shards)
295
+ # else:
296
+ # return total_size, num_shards
297
+
298
+
299
+ def get_imagenet(args, preprocess_fns, split):
300
+ assert split in ["train", "val", "v2"]
301
+ is_train = split == "train"
302
+ preprocess_train, preprocess_val = preprocess_fns
303
+
304
+ if split == "v2":
305
+ from imagenetv2_pytorch import ImageNetV2Dataset
306
+
307
+ dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
308
+ else:
309
+ if is_train:
310
+ data_path = args.imagenet_train
311
+ preprocess_fn = preprocess_train
312
+ else:
313
+ data_path = args.imagenet_val
314
+ preprocess_fn = preprocess_val
315
+ assert data_path
316
+
317
+ dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
318
+
319
+ if is_train:
320
+ idxs = np.zeros(len(dataset.targets))
321
+ target_array = np.array(dataset.targets)
322
+ k = 50
323
+ for c in range(1000):
324
+ m = target_array == c
325
+ n = len(idxs[m])
326
+ arr = np.zeros(n)
327
+ arr[:k] = 1
328
+ np.random.shuffle(arr)
329
+ idxs[m] = arr
330
+
331
+ idxs = idxs.astype("int")
332
+ sampler = SubsetRandomSampler(np.where(idxs)[0])
333
+ else:
334
+ sampler = None
335
+
336
+ dataloader = torch.utils.data.DataLoader(
337
+ dataset,
338
+ batch_size=args.batch_size,
339
+ num_workers=args.workers,
340
+ sampler=sampler,
341
+ )
342
+
343
+ return DataInfo(dataloader, sampler)
344
+
345
+
346
+ def count_samples(dataloader):
347
+ os.environ["WDS_EPOCH"] = "0"
348
+ n_elements, n_batches = 0, 0
349
+ for images, texts in dataloader:
350
+ n_batches += 1
351
+ n_elements += len(images)
352
+ assert len(images) == len(texts)
353
+ return n_elements, n_batches
354
+
355
+
356
+ def filter_no_caption(sample):
357
+ return "txt" in sample
358
+
359
+
360
+ def log_and_continue(exn):
361
+ """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
362
+ logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
363
+ return True
364
+
365
+
366
+ _SHARD_SHUFFLE_SIZE = 2000
367
+ _SHARD_SHUFFLE_INITIAL = 500
368
+ _SAMPLE_SHUFFLE_SIZE = 5000
369
+ _SAMPLE_SHUFFLE_INITIAL = 1000
370
+
371
+
372
+ # def sample_prop(sizefile, inputs, proportion, is_local=True):
373
+ # """
374
+ # Sample a proportion of the data.
375
+ # """
376
+ # file_path_dict = {
377
+ # os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0]
378
+ # for i in range(len(inputs))
379
+ # }
380
+ # sampled_filepath_dict = {}
381
+ # sampled_size_dict = {}
382
+ # if not is_local:
383
+ # if os.path.exists("sizes.json"):
384
+ # os.remove("sizes.json")
385
+ # wget.download(sizefile, "sizes.json")
386
+ # sizefile = "sizes.json"
387
+ # with open(sizefile, "r", encoding="UTF-8") as f:
388
+ # load_dict = json.load(f)
389
+ # L = int(len(file_path_dict) * proportion)
390
+ # subkeys = random.sample(file_path_dict.keys(), L)
391
+ # for k in subkeys:
392
+ # sampled_size_dict[k] = load_dict[k]
393
+ # sampled_filepath_dict[k] = file_path_dict[k]
394
+ # return (
395
+ # sum(sampled_size_dict.values()),
396
+ # L,
397
+ # [os.path.join(v, k) for k, v in sampled_filepath_dict.items()],
398
+ # sampled_size_dict,
399
+ # )
400
+
401
+
402
+ def get_mel(audio_data, audio_cfg):
403
+ # mel shape: (n_mels, T)
404
+ mel = torchaudio.transforms.MelSpectrogram(
405
+ sample_rate=audio_cfg["sample_rate"],
406
+ n_fft=audio_cfg["window_size"],
407
+ win_length=audio_cfg["window_size"],
408
+ hop_length=audio_cfg["hop_size"],
409
+ center=True,
410
+ pad_mode="reflect",
411
+ power=2.0,
412
+ norm=None,
413
+ onesided=True,
414
+ n_mels=64,
415
+ f_min=audio_cfg["fmin"],
416
+ f_max=audio_cfg["fmax"],
417
+ ).to(audio_data.device)
418
+ mel = mel(audio_data)
419
+ # we use log mel spectrogram as input
420
+ mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
421
+ return mel.T # (T, n_mels)
422
+
423
+
424
+ def get_audio_features(
425
+ audio_data, mel, max_len, data_truncating, data_filling, audio_cfg
426
+ ):
427
+ """
428
+ Calculate and add audio features to sample.
429
+ Sample: a dict containing all the data of current sample.
430
+ audio_data: a tensor of shape (T) containing audio data.
431
+ max_len: the maximum length of audio data.
432
+ data_truncating: the method of truncating data.
433
+ data_filling: the method of filling data.
434
+ audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg'].
435
+ """
436
+ sample = {}
437
+
438
+ # assert audio_data.size(-1) <= max_len, str(audio_data.size())
439
+
440
+ # split to three parts
441
+ chunk_frames = (
442
+ max_len // audio_cfg["hop_size"] + 1
443
+ ) # the +1 related to how the spectrogram is computed
444
+ mel = mel[:chunk_frames]
445
+
446
+ audio_data = audio_data[..., :max_len]
447
+ sample["mel_fusion"] = mel
448
+ longer = torch.tensor([True])
449
+
450
+ sample["longer"] = longer
451
+ sample["waveform"] = audio_data
452
+
453
+ return sample
454
+
455
+
456
+ def preprocess(
457
+ sample,
458
+ audio_ext,
459
+ text_ext,
460
+ max_len,
461
+ audio_cfg,
462
+ class_index_dict=None,
463
+ data_filling="pad",
464
+ data_truncating="rand_trunc",
465
+ text_augment_selection=None,
466
+ ):
467
+ """
468
+ Preprocess a single sample for wdsdataloader.
469
+ """
470
+ audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
471
+ audio_data = int16_to_float32(float32_to_int16(audio_data))
472
+ audio_data = torch.tensor(audio_data).float()
473
+
474
+ # TODO: (yusong) to be include in the future
475
+ # # if torchaudio not installed, use soundfile to load audio
476
+ # if torchaudio is None:
477
+ # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
478
+ # audio_data = torch.tensor(audio_data).float()
479
+ # else:
480
+ # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py
481
+ # with tempfile.TemporaryDirectory() as dirname:
482
+ # os.makedirs(dirname, exist_ok=True)
483
+ # fname = os.path.join(dirname, f"file.flac")
484
+ # with open(fname, "wb") as stream:
485
+ # stream.write(sample[audio_ext])
486
+ # audio_data, orig_sr = torchaudio.load(fname)
487
+ # audio_data = audio_data[0, :].float()
488
+
489
+ sample = get_audio_features(
490
+ sample, audio_data, max_len, data_truncating, data_filling, audio_cfg
491
+ )
492
+ del sample[audio_ext]
493
+
494
+ try:
495
+ json_dict_raw = json.loads(sample[text_ext].decode("utf-8"))
496
+ except:
497
+ print("sample[__url__]:", sample["__url__"])
498
+
499
+ # For selecting augmented text from dataset
500
+ if text_augment_selection is None or text_augment_selection == "none":
501
+ texts = json_dict_raw["text"]
502
+ elif text_augment_selection == "all":
503
+ if "text_augment_all" in json_dict_raw.keys():
504
+ texts = json_dict_raw["text_augment_all"]
505
+ else:
506
+ texts = json_dict_raw["text"]
507
+ elif text_augment_selection == "augment_only":
508
+ if "text_augment_all" in json_dict_raw.keys():
509
+ if json_dict_raw["text_augment_t5"] is None:
510
+ texts = json_dict_raw["text"]
511
+ else:
512
+ texts = json_dict_raw["text_augment_t5"]
513
+ else:
514
+ texts = json_dict_raw["text"]
515
+ else:
516
+ raise NotImplementedError(
517
+ f"text_augment_selection {text_augment_selection} not implemented"
518
+ )
519
+ sample["full_text"] = texts
520
+
521
+ if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1:
522
+ texts = random.choice(texts)
523
+ sample["raw_text"] = texts
524
+ sample["text"] = tokenizer(texts) # text shape: [num_token]
525
+ if class_index_dict is not None:
526
+ # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
527
+ # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
528
+ # key, val = class_index_dict
529
+ # key = key[:].split('\n')
530
+ # _dict = {k: v for k, v in zip(key, val)}
531
+ sample["class_label"] = np.zeros(len(class_index_dict.keys()))
532
+ for x in json_dict_raw["tag"]:
533
+ sample["class_label"][class_index_dict[x]] = 1
534
+ sample["class_label"] = torch.tensor(sample["class_label"]).float()
535
+ del sample[text_ext]
536
+ sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext
537
+ sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext
538
+ sample["audio_orig_sr"] = orig_sr
539
+ return sample
540
+
541
+
542
+ def collate_fn(batch):
543
+ """
544
+ Collate function for wdsdataloader.
545
+ batch: a list of dict, each dict is a sample
546
+ """
547
+ # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend.
548
+ batch_dict = {}
549
+ for k in batch[0].keys():
550
+ if isinstance(batch[0][k], dict): # dealwith bert tokenizer output
551
+ batch_dict[k] = {}
552
+ for kk in batch[0][k].keys():
553
+ tmp = []
554
+ for i in range(len(batch)):
555
+ tmp.append(batch[i][k][kk])
556
+ batch_dict[k][kk] = torch.vstack(tmp)
557
+ elif isinstance(batch[0][k], torch.Tensor):
558
+ batch_dict[k] = torch.stack([sample[k] for sample in batch])
559
+ elif isinstance(batch[0][k], np.ndarray):
560
+ batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch]))
561
+ else:
562
+ batch_dict[k] = [sample[k] for sample in batch]
563
+ return batch_dict
564
+
565
+
566
+ # def get_wds_dataset(
567
+ # args,
568
+ # model_cfg,
569
+ # is_train,
570
+ # audio_ext="flac",
571
+ # text_ext="json",
572
+ # max_len=480000,
573
+ # proportion=1.0,
574
+ # sizefilepath_=None,
575
+ # is_local=None,
576
+ # ):
577
+ # """
578
+ # Get a dataset for wdsdataloader.
579
+ # """
580
+ # if is_local is None and (not args.remotedata is None):
581
+ # is_local = not args.remotedata
582
+
583
+ # input_shards = args.train_data if is_train else args.val_data
584
+ # assert input_shards is not None
585
+
586
+ # if not sizefilepath_ is None:
587
+ # sizefilepath = sizefilepath_
588
+ # else:
589
+ # sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json")
590
+
591
+ # if proportion != 1.0:
592
+ # num_samples, num_shards, input_shards, _ = sample_prop(
593
+ # sizefilepath, input_shards, proportion, is_local=is_local
594
+ # )
595
+ # else:
596
+ # num_samples, num_shards = get_dataset_size(
597
+ # input_shards, sizefilepath_=sizefilepath_, is_local=is_local
598
+ # )
599
+
600
+ # if not num_samples:
601
+ # if is_train:
602
+ # num_samples = args.train_num_samples
603
+ # if not num_samples:
604
+ # raise RuntimeError(
605
+ # "Currently, number of dataset samples must be specified for training dataset. "
606
+ # "Please specify via `--train-num-samples` if no dataset length info present."
607
+ # )
608
+ # else:
609
+ # num_samples = (
610
+ # args.val_num_samples or 0
611
+ # ) # eval will just exhaust the iterator if not specified
612
+
613
+ # pipeline = [wds.SimpleShardList(input_shards)]
614
+ # # at this point we have an iterator over all the shards
615
+ # # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node
616
+ # if is_train or args.parallel_eval:
617
+ # pipeline.extend(
618
+ # [
619
+ # wds.detshuffle(
620
+ # bufsize=_SHARD_SHUFFLE_SIZE,
621
+ # initial=_SHARD_SHUFFLE_INITIAL,
622
+ # seed=args.seed,
623
+ # ),
624
+ # wds.split_by_node,
625
+ # wds.split_by_worker,
626
+ # # at this point, we have an iterator over the shards assigned to each worker at each node
627
+ # wds.tarfile_to_samples(handler=log_and_continue),
628
+ # wds.shuffle(
629
+ # bufsize=_SAMPLE_SHUFFLE_SIZE,
630
+ # initial=_SAMPLE_SHUFFLE_INITIAL,
631
+ # rng=random.Random(args.seed),
632
+ # ),
633
+ # # wds.repeatedly, # FIXME determine if this is beneficial
634
+ # ]
635
+ # )
636
+ # else:
637
+ # pipeline.extend(
638
+ # [
639
+ # wds.split_by_worker,
640
+ # # at this point, we have an iterator over the shards assigned to each worker
641
+ # wds.tarfile_to_samples(handler=log_and_continue),
642
+ # ]
643
+ # )
644
+ # pipeline.append(
645
+ # wds.map(
646
+ # partial(
647
+ # preprocess,
648
+ # audio_ext=audio_ext,
649
+ # text_ext=text_ext,
650
+ # max_len=max_len,
651
+ # audio_cfg=model_cfg["audio_cfg"],
652
+ # class_index_dict=copy.deepcopy(args.class_index_dict),
653
+ # data_filling=args.data_filling,
654
+ # data_truncating=args.data_truncating,
655
+ # text_augment_selection=args.text_augment_selection,
656
+ # )
657
+ # ),
658
+ # )
659
+
660
+ # pipeline.append(
661
+ # wds.batched(
662
+ # args.batch_size,
663
+ # partial=not (is_train or args.parallel_eval),
664
+ # collation_fn=collate_fn,
665
+ # )
666
+ # )
667
+
668
+ # dataset = wds.DataPipeline(*pipeline)
669
+ # if is_train or args.parallel_eval:
670
+ # # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples.
671
+ # # (yusong): See comments below.
672
+ # # roll over and repeat a few samples to get same number of full batches on each node
673
+ # global_batch_size = args.batch_size * args.world_size
674
+ # num_batches = math.ceil(num_samples / global_batch_size)
675
+ # num_workers = max(1, args.workers)
676
+ # num_worker_batches = math.ceil(
677
+ # num_batches / num_workers
678
+ # ) # per dataloader worker
679
+ # num_batches = num_worker_batches * num_workers
680
+ # num_samples = num_batches * global_batch_size
681
+ # dataset = dataset.with_epoch(
682
+ # num_worker_batches
683
+ # ) # each worker is iterating over this
684
+ # else:
685
+ # # last batches are partial, eval is done on single (master) node
686
+ # num_batches = math.ceil(num_samples / args.batch_size)
687
+
688
+ # kwargs = {}
689
+ # if args.horovod: # multi-node training on summit
690
+ # kwargs["multiprocessing_context"] = "forkserver"
691
+
692
+ # dataloader = wds.WebLoader(
693
+ # dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs
694
+ # )
695
+
696
+ # # FIXME not clear which approach is better, with_epoch before vs after dataloader?
697
+ # # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
698
+ # # if is_train:
699
+ # # # roll over and repeat a few samples to get same number of full batches on each node
700
+ # # global_batch_size = args.batch_size * args.world_size
701
+ # # num_batches = math.ceil(num_samples / global_batch_size)
702
+ # # num_workers = max(1, args.workers)
703
+ # # num_batches = math.ceil(num_batches / num_workers) * num_workers
704
+ # # num_samples = num_batches * global_batch_size
705
+ # # dataloader = dataloader.with_epoch(num_batches)
706
+ # # else:
707
+ # # # last batches are partial, eval is done on single (master) node
708
+ # # num_batches = math.ceil(num_samples / args.batch_size)
709
+
710
+ # # add meta-data to dataloader instance for convenience
711
+ # dataloader.num_batches = num_batches
712
+ # dataloader.num_samples = num_samples
713
+
714
+ # return DataInfo(dataloader, None)
715
+
716
+
717
+ def wds_batch_list2dict(
718
+ batch,
719
+ keys=[
720
+ "__url__",
721
+ "__key__",
722
+ "waveform",
723
+ "text",
724
+ "raw_text",
725
+ "audio_name",
726
+ "text_name",
727
+ "audio_orig_sr",
728
+ ],
729
+ ):
730
+ """
731
+ Return a dictionary of the batch, with keys as the names of the fields.
732
+ """
733
+ assert len(keys) == len(
734
+ batch
735
+ ), "batch must have same number of keys as keys argument"
736
+ return {keys[i]: batch[i] for i in range(len(batch))}
737
+
738
+
739
+ def get_csv_dataset(args, preprocess_fn, is_train):
740
+ input_filename = args.train_data if is_train else args.val_data
741
+ assert input_filename
742
+ dataset = CsvDataset(
743
+ input_filename,
744
+ preprocess_fn,
745
+ img_key=args.csv_img_key,
746
+ caption_key=args.csv_caption_key,
747
+ sep=args.csv_separator,
748
+ )
749
+ num_samples = len(dataset)
750
+ sampler = DistributedSampler(dataset) if args.distributed and is_train else None
751
+ shuffle = is_train and sampler is None
752
+
753
+ dataloader = DataLoader(
754
+ dataset,
755
+ batch_size=args.batch_size,
756
+ shuffle=shuffle,
757
+ num_workers=args.workers,
758
+ pin_memory=True,
759
+ sampler=sampler,
760
+ drop_last=is_train,
761
+ )
762
+ dataloader.num_samples = num_samples
763
+ dataloader.num_batches = len(dataloader)
764
+
765
+ return DataInfo(dataloader, sampler)
766
+
767
+
768
+ def get_toy_dataset(args, model_cfg, is_train):
769
+ index_path = args.train_data if is_train else args.val_data
770
+ ipc_path = args.train_ipc if is_train else args.val_ipc
771
+ assert index_path and ipc_path
772
+ eval_mode = not is_train
773
+ dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode)
774
+
775
+ num_samples = len(dataset)
776
+ sampler = (
777
+ DistributedSampler(dataset, shuffle=False)
778
+ if args.distributed and is_train
779
+ else None
780
+ )
781
+
782
+ dataloader = DataLoader(
783
+ dataset,
784
+ batch_size=args.batch_size,
785
+ shuffle=False,
786
+ num_workers=args.workers,
787
+ sampler=sampler,
788
+ drop_last=is_train,
789
+ )
790
+ dataloader.num_samples = num_samples
791
+ dataloader.num_batches = len(dataloader)
792
+
793
+ return DataInfo(dataloader, sampler)
794
+
795
+
796
+ def get_dataset_fn(data_path, dataset_type):
797
+ if dataset_type == "webdataset":
798
+ return get_wds_dataset
799
+ elif dataset_type == "csv":
800
+ return get_csv_dataset
801
+ elif dataset_type == "auto":
802
+ ext = data_path.split(".")[-1]
803
+ if ext in ["csv", "tsv"]:
804
+ return get_csv_dataset
805
+ elif ext in ["tar"]:
806
+ return get_wds_dataset
807
+ else:
808
+ raise ValueError(
809
+ f"Tried to figure out dataset type, but failed for extention {ext}."
810
+ )
811
+ elif dataset_type == "toy":
812
+ return get_toy_dataset
813
+ else:
814
+ raise ValueError(f"Unsupported dataset type: {dataset_type}")
815
+
816
+
817
+ def get_data(args, model_cfg):
818
+ data = {}
819
+
820
+ args.class_index_dict = load_class_label(args.class_label_path)
821
+
822
+ if args.datasetinfos is None:
823
+ args.datasetinfos = ["train", "unbalanced_train", "balanced_train"]
824
+ if args.dataset_type == "webdataset":
825
+ args.train_data = get_tar_path_from_dataset_name(
826
+ args.datasetnames,
827
+ args.datasetinfos,
828
+ islocal=not args.remotedata,
829
+ proportion=args.dataset_proportion,
830
+ dataset_path=args.datasetpath,
831
+ full_dataset=args.full_train_dataset,
832
+ )
833
+
834
+ if args.full_train_dataset is None:
835
+ args.full_train_dataset = []
836
+ if args.exclude_eval_dataset is None:
837
+ args.exclude_eval_dataset = []
838
+ excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset
839
+
840
+ val_dataset_names = (
841
+ [n for n in args.datasetnames if n not in excluded_eval_datasets]
842
+ if excluded_eval_datasets
843
+ else args.datasetnames
844
+ )
845
+ args.val_dataset_names = val_dataset_names
846
+ args.val_data = get_tar_path_from_dataset_name(
847
+ val_dataset_names,
848
+ ["valid", "test", "eval"],
849
+ islocal=not args.remotedata,
850
+ proportion=1,
851
+ dataset_path=args.datasetpath,
852
+ full_dataset=None,
853
+ )
854
+
855
+ if args.train_data:
856
+ data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
857
+ args, model_cfg, is_train=True
858
+ )
859
+
860
+ if args.val_data:
861
+ data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
862
+ args, model_cfg, is_train=False
863
+ )
864
+
865
+ return data
audioldm2/clap/training/params.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def get_default_params(model_name):
5
+ # Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
6
+ model_name = model_name.lower()
7
+ if "vit" in model_name:
8
+ return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
9
+ else:
10
+ return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
11
+
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "--train-data",
17
+ type=str,
18
+ default=None,
19
+ help="Path to h5 filewith training data",
20
+ )
21
+ parser.add_argument(
22
+ "--val-data",
23
+ type=str,
24
+ default=None,
25
+ help="Path to h5 file with validation data",
26
+ )
27
+ parser.add_argument(
28
+ "--freeze-text",
29
+ default=False,
30
+ action="store_true",
31
+ help="if you need to freeze the text encoder, make this True",
32
+ )
33
+ parser.add_argument(
34
+ "--freeze-text-after",
35
+ type=int,
36
+ default=-1,
37
+ help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it",
38
+ )
39
+ parser.add_argument(
40
+ "--train-ipc",
41
+ type=str,
42
+ default=None,
43
+ help="Path to npy file of the number of instance per class in training data",
44
+ )
45
+ parser.add_argument(
46
+ "--val-ipc",
47
+ type=str,
48
+ default=None,
49
+ help="Path to npy file of the number of instance per class in validation data",
50
+ )
51
+ parser.add_argument(
52
+ "--train-num-samples",
53
+ type=int,
54
+ default=None,
55
+ help="Number of samples in dataset. Required for webdataset if not available in info file.",
56
+ )
57
+ parser.add_argument(
58
+ "--val-num-samples",
59
+ type=int,
60
+ default=None,
61
+ help="Number of samples in dataset. Useful for webdataset if not available in info file.",
62
+ )
63
+ parser.add_argument(
64
+ "--dataset-type",
65
+ choices=["webdataset", "csv", "auto", "toy"],
66
+ default="auto",
67
+ help="Which type of dataset to process.",
68
+ )
69
+ parser.add_argument(
70
+ "--csv-separator",
71
+ type=str,
72
+ default="\t",
73
+ help="For csv-like datasets, which separator to use.",
74
+ )
75
+ parser.add_argument(
76
+ "--csv-img-key",
77
+ type=str,
78
+ default="filepath",
79
+ help="For csv-like datasets, the name of the key for the image paths.",
80
+ )
81
+ parser.add_argument(
82
+ "--csv-caption-key",
83
+ type=str,
84
+ default="title",
85
+ help="For csv-like datasets, the name of the key for the captions.",
86
+ )
87
+ parser.add_argument(
88
+ "--imagenet-val",
89
+ type=str,
90
+ default=None,
91
+ help="Path to imagenet val set for conducting zero shot evaluation.",
92
+ )
93
+ parser.add_argument(
94
+ "--imagenet-v2",
95
+ type=str,
96
+ default=None,
97
+ help="Path to imagenet v2 for conducting zero shot evaluation.",
98
+ )
99
+ parser.add_argument(
100
+ "--datasetnames",
101
+ nargs="+",
102
+ default=None,
103
+ help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects",
104
+ )
105
+ parser.add_argument(
106
+ "--full-train-dataset",
107
+ nargs="+",
108
+ default=None,
109
+ help="Which dataset will be trained with all the subsets. (train+test)",
110
+ )
111
+ parser.add_argument(
112
+ "--exclude-eval-dataset",
113
+ nargs="+",
114
+ default=None,
115
+ help="Which dataset will be excluded with evaluation",
116
+ )
117
+ parser.add_argument(
118
+ "--datasetinfos",
119
+ nargs="+",
120
+ default=None,
121
+ help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval",
122
+ )
123
+ parser.add_argument(
124
+ "--dataset-proportion",
125
+ type=float,
126
+ default=1.0,
127
+ help="How much proportion of dataset we want to train.",
128
+ )
129
+ parser.add_argument(
130
+ "--remotedata",
131
+ default=False,
132
+ action="store_true",
133
+ help="if the dataset is remote, set this flag",
134
+ )
135
+ parser.add_argument(
136
+ "--class-label-path",
137
+ type=str,
138
+ default=None,
139
+ help="The path of the class label pickle or csv.",
140
+ )
141
+ parser.add_argument(
142
+ "--datasetpath",
143
+ type=str,
144
+ default="/mnt/audio_clip/webdataset_tar",
145
+ help="The path to the dataset",
146
+ )
147
+ parser.add_argument(
148
+ "--logs",
149
+ type=str,
150
+ default="./logs/",
151
+ help="Where to store tensorboard logs. Use None to avoid storing logs.",
152
+ )
153
+ parser.add_argument(
154
+ "--log-local",
155
+ action="store_true",
156
+ default=False,
157
+ help="log files on local master, otherwise global master only.",
158
+ )
159
+ parser.add_argument(
160
+ "--name",
161
+ type=str,
162
+ default=None,
163
+ help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
164
+ )
165
+ parser.add_argument(
166
+ "--workers", type=int, default=1, help="Number of workers per GPU."
167
+ )
168
+ parser.add_argument(
169
+ "--batch-size", type=int, default=64, help="Batch size per GPU."
170
+ )
171
+ parser.add_argument(
172
+ "--epochs", type=int, default=32, help="Number of epochs to train for."
173
+ )
174
+ parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
175
+ parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
176
+ parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
177
+ parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
178
+ parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.")
179
+ parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
180
+
181
+ parser.add_argument(
182
+ "--split-opt",
183
+ action="store_true",
184
+ default=False,
185
+ help="Use this flag to skip the learning rate decay.",
186
+ )
187
+ parser.add_argument(
188
+ "--lr-pretrained", type=float, default=None, help="Learning rate for text."
189
+ )
190
+ parser.add_argument(
191
+ "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text."
192
+ )
193
+ parser.add_argument(
194
+ "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text."
195
+ )
196
+ parser.add_argument(
197
+ "--eps-pretrained", type=float, default=None, help="Adam epsilon for text."
198
+ )
199
+ parser.add_argument(
200
+ "--wd-pretrained", type=float, default=0.2, help="Weight decay for text."
201
+ )
202
+ parser.add_argument(
203
+ "--momentum-pretrained", type=float, default=0.9, help="Momentum for text."
204
+ )
205
+ parser.add_argument(
206
+ "--lr-new", type=float, default=None, help="Learning rate for audio."
207
+ )
208
+ parser.add_argument(
209
+ "--beta1-new", type=float, default=None, help="Adam beta 1 for audio."
210
+ )
211
+ parser.add_argument(
212
+ "--beta2-new", type=float, default=None, help="Adam beta 2 for audio."
213
+ )
214
+ parser.add_argument(
215
+ "--eps-new", type=float, default=None, help="Adam epsilon for audio."
216
+ )
217
+ parser.add_argument(
218
+ "--wd-new", type=float, default=0.2, help="Weight decay for audio."
219
+ )
220
+ parser.add_argument(
221
+ "--momentum-new", type=float, default=0.9, help="Momentum for audio."
222
+ )
223
+ parser.add_argument(
224
+ "--warmup", type=int, default=10000, help="Number of steps to warmup for."
225
+ )
226
+ parser.add_argument(
227
+ "--use-bn-sync",
228
+ default=False,
229
+ action="store_true",
230
+ help="Whether to use batch norm sync.",
231
+ )
232
+ parser.add_argument(
233
+ "--skip-scheduler",
234
+ action="store_true",
235
+ default=False,
236
+ help="Use this flag to skip the learning rate decay.",
237
+ )
238
+ parser.add_argument(
239
+ "--save-frequency", type=int, default=1, help="How often to save checkpoints."
240
+ )
241
+ parser.add_argument(
242
+ "--save-top-performance",
243
+ type=int,
244
+ default=0,
245
+ help="Save the top x performance weights if the value >0",
246
+ )
247
+ parser.add_argument(
248
+ "--save-most-recent",
249
+ action="store_true",
250
+ default=False,
251
+ help="Always save the most recent model trained to epoch_latest.pt.",
252
+ )
253
+ parser.add_argument(
254
+ "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
255
+ )
256
+ parser.add_argument(
257
+ "--val-frequency",
258
+ type=int,
259
+ default=1,
260
+ help="How often to run evaluation with val data.",
261
+ )
262
+ parser.add_argument(
263
+ "--resume",
264
+ default=None,
265
+ type=str,
266
+ help="path to latest checkpoint (default: none)",
267
+ )
268
+ parser.add_argument(
269
+ "--precision",
270
+ choices=["amp", "fp16", "fp32"],
271
+ default="amp",
272
+ help="Floating point precision.",
273
+ )
274
+ parser.add_argument(
275
+ "--amodel",
276
+ type=str,
277
+ default="RN50",
278
+ help="Name of the audio backbone to use.",
279
+ )
280
+ parser.add_argument(
281
+ "--tmodel",
282
+ type=str,
283
+ default="transformer",
284
+ help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]",
285
+ )
286
+ parser.add_argument(
287
+ "--pretrained-audio",
288
+ default="",
289
+ type=str,
290
+ help="Use a pretrained audio model weights for the audio encoder of CLAP",
291
+ )
292
+ parser.add_argument(
293
+ "--pretrained-text",
294
+ default="",
295
+ type=str,
296
+ help="Use a pretrained text model weights for the text encoder of CLAP",
297
+ )
298
+ parser.add_argument(
299
+ "--pretrained",
300
+ default="",
301
+ type=str,
302
+ help="Use a pretrained CLIP model weights with the specified tag or file path.",
303
+ )
304
+ parser.add_argument(
305
+ "--pretrained-image",
306
+ default=False,
307
+ action="store_true",
308
+ help="Load imagenet pretrained weights for image tower backbone if available.",
309
+ )
310
+ parser.add_argument(
311
+ "--lock-image",
312
+ default=False,
313
+ action="store_true",
314
+ help="Lock full image tower by disabling gradients.",
315
+ )
316
+ parser.add_argument(
317
+ "--lock-image-unlocked-groups",
318
+ type=int,
319
+ default=0,
320
+ help="Leave last n image tower layer groups unlocked.",
321
+ )
322
+ parser.add_argument(
323
+ "--lock-image-freeze-bn-stats",
324
+ default=False,
325
+ action="store_true",
326
+ help="Freeze BatchNorm running stats in image tower for any locked layers.",
327
+ )
328
+ parser.add_argument(
329
+ "--local-loss",
330
+ default=False,
331
+ action="store_true",
332
+ help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)",
333
+ )
334
+ parser.add_argument(
335
+ "--gather-with-grad",
336
+ default=False,
337
+ action="store_true",
338
+ help="enable full distributed gradient for feature gather",
339
+ )
340
+ parser.add_argument(
341
+ "--force-quick-gelu",
342
+ default=False,
343
+ action="store_true",
344
+ help="Force use of QuickGELU activation for non-OpenAI transformer models.",
345
+ )
346
+ parser.add_argument(
347
+ "--torchscript",
348
+ default=False,
349
+ action="store_true",
350
+ help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'",
351
+ )
352
+ parser.add_argument(
353
+ "--trace",
354
+ default=False,
355
+ action="store_true",
356
+ help="torch.jit.trace the model for inference / eval only",
357
+ )
358
+ # arguments for distributed training
359
+ parser.add_argument(
360
+ "--dist-url",
361
+ default="env://",
362
+ type=str,
363
+ help="url used to set up distributed training",
364
+ )
365
+ parser.add_argument(
366
+ "--dist-backend", default="nccl", type=str, help="distributed backend"
367
+ )
368
+ parser.add_argument(
369
+ "--report-to",
370
+ default="",
371
+ type=str,
372
+ help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']",
373
+ )
374
+ parser.add_argument(
375
+ "--wandb-notes", default="", type=str, help="Notes if logging with wandb"
376
+ )
377
+ parser.add_argument(
378
+ "--C", type=float, default=3.16, help="inverse regularizer for logistic reg."
379
+ )
380
+ parser.add_argument(
381
+ "--debug",
382
+ default=False,
383
+ action="store_true",
384
+ help="If true, more information is logged.",
385
+ )
386
+ parser.add_argument(
387
+ "--copy-codebase",
388
+ default=False,
389
+ action="store_true",
390
+ help="If true, we copy the entire base on the log diretory, and execute from there.",
391
+ )
392
+ parser.add_argument(
393
+ "--horovod",
394
+ default=False,
395
+ action="store_true",
396
+ help="Use horovod for distributed training.",
397
+ )
398
+ parser.add_argument(
399
+ "--ddp-static-graph",
400
+ default=False,
401
+ action="store_true",
402
+ help="Enable static graph optimization for DDP in PyTorch >= 1.11.",
403
+ )
404
+ parser.add_argument(
405
+ "--no-set-device-rank",
406
+ default=False,
407
+ action="store_true",
408
+ help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
409
+ )
410
+ parser.add_argument("--seed", type=int, default=4242, help="Default random seed.")
411
+
412
+ parser.add_argument(
413
+ "--top-k-checkpoint-select-dataset",
414
+ type=str,
415
+ default="all",
416
+ help="The dataset of selecting top-k checkpoint.",
417
+ )
418
+
419
+ # @R10, @R@5, @R1, mAP@10
420
+ parser.add_argument(
421
+ "--top-k-checkpoint-select-metric",
422
+ type=str,
423
+ default="_R@10",
424
+ help="The metric for selecting top-k checkpoint.",
425
+ )
426
+ parser.add_argument(
427
+ "--openai-model-cache-dir",
428
+ type=str,
429
+ default="~/.cache/clip",
430
+ help="Directory to download OpenAI models.",
431
+ )
432
+ parser.add_argument(
433
+ "--optimizer",
434
+ type=str,
435
+ default="adamw",
436
+ help="can be AdamW or SGD",
437
+ )
438
+ parser.add_argument(
439
+ "--parallel-eval",
440
+ default=False,
441
+ action="store_true",
442
+ help="Eval in parallel (multi-GPU, multi-node).",
443
+ )
444
+
445
+ parser.add_argument(
446
+ "--no-eval",
447
+ default=False,
448
+ action="store_true",
449
+ help="Training without evaluation.",
450
+ )
451
+
452
+ parser.add_argument(
453
+ "--lp-mlp",
454
+ default=False,
455
+ action="store_true",
456
+ help="Linear Probe using MLP layer or not.",
457
+ )
458
+
459
+ parser.add_argument(
460
+ "--lp-freeze",
461
+ default=False,
462
+ action="store_true",
463
+ help="Linear Probe using Freeze CLAP or not",
464
+ )
465
+
466
+ parser.add_argument(
467
+ "--lp-act",
468
+ default="None",
469
+ type=str,
470
+ help="Options are ['relu','elu','prelu','softmax','sigmoid']",
471
+ )
472
+
473
+ parser.add_argument(
474
+ "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe."
475
+ )
476
+
477
+ parser.add_argument(
478
+ "--lp-metrics",
479
+ type=str,
480
+ default="map,mauc,acc",
481
+ help="Metrics of Linear Probe.",
482
+ )
483
+
484
+ parser.add_argument(
485
+ "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe"
486
+ )
487
+ parser.add_argument(
488
+ "--kappa",
489
+ type=float,
490
+ default=0,
491
+ help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss",
492
+ )
493
+
494
+ parser.add_argument(
495
+ "--data-filling",
496
+ type=str,
497
+ default="pad",
498
+ help="type of data filling when the audio length is shorter than the max length."
499
+ "Can be one of the following: repeat, repeatpad, pad",
500
+ )
501
+ parser.add_argument(
502
+ "--data-truncating",
503
+ type=str,
504
+ default="rand_trunc",
505
+ help="type of data truncation when the audio length is longer than the max length."
506
+ "Can be one of the following: rand_trunc, fusion",
507
+ )
508
+
509
+ parser.add_argument(
510
+ "--clap-mlploss",
511
+ default=False,
512
+ action="store_true",
513
+ help="Using MLP loss for CLAP model or not",
514
+ )
515
+
516
+ parser.add_argument(
517
+ "--wandb-id",
518
+ type=str,
519
+ default=None,
520
+ help="the id of wandb experiment to restore.",
521
+ )
522
+
523
+ parser.add_argument(
524
+ "--sleep", type=float, default=0, help="sleep n seconds before start training"
525
+ )
526
+
527
+ # variable length processing
528
+ parser.add_argument(
529
+ "--enable-fusion",
530
+ default=False,
531
+ action="store_true",
532
+ help="Enable feature funsion for variable-length data",
533
+ )
534
+
535
+ parser.add_argument(
536
+ "--fusion-type",
537
+ type=str,
538
+ default="None",
539
+ help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']",
540
+ )
541
+
542
+ parser.add_argument(
543
+ "--mixup",
544
+ default=False,
545
+ action="store_true",
546
+ help="Enable mixup in finetuning training.",
547
+ )
548
+ parser.add_argument(
549
+ "--text-augment-selection",
550
+ type=str,
551
+ default=None,
552
+ help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']",
553
+ )
554
+
555
+ args = parser.parse_args()
556
+
557
+ # If some params are not passed, we use the default values based on model name.
558
+ default_params = get_default_params(args.amodel)
559
+ for name, val in default_params.items():
560
+ if getattr(args, name) is None:
561
+ setattr(args, name, val)
562
+
563
+ return args