akhaliq3 commited on
Commit
607ecc1
β€’
1 Parent(s): 3e4e0d7

spaces demo

Browse files
Files changed (45) hide show
  1. LICENSE +373 -0
  2. app.py +141 -0
  3. checkpoints/nws/fl/data_mean.npy +0 -0
  4. checkpoints/nws/fl/data_std.npy +0 -0
  5. checkpoints/nws/fl/epoch=4992-step=119831.ckpt +0 -0
  6. checkpoints/nws/fl/last.ckpt +0 -0
  7. checkpoints/nws/tpt/data_mean.npy +0 -0
  8. checkpoints/nws/tpt/data_std.npy +0 -0
  9. checkpoints/nws/tpt/epoch=358-step=24052.ckpt +0 -0
  10. checkpoints/nws/tpt/last.ckpt +0 -0
  11. checkpoints/nws/vn/data_mean.npy +0 -0
  12. checkpoints/nws/vn/data_std.npy +0 -0
  13. checkpoints/nws/vn/epoch=526-step=102237.ckpt +0 -0
  14. checkpoints/nws/vn/last.ckpt +0 -0
  15. gin/data/urmp_4second_crepe.gin +27 -0
  16. gin/models/newt.gin +33 -0
  17. gin/train/train_newt.gin +13 -0
  18. neural_waveshaping_synthesis/__init__.py +0 -0
  19. neural_waveshaping_synthesis/data/__init__.py +0 -0
  20. neural_waveshaping_synthesis/data/general.py +97 -0
  21. neural_waveshaping_synthesis/data/urmp.py +23 -0
  22. neural_waveshaping_synthesis/data/utils/__init__.py +0 -0
  23. neural_waveshaping_synthesis/data/utils/create_dataset.py +166 -0
  24. neural_waveshaping_synthesis/data/utils/f0_extraction.py +92 -0
  25. neural_waveshaping_synthesis/data/utils/loudness_extraction.py +89 -0
  26. neural_waveshaping_synthesis/data/utils/mfcc_extraction.py +13 -0
  27. neural_waveshaping_synthesis/data/utils/preprocess_audio.py +237 -0
  28. neural_waveshaping_synthesis/data/utils/upsampling.py +79 -0
  29. neural_waveshaping_synthesis/models/__init__.py +0 -0
  30. neural_waveshaping_synthesis/models/modules/__init__.py +0 -0
  31. neural_waveshaping_synthesis/models/modules/dynamic.py +40 -0
  32. neural_waveshaping_synthesis/models/modules/generators.py +66 -0
  33. neural_waveshaping_synthesis/models/modules/shaping.py +173 -0
  34. neural_waveshaping_synthesis/models/neural_waveshaping.py +165 -0
  35. neural_waveshaping_synthesis/utils/__init__.py +2 -0
  36. neural_waveshaping_synthesis/utils/seed_all.py +12 -0
  37. neural_waveshaping_synthesis/utils/utils.py +23 -0
  38. requirements.txt +13 -0
  39. scripts/create_dataset.py +31 -0
  40. scripts/create_urmp_dataset.py +58 -0
  41. scripts/resynthesise_dataset.py +80 -0
  42. scripts/time_buffer_sizes.py +79 -0
  43. scripts/time_forward_pass.py +62 -0
  44. scripts/train.py +81 -0
  45. setup.py +3 -0
LICENSE ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Mozilla Public License Version 2.0
2
+ ==================================
3
+
4
+ 1. Definitions
5
+ --------------
6
+
7
+ 1.1. "Contributor"
8
+ means each individual or legal entity that creates, contributes to
9
+ the creation of, or owns Covered Software.
10
+
11
+ 1.2. "Contributor Version"
12
+ means the combination of the Contributions of others (if any) used
13
+ by a Contributor and that particular Contributor's Contribution.
14
+
15
+ 1.3. "Contribution"
16
+ means Covered Software of a particular Contributor.
17
+
18
+ 1.4. "Covered Software"
19
+ means Source Code Form to which the initial Contributor has attached
20
+ the notice in Exhibit A, the Executable Form of such Source Code
21
+ Form, and Modifications of such Source Code Form, in each case
22
+ including portions thereof.
23
+
24
+ 1.5. "Incompatible With Secondary Licenses"
25
+ means
26
+
27
+ (a) that the initial Contributor has attached the notice described
28
+ in Exhibit B to the Covered Software; or
29
+
30
+ (b) that the Covered Software was made available under the terms of
31
+ version 1.1 or earlier of the License, but not also under the
32
+ terms of a Secondary License.
33
+
34
+ 1.6. "Executable Form"
35
+ means any form of the work other than Source Code Form.
36
+
37
+ 1.7. "Larger Work"
38
+ means a work that combines Covered Software with other material, in
39
+ a separate file or files, that is not Covered Software.
40
+
41
+ 1.8. "License"
42
+ means this document.
43
+
44
+ 1.9. "Licensable"
45
+ means having the right to grant, to the maximum extent possible,
46
+ whether at the time of the initial grant or subsequently, any and
47
+ all of the rights conveyed by this License.
48
+
49
+ 1.10. "Modifications"
50
+ means any of the following:
51
+
52
+ (a) any file in Source Code Form that results from an addition to,
53
+ deletion from, or modification of the contents of Covered
54
+ Software; or
55
+
56
+ (b) any new file in Source Code Form that contains any Covered
57
+ Software.
58
+
59
+ 1.11. "Patent Claims" of a Contributor
60
+ means any patent claim(s), including without limitation, method,
61
+ process, and apparatus claims, in any patent Licensable by such
62
+ Contributor that would be infringed, but for the grant of the
63
+ License, by the making, using, selling, offering for sale, having
64
+ made, import, or transfer of either its Contributions or its
65
+ Contributor Version.
66
+
67
+ 1.12. "Secondary License"
68
+ means either the GNU General Public License, Version 2.0, the GNU
69
+ Lesser General Public License, Version 2.1, the GNU Affero General
70
+ Public License, Version 3.0, or any later versions of those
71
+ licenses.
72
+
73
+ 1.13. "Source Code Form"
74
+ means the form of the work preferred for making modifications.
75
+
76
+ 1.14. "You" (or "Your")
77
+ means an individual or a legal entity exercising rights under this
78
+ License. For legal entities, "You" includes any entity that
79
+ controls, is controlled by, or is under common control with You. For
80
+ purposes of this definition, "control" means (a) the power, direct
81
+ or indirect, to cause the direction or management of such entity,
82
+ whether by contract or otherwise, or (b) ownership of more than
83
+ fifty percent (50%) of the outstanding shares or beneficial
84
+ ownership of such entity.
85
+
86
+ 2. License Grants and Conditions
87
+ --------------------------------
88
+
89
+ 2.1. Grants
90
+
91
+ Each Contributor hereby grants You a world-wide, royalty-free,
92
+ non-exclusive license:
93
+
94
+ (a) under intellectual property rights (other than patent or trademark)
95
+ Licensable by such Contributor to use, reproduce, make available,
96
+ modify, display, perform, distribute, and otherwise exploit its
97
+ Contributions, either on an unmodified basis, with Modifications, or
98
+ as part of a Larger Work; and
99
+
100
+ (b) under Patent Claims of such Contributor to make, use, sell, offer
101
+ for sale, have made, import, and otherwise transfer either its
102
+ Contributions or its Contributor Version.
103
+
104
+ 2.2. Effective Date
105
+
106
+ The licenses granted in Section 2.1 with respect to any Contribution
107
+ become effective for each Contribution on the date the Contributor first
108
+ distributes such Contribution.
109
+
110
+ 2.3. Limitations on Grant Scope
111
+
112
+ The licenses granted in this Section 2 are the only rights granted under
113
+ this License. No additional rights or licenses will be implied from the
114
+ distribution or licensing of Covered Software under this License.
115
+ Notwithstanding Section 2.1(b) above, no patent license is granted by a
116
+ Contributor:
117
+
118
+ (a) for any code that a Contributor has removed from Covered Software;
119
+ or
120
+
121
+ (b) for infringements caused by: (i) Your and any other third party's
122
+ modifications of Covered Software, or (ii) the combination of its
123
+ Contributions with other software (except as part of its Contributor
124
+ Version); or
125
+
126
+ (c) under Patent Claims infringed by Covered Software in the absence of
127
+ its Contributions.
128
+
129
+ This License does not grant any rights in the trademarks, service marks,
130
+ or logos of any Contributor (except as may be necessary to comply with
131
+ the notice requirements in Section 3.4).
132
+
133
+ 2.4. Subsequent Licenses
134
+
135
+ No Contributor makes additional grants as a result of Your choice to
136
+ distribute the Covered Software under a subsequent version of this
137
+ License (see Section 10.2) or under the terms of a Secondary License (if
138
+ permitted under the terms of Section 3.3).
139
+
140
+ 2.5. Representation
141
+
142
+ Each Contributor represents that the Contributor believes its
143
+ Contributions are its original creation(s) or it has sufficient rights
144
+ to grant the rights to its Contributions conveyed by this License.
145
+
146
+ 2.6. Fair Use
147
+
148
+ This License is not intended to limit any rights You have under
149
+ applicable copyright doctrines of fair use, fair dealing, or other
150
+ equivalents.
151
+
152
+ 2.7. Conditions
153
+
154
+ Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
155
+ in Section 2.1.
156
+
157
+ 3. Responsibilities
158
+ -------------------
159
+
160
+ 3.1. Distribution of Source Form
161
+
162
+ All distribution of Covered Software in Source Code Form, including any
163
+ Modifications that You create or to which You contribute, must be under
164
+ the terms of this License. You must inform recipients that the Source
165
+ Code Form of the Covered Software is governed by the terms of this
166
+ License, and how they can obtain a copy of this License. You may not
167
+ attempt to alter or restrict the recipients' rights in the Source Code
168
+ Form.
169
+
170
+ 3.2. Distribution of Executable Form
171
+
172
+ If You distribute Covered Software in Executable Form then:
173
+
174
+ (a) such Covered Software must also be made available in Source Code
175
+ Form, as described in Section 3.1, and You must inform recipients of
176
+ the Executable Form how they can obtain a copy of such Source Code
177
+ Form by reasonable means in a timely manner, at a charge no more
178
+ than the cost of distribution to the recipient; and
179
+
180
+ (b) You may distribute such Executable Form under the terms of this
181
+ License, or sublicense it under different terms, provided that the
182
+ license for the Executable Form does not attempt to limit or alter
183
+ the recipients' rights in the Source Code Form under this License.
184
+
185
+ 3.3. Distribution of a Larger Work
186
+
187
+ You may create and distribute a Larger Work under terms of Your choice,
188
+ provided that You also comply with the requirements of this License for
189
+ the Covered Software. If the Larger Work is a combination of Covered
190
+ Software with a work governed by one or more Secondary Licenses, and the
191
+ Covered Software is not Incompatible With Secondary Licenses, this
192
+ License permits You to additionally distribute such Covered Software
193
+ under the terms of such Secondary License(s), so that the recipient of
194
+ the Larger Work may, at their option, further distribute the Covered
195
+ Software under the terms of either this License or such Secondary
196
+ License(s).
197
+
198
+ 3.4. Notices
199
+
200
+ You may not remove or alter the substance of any license notices
201
+ (including copyright notices, patent notices, disclaimers of warranty,
202
+ or limitations of liability) contained within the Source Code Form of
203
+ the Covered Software, except that You may alter any license notices to
204
+ the extent required to remedy known factual inaccuracies.
205
+
206
+ 3.5. Application of Additional Terms
207
+
208
+ You may choose to offer, and to charge a fee for, warranty, support,
209
+ indemnity or liability obligations to one or more recipients of Covered
210
+ Software. However, You may do so only on Your own behalf, and not on
211
+ behalf of any Contributor. You must make it absolutely clear that any
212
+ such warranty, support, indemnity, or liability obligation is offered by
213
+ You alone, and You hereby agree to indemnify every Contributor for any
214
+ liability incurred by such Contributor as a result of warranty, support,
215
+ indemnity or liability terms You offer. You may include additional
216
+ disclaimers of warranty and limitations of liability specific to any
217
+ jurisdiction.
218
+
219
+ 4. Inability to Comply Due to Statute or Regulation
220
+ ---------------------------------------------------
221
+
222
+ If it is impossible for You to comply with any of the terms of this
223
+ License with respect to some or all of the Covered Software due to
224
+ statute, judicial order, or regulation then You must: (a) comply with
225
+ the terms of this License to the maximum extent possible; and (b)
226
+ describe the limitations and the code they affect. Such description must
227
+ be placed in a text file included with all distributions of the Covered
228
+ Software under this License. Except to the extent prohibited by statute
229
+ or regulation, such description must be sufficiently detailed for a
230
+ recipient of ordinary skill to be able to understand it.
231
+
232
+ 5. Termination
233
+ --------------
234
+
235
+ 5.1. The rights granted under this License will terminate automatically
236
+ if You fail to comply with any of its terms. However, if You become
237
+ compliant, then the rights granted under this License from a particular
238
+ Contributor are reinstated (a) provisionally, unless and until such
239
+ Contributor explicitly and finally terminates Your grants, and (b) on an
240
+ ongoing basis, if such Contributor fails to notify You of the
241
+ non-compliance by some reasonable means prior to 60 days after You have
242
+ come back into compliance. Moreover, Your grants from a particular
243
+ Contributor are reinstated on an ongoing basis if such Contributor
244
+ notifies You of the non-compliance by some reasonable means, this is the
245
+ first time You have received notice of non-compliance with this License
246
+ from such Contributor, and You become compliant prior to 30 days after
247
+ Your receipt of the notice.
248
+
249
+ 5.2. If You initiate litigation against any entity by asserting a patent
250
+ infringement claim (excluding declaratory judgment actions,
251
+ counter-claims, and cross-claims) alleging that a Contributor Version
252
+ directly or indirectly infringes any patent, then the rights granted to
253
+ You by any and all Contributors for the Covered Software under Section
254
+ 2.1 of this License shall terminate.
255
+
256
+ 5.3. In the event of termination under Sections 5.1 or 5.2 above, all
257
+ end user license agreements (excluding distributors and resellers) which
258
+ have been validly granted by You or Your distributors under this License
259
+ prior to termination shall survive termination.
260
+
261
+ ************************************************************************
262
+ * *
263
+ * 6. Disclaimer of Warranty *
264
+ * ------------------------- *
265
+ * *
266
+ * Covered Software is provided under this License on an "as is" *
267
+ * basis, without warranty of any kind, either expressed, implied, or *
268
+ * statutory, including, without limitation, warranties that the *
269
+ * Covered Software is free of defects, merchantable, fit for a *
270
+ * particular purpose or non-infringing. The entire risk as to the *
271
+ * quality and performance of the Covered Software is with You. *
272
+ * Should any Covered Software prove defective in any respect, You *
273
+ * (not any Contributor) assume the cost of any necessary servicing, *
274
+ * repair, or correction. This disclaimer of warranty constitutes an *
275
+ * essential part of this License. No use of any Covered Software is *
276
+ * authorized under this License except under this disclaimer. *
277
+ * *
278
+ ************************************************************************
279
+
280
+ ************************************************************************
281
+ * *
282
+ * 7. Limitation of Liability *
283
+ * -------------------------- *
284
+ * *
285
+ * Under no circumstances and under no legal theory, whether tort *
286
+ * (including negligence), contract, or otherwise, shall any *
287
+ * Contributor, or anyone who distributes Covered Software as *
288
+ * permitted above, be liable to You for any direct, indirect, *
289
+ * special, incidental, or consequential damages of any character *
290
+ * including, without limitation, damages for lost profits, loss of *
291
+ * goodwill, work stoppage, computer failure or malfunction, or any *
292
+ * and all other commercial damages or losses, even if such party *
293
+ * shall have been informed of the possibility of such damages. This *
294
+ * limitation of liability shall not apply to liability for death or *
295
+ * personal injury resulting from such party's negligence to the *
296
+ * extent applicable law prohibits such limitation. Some *
297
+ * jurisdictions do not allow the exclusion or limitation of *
298
+ * incidental or consequential damages, so this exclusion and *
299
+ * limitation may not apply to You. *
300
+ * *
301
+ ************************************************************************
302
+
303
+ 8. Litigation
304
+ -------------
305
+
306
+ Any litigation relating to this License may be brought only in the
307
+ courts of a jurisdiction where the defendant maintains its principal
308
+ place of business and such litigation shall be governed by laws of that
309
+ jurisdiction, without reference to its conflict-of-law provisions.
310
+ Nothing in this Section shall prevent a party's ability to bring
311
+ cross-claims or counter-claims.
312
+
313
+ 9. Miscellaneous
314
+ ----------------
315
+
316
+ This License represents the complete agreement concerning the subject
317
+ matter hereof. If any provision of this License is held to be
318
+ unenforceable, such provision shall be reformed only to the extent
319
+ necessary to make it enforceable. Any law or regulation which provides
320
+ that the language of a contract shall be construed against the drafter
321
+ shall not be used to construe this License against a Contributor.
322
+
323
+ 10. Versions of the License
324
+ ---------------------------
325
+
326
+ 10.1. New Versions
327
+
328
+ Mozilla Foundation is the license steward. Except as provided in Section
329
+ 10.3, no one other than the license steward has the right to modify or
330
+ publish new versions of this License. Each version will be given a
331
+ distinguishing version number.
332
+
333
+ 10.2. Effect of New Versions
334
+
335
+ You may distribute the Covered Software under the terms of the version
336
+ of the License under which You originally received the Covered Software,
337
+ or under the terms of any subsequent version published by the license
338
+ steward.
339
+
340
+ 10.3. Modified Versions
341
+
342
+ If you create software not governed by this License, and you want to
343
+ create a new license for such software, you may create and use a
344
+ modified version of this License if you rename the license and remove
345
+ any references to the name of the license steward (except to note that
346
+ such modified license differs from this License).
347
+
348
+ 10.4. Distributing Source Code Form that is Incompatible With Secondary
349
+ Licenses
350
+
351
+ If You choose to distribute Source Code Form that is Incompatible With
352
+ Secondary Licenses under the terms of this version of the License, the
353
+ notice described in Exhibit B of this License must be attached.
354
+
355
+ Exhibit A - Source Code Form License Notice
356
+ -------------------------------------------
357
+
358
+ This Source Code Form is subject to the terms of the Mozilla Public
359
+ License, v. 2.0. If a copy of the MPL was not distributed with this
360
+ file, You can obtain one at http://mozilla.org/MPL/2.0/.
361
+
362
+ If it is not possible or desirable to put the notice in a particular
363
+ file, then You may include the notice in a location (such as a LICENSE
364
+ file in a relevant directory) where a recipient would be likely to look
365
+ for such a notice.
366
+
367
+ You may add additional accurate notices of copyright ownership.
368
+
369
+ Exhibit B - "Incompatible With Secondary Licenses" Notice
370
+ ---------------------------------------------------------
371
+
372
+ This Source Code Form is "Incompatible With Secondary Licenses", as
373
+ defined by the Mozilla Public License, v. 2.0.
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import time
4
+ import warnings
5
+ warnings.filterwarnings("ignore")
6
+
7
+ import gin
8
+ import numpy as np
9
+ from scipy.io import wavfile
10
+ import torch
11
+
12
+ from neural_waveshaping_synthesis.data.utils.loudness_extraction import extract_perceptual_loudness
13
+ from neural_waveshaping_synthesis.data.utils.mfcc_extraction import extract_mfcc
14
+ from neural_waveshaping_synthesis.data.utils.f0_extraction import extract_f0_with_crepe
15
+ from neural_waveshaping_synthesis.data.utils.preprocess_audio import preprocess_audio, convert_to_float32_audio, make_monophonic, resample_audio
16
+ from neural_waveshaping_synthesis.models.modules.shaping import FastNEWT
17
+ from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping
18
+ import gradio as gr
19
+
20
+ torch.hub.download_url_to_file('https://benhayes.net/assets/audio/nws_examples/tt/tt1_in.wav', 'test1.wav')
21
+ torch.hub.download_url_to_file('https://benhayes.net/assets/audio/nws_examples/tt/tt2_in.wav', 'test2.wav')
22
+ torch.hub.download_url_to_file('https://benhayes.net/assets/audio/nws_examples/tt/tt3_in.wav', 'test3.wav')
23
+
24
+
25
+ try:
26
+ gin.constant("device", "cuda" if torch.cuda.is_available() else "cpu")
27
+ except ValueError as err:
28
+ pass
29
+
30
+ from scipy.io.wavfile import write
31
+
32
+
33
+ gin.parse_config_file("gin/models/newt.gin")
34
+ gin.parse_config_file("gin/data/urmp_4second_crepe.gin")
35
+
36
+ checkpoints = dict(Violin="vn", Flute="fl", Trumpet="tpt")
37
+
38
+ use_gpu = False
39
+ dev_string = "cuda" if use_gpu else "cpu"
40
+ device = torch.device(dev_string)
41
+
42
+
43
+
44
+ def inference(wav, instrument):
45
+ selected_checkpoint_name = instrument
46
+ selected_checkpoint = checkpoints[selected_checkpoint_name]
47
+
48
+ checkpoint_path = os.path.join(
49
+ "checkpoints/nws", selected_checkpoint)
50
+ model = NeuralWaveshaping.load_from_checkpoint(
51
+ os.path.join(checkpoint_path, "last.ckpt")).to(device)
52
+ original_newt = model.newt
53
+ model.eval()
54
+ data_mean = np.load(
55
+ os.path.join(checkpoint_path, "data_mean.npy"))
56
+ data_std = np.load(
57
+ os.path.join(checkpoint_path, "data_std.npy"))
58
+ rate, audio = wavfile.read(wav.name)
59
+ audio = convert_to_float32_audio(make_monophonic(audio))
60
+ audio = resample_audio(audio, rate, model.sample_rate)
61
+
62
+ use_full_crepe_model = False
63
+ with torch.no_grad():
64
+ f0, confidence = extract_f0_with_crepe(
65
+ audio,
66
+ full_model=use_full_crepe_model,
67
+ maximum_frequency=1000)
68
+ loudness = extract_perceptual_loudness(audio)
69
+
70
+
71
+
72
+ octave_shift = 1
73
+ loudness_scale = 0.5
74
+
75
+
76
+ loudness_floor = 0
77
+ loudness_conf_filter = 0
78
+ pitch_conf_filter = 0
79
+
80
+ pitch_smoothing = 0
81
+ loudness_smoothing = 0
82
+
83
+ with torch.no_grad():
84
+ f0_filtered = f0 * (confidence > pitch_conf_filter)
85
+ loudness_filtered = loudness * (confidence > loudness_conf_filter)
86
+ f0_shifted = f0_filtered * (2 ** octave_shift)
87
+ loudness_floored = loudness_filtered * (loudness_filtered > loudness_floor) - loudness_floor
88
+ loudness_scaled = loudness_floored * loudness_scale
89
+
90
+ loud_norm = (loudness_scaled - data_mean[1]) / data_std[1]
91
+
92
+ f0_t = torch.tensor(f0_shifted, device=device).float()
93
+ loud_norm_t = torch.tensor(loud_norm, device=device).float()
94
+
95
+ if pitch_smoothing != 0:
96
+ f0_t = torch.nn.functional.conv1d(
97
+ f0_t.expand(1, 1, -1),
98
+ torch.ones(1, 1, pitch_smoothing * 2 + 1, device=device) /
99
+ (pitch_smoothing * 2 + 1),
100
+ padding=pitch_smoothing
101
+ ).squeeze()
102
+ f0_norm_t = torch.tensor((f0_t.cpu() - data_mean[0]) / data_std[0], device=device).float()
103
+
104
+ if loudness_smoothing != 0:
105
+ loud_norm_t = torch.nn.functional.conv1d(
106
+ loud_norm_t.expand(1, 1, -1),
107
+ torch.ones(1, 1, loudness_smoothing * 2 + 1, device=device) /
108
+ (loudness_smoothing * 2 + 1),
109
+ padding=loudness_smoothing
110
+ ).squeeze()
111
+ f0_norm_t = torch.tensor((f0_t.cpu() - data_mean[0]) / data_std[0], device=device).float()
112
+
113
+ control = torch.stack((f0_norm_t, loud_norm_t), dim=0)
114
+
115
+ model.newt = FastNEWT(original_newt)
116
+
117
+ with torch.no_grad():
118
+ start_time = time.time()
119
+ out = model(f0_t.expand(1, 1, -1), control.unsqueeze(0))
120
+ run_time = time.time() - start_time
121
+ sample_rates=model.sample_rate
122
+ rtf = (audio.shape[-1] / model.sample_rate) / run_time
123
+ write('test.wav', sample_rates, out.detach().cpu().numpy().T)
124
+ return 'test.wav'
125
+
126
+ inputs = [gr.inputs.Audio(label="input audio", type="file"),
127
+ gr.inputs.Dropdown(["Violin", "Flute", "Trumpet"], type="value", default="Violin", label="Instrument")]
128
+ outputs = gr.outputs.Audio(label="output audio", type="file")
129
+
130
+
131
+ title = "neural waveshaping synthesis"
132
+ description = "demo for neural waveshaping synthesis: efficient neural audio synthesis in the waveform domain for timbre transfer. To use it, simply add your audio, or click one of the examples to load them. Read more at the links below. Input audio should be in WAV format similar to the example audio below"
133
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2107.05050'>neural waveshaping synthesis</a> | <a href='https://github.com/ben-hayes/neural-waveshaping-synthesis'>Github Repo</a></p>"
134
+
135
+ examples = [
136
+ ['test1.wav'],
137
+ ['test2.wav'],
138
+ ['test3.wav']
139
+ ]
140
+
141
+ gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples).launch()
checkpoints/nws/fl/data_mean.npy ADDED
Binary file (204 Bytes). View file
 
checkpoints/nws/fl/data_std.npy ADDED
Binary file (280 Bytes). View file
 
checkpoints/nws/fl/epoch=4992-step=119831.ckpt ADDED
Binary file (3.26 MB). View file
 
checkpoints/nws/fl/last.ckpt ADDED
Binary file (3.26 MB). View file
 
checkpoints/nws/tpt/data_mean.npy ADDED
Binary file (204 Bytes). View file
 
checkpoints/nws/tpt/data_std.npy ADDED
Binary file (280 Bytes). View file
 
checkpoints/nws/tpt/epoch=358-step=24052.ckpt ADDED
Binary file (3.26 MB). View file
 
checkpoints/nws/tpt/last.ckpt ADDED
Binary file (3.26 MB). View file
 
checkpoints/nws/vn/data_mean.npy ADDED
Binary file (204 Bytes). View file
 
checkpoints/nws/vn/data_std.npy ADDED
Binary file (280 Bytes). View file
 
checkpoints/nws/vn/epoch=526-step=102237.ckpt ADDED
Binary file (3.26 MB). View file
 
checkpoints/nws/vn/last.ckpt ADDED
Binary file (3.26 MB). View file
 
gin/data/urmp_4second_crepe.gin ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sample_rate = 16000
2
+ interpolation = None
3
+ control_hop = 128
4
+
5
+ extract_f0_with_crepe.sample_rate = %sample_rate
6
+ extract_f0_with_crepe.device = %device
7
+ extract_f0_with_crepe.full_model = True
8
+ extract_f0_with_crepe.interpolate_fn = %interpolation
9
+ extract_f0_with_crepe.hop_length = %control_hop
10
+
11
+ extract_perceptual_loudness.sample_rate = %sample_rate
12
+ extract_perceptual_loudness.interpolate_fn = %interpolation
13
+ extract_perceptual_loudness.n_fft = 1024
14
+ extract_perceptual_loudness.hop_length = %control_hop
15
+
16
+ extract_mfcc.sample_rate = %sample_rate
17
+ extract_mfcc.n_fft = 1024
18
+ extract_mfcc.hop_length = 128
19
+ extract_mfcc.n_mfcc = 16
20
+
21
+ preprocess_audio.target_sr = %sample_rate
22
+ preprocess_audio.f0_extractor = @extract_f0_with_crepe
23
+ preprocess_audio.loudness_extractor = @extract_perceptual_loudness
24
+ preprocess_audio.segment_length_in_seconds = 4
25
+ preprocess_audio.hop_length_in_seconds = 4
26
+ preprocess_audio.normalise_audio = True
27
+ preprocess_audio.control_decimation_factor = %control_hop
gin/models/newt.gin ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sample_rate = 16000
2
+
3
+ control_embedding_size = 128
4
+ n_waveshapers = 64
5
+ control_hop = 128
6
+
7
+ HarmonicOscillator.n_harmonics = 101
8
+ HarmonicOscillator.sample_rate = %sample_rate
9
+
10
+ NEWT.n_waveshapers = %n_waveshapers
11
+ NEWT.control_embedding_size = %control_embedding_size
12
+ NEWT.shaping_fn_size = 8
13
+ NEWT.out_channels = 1
14
+ TrainableNonlinearity.depth = 4
15
+
16
+ ControlModule.control_size = 2
17
+ ControlModule.hidden_size = 128
18
+ ControlModule.embedding_size = %control_embedding_size
19
+
20
+ noise_synth/TimeDistributedMLP.in_size = %control_embedding_size
21
+ noise_synth/TimeDistributedMLP.hidden_size = %control_embedding_size
22
+ noise_synth/TimeDistributedMLP.out_size = 129
23
+ noise_synth/TimeDistributedMLP.depth = 4
24
+ noise_synth/FIRNoiseSynth.ir_length = 256
25
+ noise_synth/FIRNoiseSynth.hop_length = %control_hop
26
+
27
+ Reverb.length_in_seconds = 2
28
+ Reverb.sr = %sample_rate
29
+
30
+
31
+ NeuralWaveshaping.n_waveshapers = %n_waveshapers
32
+ NeuralWaveshaping.control_hop = %control_hop
33
+ NeuralWaveshaping.sample_rate = %sample_rate
gin/train/train_newt.gin ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ get_model.model = @NeuralWaveshaping
2
+
3
+ include 'gin/models/newt.gin'
4
+
5
+ URMPDataModule.batch_size = 8
6
+
7
+ NeuralWaveshaping.learning_rate = 0.001
8
+ NeuralWaveshaping.lr_decay = 0.9
9
+ NeuralWaveshaping.lr_decay_interval = 10000
10
+
11
+ trainer_kwargs.max_steps = 120000
12
+ trainer_kwargs.gradient_clip_val = 2.0
13
+ trainer_kwargs.accelerator = 'dp'
neural_waveshaping_synthesis/__init__.py ADDED
File without changes
neural_waveshaping_synthesis/data/__init__.py ADDED
File without changes
neural_waveshaping_synthesis/data/general.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gin
4
+ import numpy as np
5
+ import pytorch_lightning as pl
6
+ import torch
7
+
8
+
9
+ class GeneralDataset(torch.utils.data.Dataset):
10
+ def __init__(self, path: str, split: str = "train", load_to_memory: bool = True):
11
+ super().__init__()
12
+ # split = "train"
13
+ self.load_to_memory = load_to_memory
14
+
15
+ self.split_path = os.path.join(path, split)
16
+ self.data_list = [
17
+ f.replace("audio_", "")
18
+ for f in os.listdir(os.path.join(self.split_path, "audio"))
19
+ if f[-4:] == ".npy"
20
+ ]
21
+ if load_to_memory:
22
+ self.audio = [
23
+ np.load(os.path.join(self.split_path, "audio", "audio_%s" % name))
24
+ for name in self.data_list
25
+ ]
26
+ self.control = [
27
+ np.load(os.path.join(self.split_path, "control", "control_%s" % name))
28
+ for name in self.data_list
29
+ ]
30
+
31
+ self.data_mean = np.load(os.path.join(path, "data_mean.npy"))
32
+ self.data_std = np.load(os.path.join(path, "data_std.npy"))
33
+
34
+ def __len__(self):
35
+ return len(self.data_list)
36
+
37
+ def __getitem__(self, idx):
38
+ # idx = 10
39
+ name = self.data_list[idx]
40
+ if self.load_to_memory:
41
+ audio = self.audio[idx]
42
+ control = self.control[idx]
43
+ else:
44
+ audio_name = "audio_%s" % name
45
+ control_name = "control_%s" % name
46
+
47
+ audio = np.load(os.path.join(self.split_path, "audio", audio_name))
48
+ control = np.load(os.path.join(self.split_path, "control", control_name))
49
+ denormalised_control = (control * self.data_std) + self.data_mean
50
+
51
+ return {
52
+ "audio": audio,
53
+ "f0": denormalised_control[0:1, :],
54
+ "amp": denormalised_control[1:2, :],
55
+ "control": control,
56
+ "name": os.path.splitext(os.path.basename(name))[0],
57
+ }
58
+
59
+
60
+ @gin.configurable
61
+ class GeneralDataModule(pl.LightningDataModule):
62
+ def __init__(
63
+ self,
64
+ data_root: str,
65
+ batch_size: int = 16,
66
+ load_to_memory: bool = True,
67
+ **dataloader_args
68
+ ):
69
+ super().__init__()
70
+ self.data_dir = data_root
71
+ self.batch_size = batch_size
72
+ self.dataloader_args = dataloader_args
73
+ self.load_to_memory = load_to_memory
74
+
75
+ def prepare_data(self):
76
+ pass
77
+
78
+ def setup(self, stage: str = None):
79
+ if stage == "fit":
80
+ self.urmp_train = GeneralDataset(self.data_dir, "train", self.load_to_memory)
81
+ self.urmp_val = GeneralDataset(self.data_dir, "val", self.load_to_memory)
82
+ elif stage == "test" or stage is None:
83
+ self.urmp_test = GeneralDataset(self.data_dir, "test", self.load_to_memory)
84
+
85
+ def _make_dataloader(self, dataset):
86
+ return torch.utils.data.DataLoader(
87
+ dataset, self.batch_size, **self.dataloader_args
88
+ )
89
+
90
+ def train_dataloader(self):
91
+ return self._make_dataloader(self.urmp_train)
92
+
93
+ def val_dataloader(self):
94
+ return self._make_dataloader(self.urmp_val)
95
+
96
+ def test_dataloader(self):
97
+ return self._make_dataloader(self.urmp_test)
neural_waveshaping_synthesis/data/urmp.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gin
4
+
5
+ from .general import GeneralDataModule
6
+
7
+
8
+ @gin.configurable
9
+ class URMPDataModule(GeneralDataModule):
10
+ def __init__(
11
+ self,
12
+ urmp_root: str,
13
+ instrument: str,
14
+ batch_size: int = 16,
15
+ load_to_memory: bool = True,
16
+ **dataloader_args
17
+ ):
18
+ super().__init__(
19
+ os.path.join(urmp_root, instrument),
20
+ batch_size,
21
+ load_to_memory,
22
+ **dataloader_args
23
+ )
neural_waveshaping_synthesis/data/utils/__init__.py ADDED
File without changes
neural_waveshaping_synthesis/data/utils/create_dataset.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from typing import Sequence
4
+
5
+ import gin
6
+ import numpy as np
7
+ from sklearn.model_selection import train_test_split
8
+
9
+ from .preprocess_audio import preprocess_audio
10
+ from ...utils import seed_all
11
+
12
+
13
+ def create_directory(path):
14
+ if not os.path.isdir(path):
15
+ try:
16
+ os.mkdir(path)
17
+ except OSError:
18
+ print("Failed to create directory %s" % path)
19
+ else:
20
+ print("Created directory %s..." % path)
21
+ else:
22
+ print("Directory %s already exists. Skipping..." % path)
23
+
24
+
25
+ def create_directories(target_root, names):
26
+ create_directory(target_root)
27
+ for name in names:
28
+ create_directory(os.path.join(target_root, name))
29
+
30
+
31
+ def make_splits(
32
+ audio_list: Sequence[str],
33
+ control_list: Sequence[str],
34
+ splits: Sequence[str],
35
+ split_proportions: Sequence[float],
36
+ ):
37
+ assert len(splits) == len(
38
+ split_proportions
39
+ ), "Length of splits and split_proportions must be equal"
40
+
41
+ train_size = split_proportions[0] / np.sum(split_proportions)
42
+ audio_0, audio_1, control_0, control_1 = train_test_split(
43
+ audio_list, control_list, train_size=train_size
44
+ )
45
+ if len(splits) == 2:
46
+ return {
47
+ splits[0]: {
48
+ "audio": audio_0,
49
+ "control": control_0,
50
+ },
51
+ splits[1]: {
52
+ "audio": audio_1,
53
+ "control": control_1,
54
+ },
55
+ }
56
+ elif len(splits) > 2:
57
+ return {
58
+ splits[0]: {
59
+ "audio": audio_0,
60
+ "control": control_0,
61
+ },
62
+ **make_splits(audio_1, control_1, splits[1:], split_proportions[1:]),
63
+ }
64
+ elif len(splits) == 1:
65
+ return {
66
+ splits[0]: {
67
+ "audio": audio_list,
68
+ "control": control_list,
69
+ }
70
+ }
71
+
72
+
73
+ def lazy_create_dataset(
74
+ files: Sequence[str],
75
+ output_directory: str,
76
+ splits: Sequence[str],
77
+ split_proportions: Sequence[float],
78
+ ):
79
+ audio_files = []
80
+ control_files = []
81
+ audio_max = 1e-5
82
+ means = []
83
+ stds = []
84
+ lengths = []
85
+ control_mean = 0
86
+ control_std = 1
87
+
88
+ for i, (all_audio, all_f0, all_confidence, all_loudness, all_mfcc) in enumerate(
89
+ preprocess_audio(files)
90
+ ):
91
+ file = os.path.split(files[i])[-1].replace(".wav", "")
92
+ for j, (audio, f0, confidence, loudness, mfcc) in enumerate(
93
+ zip(all_audio, all_f0, all_confidence, all_loudness, all_mfcc)
94
+ ):
95
+ audio_file_name = "audio_%s_%d.npy" % (file, j)
96
+ control_file_name = "control_%s_%d.npy" % (file, j)
97
+
98
+ max_sample = np.abs(audio).max()
99
+ if max_sample > audio_max:
100
+ audio_max = max_sample
101
+
102
+ np.save(
103
+ os.path.join(output_directory, "temp", "audio", audio_file_name),
104
+ audio,
105
+ )
106
+ control = np.stack((f0, loudness, confidence), axis=0)
107
+ control = np.concatenate((control, mfcc), axis=0)
108
+ np.save(
109
+ os.path.join(output_directory, "temp", "control", control_file_name),
110
+ control,
111
+ )
112
+
113
+ audio_files.append(audio_file_name)
114
+ control_files.append(control_file_name)
115
+
116
+ means.append(control.mean(axis=-1))
117
+ stds.append(control.std(axis=-1))
118
+ lengths.append(control.shape[-1])
119
+
120
+ if len(audio_files) == 0:
121
+ print("No datapoints to split. Skipping...")
122
+ return
123
+
124
+ data_mean = np.mean(np.stack(means, axis=-1), axis=-1)[:, np.newaxis]
125
+ lengths = np.stack(lengths)[np.newaxis, :]
126
+ stds = np.stack(stds, axis=-1)
127
+ data_std = np.sqrt(np.sum(lengths * stds ** 2, axis=-1) / np.sum(lengths))[
128
+ :, np.newaxis
129
+ ]
130
+
131
+ print("Saving dataset stats...")
132
+ np.save(os.path.join(output_directory, "data_mean.npy"), data_mean)
133
+ np.save(os.path.join(output_directory, "data_std.npy"), data_std)
134
+
135
+ splits = make_splits(audio_files, control_files, splits, split_proportions)
136
+ for split in splits:
137
+ for audio_file in splits[split]["audio"]:
138
+ audio = np.load(os.path.join(output_directory, "temp", "audio", audio_file))
139
+ audio = audio / audio_max
140
+ np.save(os.path.join(output_directory, split, "audio", audio_file), audio)
141
+ for control_file in splits[split]["control"]:
142
+ control = np.load(
143
+ os.path.join(output_directory, "temp", "control", control_file)
144
+ )
145
+ control = (control - data_mean) / data_std
146
+ np.save(
147
+ os.path.join(output_directory, split, "control", control_file), control
148
+ )
149
+
150
+
151
+ @gin.configurable
152
+ def create_dataset(
153
+ files: Sequence[str],
154
+ output_directory: str,
155
+ splits: Sequence[str] = ("train", "val", "test"),
156
+ split_proportions: Sequence[float] = (0.8, 0.1, 0.1),
157
+ lazy: bool = True,
158
+ ):
159
+ create_directories(output_directory, (*splits, "temp"))
160
+ for split in (*splits, "temp"):
161
+ create_directories(os.path.join(output_directory, split), ("audio", "control"))
162
+
163
+ if lazy:
164
+ lazy_create_dataset(files, output_directory, splits, split_proportions)
165
+
166
+ shutil.rmtree(os.path.join(output_directory, "temp"))
neural_waveshaping_synthesis/data/utils/f0_extraction.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Callable, Optional, Sequence, Union
3
+
4
+ import gin
5
+ import librosa
6
+ import numpy as np
7
+ import torch
8
+ import torchcrepe
9
+
10
+ from .upsampling import linear_interpolation
11
+ from ...utils import apply
12
+
13
+
14
+ CREPE_WINDOW_LENGTH = 1024
15
+
16
+ @gin.configurable
17
+ def extract_f0_with_crepe(
18
+ audio: np.ndarray,
19
+ sample_rate: float,
20
+ hop_length: int = 128,
21
+ minimum_frequency: float = 50.0,
22
+ maximum_frequency: float = 2000.0,
23
+ full_model: bool = True,
24
+ batch_size: int = 2048,
25
+ device: Union[str, torch.device] = "cpu",
26
+ interpolate_fn: Optional[Callable] = linear_interpolation,
27
+ ):
28
+ # convert to torch tensor with channel dimension (necessary for CREPE)
29
+ audio = torch.tensor(audio).unsqueeze(0)
30
+ f0, confidence = torchcrepe.predict(
31
+ audio,
32
+ sample_rate,
33
+ hop_length,
34
+ minimum_frequency,
35
+ maximum_frequency,
36
+ "full" if full_model else "tiny",
37
+ batch_size=batch_size,
38
+ device=device,
39
+ decoder=torchcrepe.decode.viterbi,
40
+ # decoder=torchcrepe.decode.weighted_argmax,
41
+ return_harmonicity=True,
42
+ )
43
+
44
+ f0, confidence = f0.squeeze().numpy(), confidence.squeeze().numpy()
45
+
46
+ if interpolate_fn:
47
+ f0 = interpolate_fn(
48
+ f0, CREPE_WINDOW_LENGTH, hop_length, original_length=audio.shape[-1]
49
+ )
50
+ confidence = interpolate_fn(
51
+ confidence,
52
+ CREPE_WINDOW_LENGTH,
53
+ hop_length,
54
+ original_length=audio.shape[-1],
55
+ )
56
+
57
+ return f0, confidence
58
+
59
+
60
+ @gin.configurable
61
+ def extract_f0_with_pyin(
62
+ audio: np.ndarray,
63
+ sample_rate: float,
64
+ minimum_frequency: float = 65.0, # recommended minimum freq from librosa docs
65
+ maximum_frequency: float = 2093.0, # recommended maximum freq from librosa docs
66
+ frame_length: int = 1024,
67
+ hop_length: int = 128,
68
+ fill_na: Optional[float] = None,
69
+ interpolate_fn: Optional[Callable] = linear_interpolation,
70
+ ):
71
+ f0, _, voiced_prob = librosa.pyin(
72
+ audio,
73
+ sr=sample_rate,
74
+ fmin=minimum_frequency,
75
+ fmax=maximum_frequency,
76
+ frame_length=frame_length,
77
+ hop_length=hop_length,
78
+ fill_na=fill_na,
79
+ )
80
+
81
+ if interpolate_fn:
82
+ f0 = interpolate_fn(
83
+ f0, frame_length, hop_length, original_length=audio.shape[-1]
84
+ )
85
+ voiced_prob = interpolate_fn(
86
+ voiced_prob,
87
+ frame_length,
88
+ hop_length,
89
+ original_length=audio.shape[-1],
90
+ )
91
+
92
+ return f0, voiced_prob
neural_waveshaping_synthesis/data/utils/loudness_extraction.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional
2
+ import warnings
3
+
4
+ import gin
5
+ import librosa
6
+ import numpy as np
7
+
8
+ from .upsampling import linear_interpolation
9
+
10
+
11
+ def compute_power_spectrogram(
12
+ audio: np.ndarray,
13
+ n_fft: int,
14
+ hop_length: int,
15
+ window: str,
16
+ epsilon: float,
17
+ ):
18
+ spectrogram = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length, window=window)
19
+ magnitude_spectrogram = np.abs(spectrogram)
20
+ power_spectrogram = librosa.amplitude_to_db(
21
+ magnitude_spectrogram, ref=np.max, amin=epsilon
22
+ )
23
+ return power_spectrogram
24
+
25
+
26
+ def perform_perceptual_weighting(
27
+ power_spectrogram_in_db: np.ndarray, sample_rate: float, n_fft: int
28
+ ):
29
+ centre_frequencies = librosa.fft_frequencies(sample_rate, n_fft)
30
+
31
+ # We know that we will get a log(0) warning here due to the DC component -- we can
32
+ # safely ignore as it is clipped to the default min dB value of -80.0 dB
33
+ with warnings.catch_warnings():
34
+ warnings.simplefilter("ignore")
35
+ weights = librosa.A_weighting(centre_frequencies)
36
+
37
+ weights = np.expand_dims(weights, axis=1)
38
+ weighted_spectrogram = power_spectrogram_in_db # + weights
39
+ return weighted_spectrogram
40
+
41
+
42
+ @gin.configurable
43
+ def extract_perceptual_loudness(
44
+ audio: np.ndarray,
45
+ sample_rate: float = 16000,
46
+ n_fft: int = 2048,
47
+ hop_length: int = 512,
48
+ window: str = "hann",
49
+ epsilon: float = 1e-5,
50
+ interpolate_fn: Optional[Callable] = linear_interpolation,
51
+ normalise: bool = True,
52
+ ):
53
+ power_spectrogram = compute_power_spectrogram(
54
+ audio, n_fft=n_fft, hop_length=hop_length, window=window, epsilon=epsilon
55
+ )
56
+ perceptually_weighted_spectrogram = perform_perceptual_weighting(
57
+ power_spectrogram, sample_rate=sample_rate, n_fft=n_fft
58
+ )
59
+ loudness = np.mean(perceptually_weighted_spectrogram, axis=0)
60
+ if interpolate_fn:
61
+ loudness = interpolate_fn(
62
+ loudness, n_fft, hop_length, original_length=audio.size
63
+ )
64
+
65
+ if normalise:
66
+ loudness = (loudness + 80) / 80
67
+
68
+ return loudness
69
+
70
+
71
+ @gin.configurable
72
+ def extract_rms(
73
+ audio: np.ndarray,
74
+ window_size: int = 2048,
75
+ hop_length: int = 512,
76
+ sample_rate: Optional[float] = 16000.0,
77
+ interpolate_fn: Optional[Callable] = linear_interpolation,
78
+ ):
79
+ # pad audio to centre frames
80
+ padded_audio = np.pad(audio, (window_size // 2, window_size // 2))
81
+ frames = librosa.util.frame(padded_audio, window_size, hop_length)
82
+ squared = frames ** 2
83
+ mean = np.mean(squared, axis=0)
84
+ root = np.sqrt(mean)
85
+ if interpolate_fn:
86
+ assert sample_rate is not None, "Must provide sample rate if upsampling"
87
+ root = interpolate_fn(root, window_size, hop_length, original_length=audio.size)
88
+
89
+ return root
neural_waveshaping_synthesis/data/utils/mfcc_extraction.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gin
2
+ import librosa
3
+ import numpy as np
4
+
5
+
6
+ @gin.configurable
7
+ def extract_mfcc(
8
+ audio: np.ndarray, sample_rate: float, n_fft: int, hop_length: int, n_mfcc: int
9
+ ):
10
+ mfcc = librosa.feature.mfcc(
11
+ audio, sr=sample_rate, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length
12
+ )
13
+ return mfcc
neural_waveshaping_synthesis/data/utils/preprocess_audio.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Callable, Sequence, Union
3
+
4
+ import gin
5
+ import librosa
6
+ import numpy as np
7
+ import resampy
8
+ import scipy.io.wavfile as wavfile
9
+
10
+ from .f0_extraction import extract_f0_with_crepe, extract_f0_with_pyin
11
+ from .loudness_extraction import extract_perceptual_loudness, extract_rms
12
+ from .mfcc_extraction import extract_mfcc
13
+ from ...utils import apply, apply_unpack, unzip
14
+
15
+
16
+ def read_audio_files(files: list):
17
+ rates_and_audios = apply(wavfile.read, files)
18
+ return unzip(rates_and_audios)
19
+
20
+
21
+ def convert_to_float32_audio(audio: np.ndarray):
22
+ if audio.dtype == np.float32:
23
+ return audio
24
+
25
+ max_sample_value = np.iinfo(audio.dtype).max
26
+ floating_point_audio = audio / max_sample_value
27
+ return floating_point_audio.astype(np.float32)
28
+
29
+
30
+ def make_monophonic(audio: np.ndarray, strategy: str = "keep_left"):
31
+ # deal with non stereo array formats
32
+ if len(audio.shape) == 1:
33
+ return audio
34
+ elif len(audio.shape) != 2:
35
+ raise ValueError("Unknown audio array format.")
36
+
37
+ # deal with single audio channel
38
+ if audio.shape[0] == 1:
39
+ return audio[0]
40
+ elif audio.shape[1] == 1:
41
+ return audio[:, 0]
42
+ # deal with more than two channels
43
+ elif audio.shape[0] != 2 and audio.shape[1] != 2:
44
+ raise ValueError("Expected stereo input audio but got too many channels.")
45
+
46
+ # put channel first
47
+ if audio.shape[1] == 2:
48
+ audio = audio.T
49
+
50
+ # make stereo audio monophonic
51
+ if strategy == "keep_left":
52
+ return audio[0]
53
+ elif strategy == "keep_right":
54
+ return audio[1]
55
+ elif strategy == "sum":
56
+ return np.mean(audio, axis=0)
57
+ elif strategy == "diff":
58
+ return audio[0] - audio[1]
59
+
60
+
61
+ def normalise_signal(audio: np.ndarray, factor: float):
62
+ return audio / factor
63
+
64
+
65
+ def resample_audio(audio: np.ndarray, original_sr: float, target_sr: float):
66
+ return resampy.resample(audio, original_sr, target_sr)
67
+
68
+
69
+ def segment_signal(
70
+ signal: np.ndarray,
71
+ sample_rate: float,
72
+ segment_length_in_seconds: float,
73
+ hop_length_in_seconds: float,
74
+ ):
75
+ segment_length_in_samples = int(sample_rate * segment_length_in_seconds)
76
+ hop_length_in_samples = int(sample_rate * hop_length_in_seconds)
77
+ segments = librosa.util.frame(
78
+ signal, segment_length_in_samples, hop_length_in_samples
79
+ )
80
+ return segments
81
+
82
+
83
+ def filter_segments(
84
+ threshold: float,
85
+ key_segments: np.ndarray,
86
+ segments: Sequence[np.ndarray],
87
+ ):
88
+ mean_keys = key_segments.mean(axis=0)
89
+ mask = mean_keys > threshold
90
+ filtered_segments = apply(
91
+ lambda x: x[:, mask] if len(x.shape) == 2 else x[:, :, mask], segments
92
+ )
93
+ return filtered_segments
94
+
95
+
96
+ def preprocess_single_audio_file(
97
+ file: str,
98
+ control_decimation_factor: float,
99
+ target_sr: float = 16000.0,
100
+ segment_length_in_seconds: float = 4.0,
101
+ hop_length_in_seconds: float = 2.0,
102
+ confidence_threshold: float = 0.85,
103
+ f0_extractor: Callable = extract_f0_with_crepe,
104
+ loudness_extractor: Callable = extract_perceptual_loudness,
105
+ mfcc_extractor: Callable = extract_mfcc,
106
+ normalisation_factor: Union[float, None] = None,
107
+ ):
108
+ print("Loading audio file: %s..." % file)
109
+ original_sr, audio = wavfile.read(file)
110
+ audio = convert_to_float32_audio(audio)
111
+ audio = make_monophonic(audio)
112
+
113
+ if normalisation_factor:
114
+ audio = normalise_signal(audio, normalisation_factor)
115
+
116
+ print("Resampling audio file: %s..." % file)
117
+ audio = resample_audio(audio, original_sr, target_sr)
118
+
119
+ print("Extracting f0 with extractor '%s': %s..." % (f0_extractor.__name__, file))
120
+ f0, confidence = f0_extractor(audio)
121
+
122
+ print(
123
+ "Extracting loudness with extractor '%s': %s..."
124
+ % (loudness_extractor.__name__, file)
125
+ )
126
+ loudness = loudness_extractor(audio)
127
+
128
+ print(
129
+ "Extracting MFCC with extractor '%s': %s..." % (mfcc_extractor.__name__, file)
130
+ )
131
+ mfcc = mfcc_extractor(audio)
132
+
133
+ print("Segmenting audio file: %s..." % file)
134
+ segmented_audio = segment_signal(
135
+ audio, target_sr, segment_length_in_seconds, hop_length_in_seconds
136
+ )
137
+
138
+ print("Segmenting control signals: %s..." % file)
139
+ segmented_f0 = segment_signal(
140
+ f0,
141
+ target_sr / (control_decimation_factor or 1),
142
+ segment_length_in_seconds,
143
+ hop_length_in_seconds,
144
+ )
145
+ segmented_confidence = segment_signal(
146
+ confidence,
147
+ target_sr / (control_decimation_factor or 1),
148
+ segment_length_in_seconds,
149
+ hop_length_in_seconds,
150
+ )
151
+ segmented_loudness = segment_signal(
152
+ loudness,
153
+ target_sr / (control_decimation_factor or 1),
154
+ segment_length_in_seconds,
155
+ hop_length_in_seconds,
156
+ )
157
+ segmented_mfcc = segment_signal(
158
+ mfcc,
159
+ target_sr / (control_decimation_factor or 1),
160
+ segment_length_in_seconds,
161
+ hop_length_in_seconds,
162
+ )
163
+
164
+ (
165
+ filtered_audio,
166
+ filtered_f0,
167
+ filtered_confidence,
168
+ filtered_loudness,
169
+ filtered_mfcc,
170
+ ) = filter_segments(
171
+ confidence_threshold,
172
+ segmented_confidence,
173
+ (
174
+ segmented_audio,
175
+ segmented_f0,
176
+ segmented_confidence,
177
+ segmented_loudness,
178
+ segmented_mfcc,
179
+ ),
180
+ )
181
+
182
+ if filtered_audio.shape[-1] == 0:
183
+ print("No segments exceeding confidence threshold...")
184
+ audio_split, f0_split, confidence_split, loudness_split, mfcc_split = (
185
+ [],
186
+ [],
187
+ [],
188
+ [],
189
+ [],
190
+ )
191
+ else:
192
+ split = lambda x: [e.squeeze() for e in np.split(x, x.shape[-1], -1)]
193
+ audio_split = split(filtered_audio)
194
+ f0_split = split(filtered_f0)
195
+ confidence_split = split(filtered_confidence)
196
+ loudness_split = split(filtered_loudness)
197
+ mfcc_split = split(filtered_mfcc)
198
+
199
+ return audio_split, f0_split, confidence_split, loudness_split, mfcc_split
200
+
201
+
202
+ @gin.configurable
203
+ def preprocess_audio(
204
+ files: list,
205
+ control_decimation_factor: float,
206
+ target_sr: float = 16000,
207
+ segment_length_in_seconds: float = 4.0,
208
+ hop_length_in_seconds: float = 2.0,
209
+ confidence_threshold: float = 0.85,
210
+ f0_extractor: Callable = extract_f0_with_crepe,
211
+ loudness_extractor: Callable = extract_perceptual_loudness,
212
+ normalise_audio: bool = False,
213
+ ):
214
+ if normalise_audio:
215
+ print("Finding normalisation factor...")
216
+ normalisation_factor = 0
217
+ for file in files:
218
+ _, audio = wavfile.read(file)
219
+ audio = convert_to_float32_audio(audio)
220
+ audio = make_monophonic(audio)
221
+ max_value = np.abs(audio).max()
222
+ normalisation_factor = (
223
+ max_value if max_value > normalisation_factor else normalisation_factor
224
+ )
225
+
226
+ processor = partial(
227
+ preprocess_single_audio_file,
228
+ control_decimation_factor=control_decimation_factor,
229
+ target_sr=target_sr,
230
+ segment_length_in_seconds=segment_length_in_seconds,
231
+ hop_length_in_seconds=hop_length_in_seconds,
232
+ f0_extractor=f0_extractor,
233
+ loudness_extractor=loudness_extractor,
234
+ normalisation_factor=None if not normalise_audio else normalisation_factor,
235
+ )
236
+ for file in files:
237
+ yield processor(file)
neural_waveshaping_synthesis/data/utils/upsampling.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import gin
4
+ import numpy as np
5
+ import scipy.interpolate
6
+ import scipy.signal.windows
7
+
8
+
9
+ def get_padded_length(frames: int, window_length: int, hop_length: int):
10
+ return frames * hop_length + window_length - hop_length
11
+
12
+
13
+ def get_source_target_axes(frames: int, window_length: int, hop_length: int):
14
+ padded_length = get_padded_length(frames, window_length, hop_length)
15
+ source_x = np.linspace(0, frames - 1, frames)
16
+ target_x = np.linspace(0, frames - 1, padded_length)
17
+ return source_x, target_x
18
+
19
+
20
+ @gin.configurable
21
+ def linear_interpolation(
22
+ signal: np.ndarray,
23
+ window_length: int,
24
+ hop_length: int,
25
+ original_length: Optional[int] = None,
26
+ ):
27
+ source_x, target_x = get_source_target_axes(signal.size, window_length, hop_length)
28
+
29
+ interpolated = np.interp(target_x, source_x, signal)
30
+ if original_length:
31
+ interpolated = interpolated[window_length // 2 :]
32
+ interpolated = interpolated[:original_length]
33
+
34
+ return interpolated
35
+
36
+
37
+ @gin.configurable
38
+ def cubic_spline_interpolation(
39
+ signal: np.ndarray,
40
+ window_length: int,
41
+ hop_length: int,
42
+ original_length: Optional[int] = None,
43
+ ):
44
+ source_x, target_x = get_source_target_axes(signal.size, window_length, hop_length)
45
+
46
+ interpolant = scipy.interpolate.interp1d(source_x, signal, kind="cubic")
47
+ interpolated = interpolant(target_x)
48
+ if original_length:
49
+ interpolated = interpolated[window_length // 2 :]
50
+ interpolated = interpolated[:original_length]
51
+
52
+ return interpolated
53
+
54
+
55
+ @gin.configurable
56
+ def overlap_add_upsample(
57
+ signal: np.ndarray,
58
+ window_length: int,
59
+ hop_length: int,
60
+ window_fn: str = "hann",
61
+ window_scale: int = 2,
62
+ original_length: Optional[int] = None,
63
+ ):
64
+ window = scipy.signal.windows.get_window(window_fn, hop_length * window_scale)
65
+ padded_length = get_padded_length(signal.size, window_length, hop_length)
66
+ padded_output = np.zeros(padded_length)
67
+
68
+ for i, value in enumerate(signal):
69
+ window_start = i * hop_length
70
+ window_end = window_start + hop_length * window_scale
71
+ padded_output[window_start:window_end] += window * value
72
+
73
+ if original_length:
74
+ output = padded_output[(padded_length - original_length) // 2:]
75
+ output = output[:original_length]
76
+ else:
77
+ output = padded_output
78
+
79
+ return output
neural_waveshaping_synthesis/models/__init__.py ADDED
File without changes
neural_waveshaping_synthesis/models/modules/__init__.py ADDED
File without changes
neural_waveshaping_synthesis/models/modules/dynamic.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gin
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class FiLM(nn.Module):
7
+ def forward(self, x, gamma, beta):
8
+ return gamma * x + beta
9
+
10
+
11
+ class TimeDistributedLayerNorm(nn.Module):
12
+ def __init__(self, size: int):
13
+ super().__init__()
14
+ self.layer_norm = nn.LayerNorm(size)
15
+
16
+ def forward(self, x):
17
+ return self.layer_norm(x.transpose(1, 2)).transpose(1, 2)
18
+
19
+
20
+ @gin.configurable
21
+ class TimeDistributedMLP(nn.Module):
22
+ def __init__(self, in_size: int, hidden_size: int, out_size: int, depth: int = 3):
23
+ super().__init__()
24
+ assert depth >= 3, "Depth must be at least 3"
25
+ layers = []
26
+ for i in range(depth):
27
+ layers.append(
28
+ nn.Conv1d(
29
+ in_size if i == 0 else hidden_size,
30
+ hidden_size if i < depth - 1 else out_size,
31
+ 1,
32
+ )
33
+ )
34
+ if i < depth - 1:
35
+ layers.append(TimeDistributedLayerNorm(hidden_size))
36
+ layers.append(nn.LeakyReLU())
37
+ self.net = nn.Sequential(*layers)
38
+
39
+ def forward(self, x):
40
+ return self.net(x)
neural_waveshaping_synthesis/models/modules/generators.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable
3
+
4
+ import gin
5
+ import torch
6
+ import torch.fft
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ @gin.configurable
12
+ class FIRNoiseSynth(nn.Module):
13
+ def __init__(
14
+ self, ir_length: int, hop_length: int, window_fn: Callable = torch.hann_window
15
+ ):
16
+ super().__init__()
17
+ self.ir_length = ir_length
18
+ self.hop_length = hop_length
19
+ self.register_buffer("window", window_fn(ir_length))
20
+
21
+ def forward(self, H_re):
22
+ H_im = torch.zeros_like(H_re)
23
+ H_z = torch.complex(H_re, H_im)
24
+
25
+ h = torch.fft.irfft(H_z.transpose(1, 2))
26
+ h = h.roll(self.ir_length // 2, -1)
27
+ h = h * self.window.view(1, 1, -1)
28
+ H = torch.fft.rfft(h)
29
+
30
+ noise = torch.rand(self.hop_length * H_re.shape[-1] - 1, device=H_re.device)
31
+ X = torch.stft(noise, self.ir_length, self.hop_length, return_complex=True)
32
+ X = X.unsqueeze(0)
33
+ Y = X * H.transpose(1, 2)
34
+ y = torch.istft(Y, self.ir_length, self.hop_length, center=False)
35
+ return y.unsqueeze(1)[:, :, : H_re.shape[-1] * self.hop_length]
36
+
37
+
38
+ @gin.configurable
39
+ class HarmonicOscillator(nn.Module):
40
+ def __init__(self, n_harmonics, sample_rate):
41
+ super().__init__()
42
+ self.sample_rate = sample_rate
43
+ self.n_harmonics = n_harmonics
44
+ self.register_buffer("harmonic_axis", self._create_harmonic_axis(n_harmonics))
45
+ self.register_buffer("rand_phase", torch.ones(1, n_harmonics, 1) * math.tau)
46
+
47
+ def _create_harmonic_axis(self, n_harmonics):
48
+ return torch.arange(1, n_harmonics + 1).view(1, -1, 1)
49
+
50
+ def _create_antialias_mask(self, f0):
51
+ freqs = f0.unsqueeze(1) * self.harmonic_axis
52
+ return freqs < (self.sample_rate / 2)
53
+
54
+ def _create_phase_shift(self, n_harmonics):
55
+ shift = torch.rand_like(self.rand_phase) * self.rand_phase - math.pi
56
+ return shift
57
+
58
+ def forward(self, f0):
59
+ phase = math.tau * f0.cumsum(-1) / self.sample_rate
60
+ harmonic_phase = self.harmonic_axis * phase.unsqueeze(1)
61
+ harmonic_phase = harmonic_phase + self._create_phase_shift(self.n_harmonics)
62
+ antialias_mask = self._create_antialias_mask(f0)
63
+
64
+ output = torch.sin(harmonic_phase) * antialias_mask
65
+
66
+ return output
neural_waveshaping_synthesis/models/modules/shaping.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gin
2
+ import torch
3
+ import torch.fft
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .dynamic import FiLM, TimeDistributedMLP
8
+
9
+
10
+ class Sine(nn.Module):
11
+ def forward(self, x: torch.Tensor):
12
+ return torch.sin(x)
13
+
14
+
15
+ @gin.configurable
16
+ class TrainableNonlinearity(nn.Module):
17
+ def __init__(
18
+ self, channels, width, nonlinearity=nn.ReLU, final_nonlinearity=Sine, depth=3
19
+ ):
20
+ super().__init__()
21
+ self.input_scale = nn.Parameter(torch.randn(1, channels, 1) * 10)
22
+ layers = []
23
+ for i in range(depth):
24
+ layers.append(
25
+ nn.Conv1d(
26
+ channels if i == 0 else channels * width,
27
+ channels * width if i < depth - 1 else channels,
28
+ 1,
29
+ groups=channels,
30
+ )
31
+ )
32
+ layers.append(nonlinearity() if i < depth - 1 else final_nonlinearity())
33
+
34
+ self.net = nn.Sequential(*layers)
35
+
36
+ def forward(self, x):
37
+ return self.net(self.input_scale * x)
38
+
39
+
40
+ @gin.configurable
41
+ class NEWT(nn.Module):
42
+ def __init__(
43
+ self,
44
+ n_waveshapers: int,
45
+ control_embedding_size: int,
46
+ shaping_fn_size: int = 16,
47
+ out_channels: int = 1,
48
+ ):
49
+ super().__init__()
50
+
51
+ self.n_waveshapers = n_waveshapers
52
+
53
+ self.mlp = TimeDistributedMLP(
54
+ control_embedding_size, control_embedding_size, n_waveshapers * 4, depth=4
55
+ )
56
+
57
+ self.waveshaping_index = FiLM()
58
+ self.shaping_fn = TrainableNonlinearity(
59
+ n_waveshapers, shaping_fn_size, nonlinearity=Sine
60
+ )
61
+ self.normalising_coeff = FiLM()
62
+
63
+ self.mixer = nn.Sequential(
64
+ nn.Conv1d(n_waveshapers, out_channels, 1),
65
+ )
66
+
67
+ def forward(self, exciter, control_embedding):
68
+ film_params = self.mlp(control_embedding)
69
+ film_params = F.upsample(film_params, exciter.shape[-1], mode="linear")
70
+ gamma_index, beta_index, gamma_norm, beta_norm = torch.split(
71
+ film_params, self.n_waveshapers, 1
72
+ )
73
+
74
+ x = self.waveshaping_index(exciter, gamma_index, beta_index)
75
+ x = self.shaping_fn(x)
76
+ x = self.normalising_coeff(x, gamma_norm, beta_norm)
77
+
78
+ # return x
79
+ return self.mixer(x)
80
+
81
+
82
+ class FastNEWT(NEWT):
83
+ def __init__(
84
+ self,
85
+ newt: NEWT,
86
+ table_size: int = 4096,
87
+ table_min: float = -3.0,
88
+ table_max: float = 3.0,
89
+ ):
90
+ super().__init__()
91
+ self.table_size = table_size
92
+ self.table_min = table_min
93
+ self.table_max = table_max
94
+
95
+ self.n_waveshapers = newt.n_waveshapers
96
+ self.mlp = newt.mlp
97
+
98
+ self.waveshaping_index = newt.waveshaping_index
99
+ self.normalising_coeff = newt.normalising_coeff
100
+ self.mixer = newt.mixer
101
+
102
+ self.lookup_table = self._init_lookup_table(
103
+ newt, table_size, self.n_waveshapers, table_min, table_max
104
+ )
105
+ self.to(next(iter(newt.parameters())).device)
106
+
107
+ def _init_lookup_table(
108
+ self,
109
+ newt: NEWT,
110
+ table_size: int,
111
+ n_waveshapers: int,
112
+ table_min: float,
113
+ table_max: float,
114
+ ):
115
+ sample_values = torch.linspace(table_min, table_max, table_size, device=next(iter(newt.parameters())).device).expand(
116
+ 1, n_waveshapers, table_size
117
+ )
118
+ lookup_table = newt.shaping_fn(sample_values)[0]
119
+ return nn.Parameter(lookup_table)
120
+
121
+ def _lookup(self, idx):
122
+ return torch.stack(
123
+ [
124
+ torch.stack(
125
+ [
126
+ self.lookup_table[shaper, idx[batch, shaper]]
127
+ for shaper in range(idx.shape[1])
128
+ ],
129
+ dim=0,
130
+ )
131
+ for batch in range(idx.shape[0])
132
+ ],
133
+ dim=0,
134
+ )
135
+
136
+ def shaping_fn(self, x):
137
+ idx = self.table_size * (x - self.table_min) / (self.table_max - self.table_min)
138
+
139
+ lower = torch.floor(idx).long()
140
+ lower[lower < 0] = 0
141
+ lower[lower >= self.table_size] = self.table_size - 1
142
+
143
+ upper = lower + 1
144
+ upper[upper >= self.table_size] = self.table_size - 1
145
+
146
+ fract = idx - lower
147
+ lower_v = self._lookup(lower)
148
+ upper_v = self._lookup(upper)
149
+
150
+ output = (upper_v - lower_v) * fract + lower_v
151
+ return output
152
+
153
+
154
+ @gin.configurable
155
+ class Reverb(nn.Module):
156
+ def __init__(self, length_in_seconds, sr):
157
+ super().__init__()
158
+ self.ir = nn.Parameter(torch.randn(1, sr * length_in_seconds - 1) * 1e-6)
159
+ self.register_buffer("initial_zero", torch.zeros(1, 1))
160
+
161
+ def forward(self, x):
162
+ ir_ = torch.cat((self.initial_zero, self.ir), dim=-1)
163
+ if x.shape[-1] > ir_.shape[-1]:
164
+ ir_ = F.pad(ir_, (0, x.shape[-1] - ir_.shape[-1]))
165
+ x_ = x
166
+ else:
167
+ x_ = F.pad(x, (0, ir_.shape[-1] - x.shape[-1]))
168
+ return (
169
+ x
170
+ + torch.fft.irfft(torch.fft.rfft(x_) * torch.fft.rfft(ir_))[
171
+ ..., : x.shape[-1]
172
+ ]
173
+ )
neural_waveshaping_synthesis/models/neural_waveshaping.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import auraloss
2
+ import gin
3
+ import pytorch_lightning as pl
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import wandb
8
+
9
+ from .modules.dynamic import TimeDistributedMLP
10
+ from .modules.generators import FIRNoiseSynth, HarmonicOscillator
11
+ from .modules.shaping import NEWT, Reverb
12
+
13
+ gin.external_configurable(nn.GRU, module="torch.nn")
14
+ gin.external_configurable(nn.Conv1d, module="torch.nn")
15
+
16
+
17
+ @gin.configurable
18
+ class ControlModule(nn.Module):
19
+ def __init__(self, control_size: int, hidden_size: int, embedding_size: int):
20
+ super().__init__()
21
+ self.gru = nn.GRU(control_size, hidden_size, batch_first=True)
22
+ self.proj = nn.Conv1d(hidden_size, embedding_size, 1)
23
+
24
+ def forward(self, x):
25
+ x, _ = self.gru(x.transpose(1, 2))
26
+ return self.proj(x.transpose(1, 2))
27
+
28
+
29
+ @gin.configurable
30
+ class NeuralWaveshaping(pl.LightningModule):
31
+ def __init__(
32
+ self,
33
+ n_waveshapers: int,
34
+ control_hop: int,
35
+ sample_rate: float = 16000,
36
+ learning_rate: float = 1e-3,
37
+ lr_decay: float = 0.9,
38
+ lr_decay_interval: int = 10000,
39
+ log_audio: bool = False,
40
+ ):
41
+ super().__init__()
42
+ self.save_hyperparameters()
43
+ self.learning_rate = learning_rate
44
+ self.lr_decay = lr_decay
45
+ self.lr_decay_interval = lr_decay_interval
46
+ self.control_hop = control_hop
47
+ self.log_audio = log_audio
48
+
49
+ self.sample_rate = sample_rate
50
+
51
+ self.embedding = ControlModule()
52
+
53
+ self.osc = HarmonicOscillator()
54
+ self.harmonic_mixer = nn.Conv1d(self.osc.n_harmonics, n_waveshapers, 1)
55
+
56
+ self.newt = NEWT()
57
+
58
+ with gin.config_scope("noise_synth"):
59
+ self.h_generator = TimeDistributedMLP()
60
+ self.noise_synth = FIRNoiseSynth()
61
+
62
+ self.reverb = Reverb()
63
+
64
+ def render_exciter(self, f0):
65
+ sig = self.osc(f0[:, 0])
66
+ sig = self.harmonic_mixer(sig)
67
+ return sig
68
+
69
+ def get_embedding(self, control):
70
+ f0, other = control[:, 0:1], control[:, 1:2]
71
+ control = torch.cat((f0, other), dim=1)
72
+ return self.embedding(control)
73
+
74
+ def forward(self, f0, control):
75
+ f0_upsampled = F.upsample(f0, f0.shape[-1] * self.control_hop, mode="linear")
76
+ x = self.render_exciter(f0_upsampled)
77
+
78
+ control_embedding = self.get_embedding(control)
79
+
80
+ x = self.newt(x, control_embedding)
81
+
82
+ H = self.h_generator(control_embedding)
83
+ noise = self.noise_synth(H)
84
+
85
+ x = torch.cat((x, noise), dim=1)
86
+ x = x.sum(1)
87
+
88
+ x = self.reverb(x)
89
+
90
+ return x
91
+
92
+ def configure_optimizers(self):
93
+ self.stft_loss = auraloss.freq.MultiResolutionSTFTLoss()
94
+
95
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
96
+ scheduler = torch.optim.lr_scheduler.StepLR(
97
+ optimizer, self.lr_decay_interval, self.lr_decay
98
+ )
99
+ return {
100
+ "optimizer": optimizer,
101
+ "lr_scheduler": {"scheduler": scheduler, "interval": "step"},
102
+ }
103
+
104
+ def _run_step(self, batch):
105
+ audio = batch["audio"].float()
106
+ f0 = batch["f0"].float()
107
+ control = batch["control"].float()
108
+
109
+ recon = self(f0, control)
110
+
111
+ loss = self.stft_loss(recon, audio)
112
+ return loss, recon, audio
113
+
114
+ def _log_audio(self, name, audio):
115
+ wandb.log(
116
+ {
117
+ "audio/%s"
118
+ % name: wandb.Audio(audio, sample_rate=self.sample_rate, caption=name)
119
+ },
120
+ commit=False,
121
+ )
122
+
123
+ def training_step(self, batch, batch_idx):
124
+ loss, _, _ = self._run_step(batch)
125
+ self.log(
126
+ "train/loss",
127
+ loss.item(),
128
+ on_step=False,
129
+ on_epoch=True,
130
+ prog_bar=True,
131
+ logger=True,
132
+ sync_dist=True,
133
+ )
134
+ return loss
135
+
136
+ def validation_step(self, batch, batch_idx):
137
+ loss, recon, audio = self._run_step(batch)
138
+ self.log(
139
+ "val/loss",
140
+ loss.item(),
141
+ on_step=False,
142
+ on_epoch=True,
143
+ prog_bar=True,
144
+ logger=True,
145
+ sync_dist=True,
146
+ )
147
+ if batch_idx == 0 and self.log_audio:
148
+ self._log_audio("original", audio[0].detach().cpu().squeeze())
149
+ self._log_audio("recon", recon[0].detach().cpu().squeeze())
150
+ return loss
151
+
152
+ def test_step(self, batch, batch_idx):
153
+ loss, recon, audio = self._run_step(batch)
154
+ self.log(
155
+ "test/loss",
156
+ loss.item(),
157
+ on_step=False,
158
+ on_epoch=True,
159
+ prog_bar=True,
160
+ logger=True,
161
+ sync_dist=True,
162
+ )
163
+ if batch_idx == 0:
164
+ self._log_audio("original", audio[0].detach().cpu().squeeze())
165
+ self._log_audio("recon", recon[0].detach().cpu().squeeze())
neural_waveshaping_synthesis/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .utils import *
2
+ from .seed_all import *
neural_waveshaping_synthesis/utils/seed_all.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import random
4
+ import torch
5
+
6
+ def seed_all(seed):
7
+ np.random.seed(seed)
8
+ os.environ['PYTHONHASHSEED'] = str(seed)
9
+ random.seed(seed)
10
+ torch.manual_seed(seed)
11
+ torch.cuda.manual_seed(seed)
12
+ torch.backends.cudnn.deterministic = True
neural_waveshaping_synthesis/utils/utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Callable, Sequence
3
+
4
+
5
+ def apply(fn: Callable[[any], any], x: Sequence[any]):
6
+ if type(x) not in (tuple, list):
7
+ raise TypeError("x must be a tuple or list.")
8
+ return type(x)([fn(element) for element in x])
9
+
10
+
11
+ def apply_unpack(fn: Callable[[any], any], x: Sequence[Sequence[any]]):
12
+ if type(x) not in (tuple, list):
13
+ raise TypeError("x must be a tuple or list.")
14
+ return type(x)([fn(*element) for element in x])
15
+
16
+
17
+ def unzip(x: Sequence[any]):
18
+ return list(zip(*x))
19
+
20
+
21
+ def make_dir_if_not_exists(path):
22
+ if not os.path.exists(path):
23
+ os.makedirs(path, exist_ok=True)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ auraloss==0.2.1
2
+ black==20.8b1
3
+ click==7.1.2
4
+ gin-config==0.4.0
5
+ librosa==0.8.0
6
+ numpy==1.20.1
7
+ pytorch_lightning==1.1.2
8
+ resampy==0.2.2
9
+ scipy==1.6.1
10
+ torch==1.7.1
11
+ torchcrepe==0.0.12
12
+ gradio
13
+ wandb
scripts/create_dataset.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import click
4
+ import gin
5
+
6
+ from neural_waveshaping_synthesis.data.utils.create_dataset import create_dataset
7
+ from neural_waveshaping_synthesis.utils import seed_all
8
+
9
+
10
+ def get_filenames(directory):
11
+ return [os.path.join(directory, f) for f in os.listdir(directory) if ".wav" in f]
12
+
13
+
14
+ @click.command()
15
+ @click.option("--gin-file", prompt="Gin config file")
16
+ @click.option("--data-directory", prompt="Data directory")
17
+ @click.option("--output-directory", prompt="Output directory")
18
+ @click.option("--seed", default=0)
19
+ @click.option("--device", default="cpu")
20
+ def main(gin_file, data_directory, output_directory, seed=0, device="cpu"):
21
+ gin.constant("device", device)
22
+ gin.parse_config_file(gin_file)
23
+
24
+ seed_all(seed)
25
+
26
+ files = get_filenames(data_directory)
27
+ create_dataset(files, output_directory)
28
+
29
+
30
+ if __name__ == "__main__":
31
+ main()
scripts/create_urmp_dataset.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import click
5
+ import gin
6
+
7
+ from neural_waveshaping_synthesis.data.utils.create_dataset import create_dataset
8
+ from neural_waveshaping_synthesis.utils import seed_all
9
+
10
+ INSTRUMENTS = (
11
+ "vn",
12
+ "vc",
13
+ "fl",
14
+ "cl",
15
+ "tpt",
16
+ "sax",
17
+ "tbn",
18
+ "ob",
19
+ "va",
20
+ "bn",
21
+ "hn",
22
+ "db",
23
+ )
24
+
25
+
26
+ def get_instrument_file_list(instrument_string, directory):
27
+ return [
28
+ str(f)
29
+ for f in Path(directory).glob(
30
+ "**/*_%s_*/AuSep*_%s_*.wav" % (instrument_string, instrument_string)
31
+ )
32
+ ]
33
+
34
+
35
+ @click.command()
36
+ @click.option("--gin-file", prompt="Gin config file")
37
+ @click.option("--data-directory", prompt="Data directory")
38
+ @click.option("--output-directory", prompt="Output directory")
39
+ @click.option("--seed", default=0)
40
+ @click.option("--device", default="cpu")
41
+ def main(gin_file, data_directory, output_directory, seed=0, device="cpu"):
42
+ gin.constant("device", device)
43
+ gin.parse_config_file(gin_file)
44
+
45
+ seed_all(seed)
46
+
47
+ file_lists = {
48
+ instrument: get_instrument_file_list(instrument, data_directory)
49
+ for instrument in INSTRUMENTS
50
+ }
51
+ for instrument in file_lists:
52
+ create_dataset(
53
+ file_lists[instrument], os.path.join(output_directory, instrument)
54
+ )
55
+
56
+
57
+ if __name__ == "__main__":
58
+ main()
scripts/resynthesise_dataset.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import click
4
+ import gin
5
+ from scipy.io import wavfile
6
+ from tqdm import tqdm
7
+ import torch
8
+
9
+ from neural_waveshaping_synthesis.data.urmp import URMPDataset
10
+ from neural_waveshaping_synthesis.models.modules.shaping import FastNEWT
11
+ from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping
12
+ from neural_waveshaping_synthesis.utils import make_dir_if_not_exists
13
+
14
+
15
+ @click.command()
16
+ @click.option("--model-gin", prompt="Model .gin file")
17
+ @click.option("--model-checkpoint", prompt="Model checkpoint")
18
+ @click.option("--dataset-root", prompt="Dataset root directory")
19
+ @click.option("--dataset-split", default="test")
20
+ @click.option("--output-path", default="audio_output")
21
+ @click.option("--load-data-to-memory", default=False)
22
+ @click.option("--device", default="cuda:0")
23
+ @click.option("--batch-size", default=8)
24
+ @click.option("--num_workers", default=16)
25
+ @click.option("--use-fastnewt", is_flag=True)
26
+ def main(
27
+ model_gin,
28
+ model_checkpoint,
29
+ dataset_root,
30
+ dataset_split,
31
+ output_path,
32
+ load_data_to_memory,
33
+ device,
34
+ batch_size,
35
+ num_workers,
36
+ use_fastnewt
37
+ ):
38
+ gin.parse_config_file(model_gin)
39
+ make_dir_if_not_exists(output_path)
40
+
41
+ data = URMPDataset(dataset_root, dataset_split, load_data_to_memory)
42
+ data_loader = torch.utils.data.DataLoader(
43
+ data, batch_size=batch_size, num_workers=num_workers
44
+ )
45
+
46
+ device = torch.device(device)
47
+ model = NeuralWaveshaping.load_from_checkpoint(model_checkpoint)
48
+ model.eval()
49
+
50
+ if use_fastnewt:
51
+ model.newt = FastNEWT(model.newt)
52
+
53
+ model = model.to(device)
54
+
55
+ for i, batch in enumerate(tqdm(data_loader)):
56
+ with torch.no_grad():
57
+ f0 = batch["f0"].float().to(device)
58
+ control = batch["control"].float().to(device)
59
+ output = model(f0, control)
60
+
61
+ target_audio = batch["audio"].float().numpy()
62
+ output_audio = output.cpu().numpy()
63
+ for j in range(output_audio.shape[0]):
64
+ name = batch["name"][j]
65
+ target_name = "%s.target.wav" % name
66
+ output_name = "%s.output.wav" % name
67
+ wavfile.write(
68
+ os.path.join(output_path, target_name),
69
+ model.sample_rate,
70
+ target_audio[j],
71
+ )
72
+ wavfile.write(
73
+ os.path.join(output_path, output_name),
74
+ model.sample_rate,
75
+ output_audio[j],
76
+ )
77
+
78
+
79
+ if __name__ == "__main__":
80
+ main()
scripts/time_buffer_sizes.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import click
4
+ import gin
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from tqdm import trange
9
+
10
+ from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping
11
+ from neural_waveshaping_synthesis.models.modules.shaping import FastNEWT
12
+
13
+ BUFFER_SIZES = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
14
+
15
+ @click.command()
16
+ @click.option("--gin-file", prompt="Model config gin file")
17
+ @click.option("--output-file", prompt="output file")
18
+ @click.option("--num-iters", default=100)
19
+ @click.option("--batch-size", default=1)
20
+ @click.option("--device", default="cpu")
21
+ @click.option("--length-in-seconds", default=4)
22
+ @click.option("--use-fast-newt", is_flag=True)
23
+ @click.option("--model-name", default="ours")
24
+ def main(
25
+ gin_file,
26
+ output_file,
27
+ num_iters,
28
+ batch_size,
29
+ device,
30
+ length_in_seconds,
31
+ use_fast_newt,
32
+ model_name,
33
+ ):
34
+ gin.parse_config_file(gin_file)
35
+ model = NeuralWaveshaping()
36
+ if use_fast_newt:
37
+ model.newt = FastNEWT(model.newt)
38
+ model.eval()
39
+ model = model.to(device)
40
+
41
+ # eliminate any lazy init costs
42
+ with torch.no_grad():
43
+ for i in range(10):
44
+ model(
45
+ torch.rand(4, 1, 250, device=device),
46
+ torch.rand(4, 2, 250, device=device),
47
+ )
48
+
49
+ times = []
50
+ with torch.no_grad():
51
+ for bs in BUFFER_SIZES:
52
+ dummy_control = torch.rand(
53
+ batch_size,
54
+ 2,
55
+ bs // 128,
56
+ device=device,
57
+ requires_grad=False,
58
+ )
59
+ dummy_f0 = torch.rand(
60
+ batch_size,
61
+ 1,
62
+ bs // 128,
63
+ device=device,
64
+ requires_grad=False,
65
+ )
66
+ for i in trange(num_iters):
67
+ start_time = time.time()
68
+ model(dummy_f0, dummy_control)
69
+ time_elapsed = time.time() - start_time
70
+ times.append(
71
+ [model_name, device if device == "cpu" else "gpu", bs, time_elapsed]
72
+ )
73
+
74
+ df = pd.DataFrame(times)
75
+ df.to_csv(output_file)
76
+
77
+
78
+ if __name__ == "__main__":
79
+ main()
scripts/time_forward_pass.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import click
4
+ import gin
5
+ import numpy as np
6
+ from scipy.stats import describe
7
+ import torch
8
+ from tqdm import trange
9
+
10
+ from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping
11
+ from neural_waveshaping_synthesis.models.modules.shaping import FastNEWT
12
+
13
+
14
+ @click.command()
15
+ @click.option("--gin-file", prompt="Model config gin file")
16
+ @click.option("--num-iters", default=100)
17
+ @click.option("--batch-size", default=1)
18
+ @click.option("--device", default="cpu")
19
+ @click.option("--length-in-seconds", default=4)
20
+ @click.option("--sample-rate", default=16000)
21
+ @click.option("--control-hop", default=128)
22
+ @click.option("--use-fast-newt", is_flag=True)
23
+ def main(
24
+ gin_file, num_iters, batch_size, device, length_in_seconds, sample_rate, control_hop, use_fast_newt
25
+ ):
26
+ gin.parse_config_file(gin_file)
27
+ dummy_control = torch.rand(
28
+ batch_size,
29
+ 2,
30
+ sample_rate * length_in_seconds // control_hop,
31
+ device=device,
32
+ requires_grad=False,
33
+ )
34
+ dummy_f0 = torch.rand(
35
+ batch_size,
36
+ 1,
37
+ sample_rate * length_in_seconds // control_hop,
38
+ device=device,
39
+ requires_grad=False,
40
+ )
41
+ model = NeuralWaveshaping()
42
+ if use_fast_newt:
43
+ model.newt = FastNEWT(model.newt)
44
+ model.eval()
45
+ model = model.to(device)
46
+
47
+ times = []
48
+ with torch.no_grad():
49
+ for i in trange(num_iters):
50
+ start_time = time.time()
51
+ model(dummy_f0, dummy_control)
52
+ time_elapsed = time.time() - start_time
53
+ times.append(time_elapsed)
54
+
55
+ print(describe(times))
56
+ rtfs = np.array(times) / length_in_seconds
57
+ print("Mean RTF: %.4f" % np.mean(rtfs))
58
+ print("90th percentile RTF: %.4f" % np.percentile(rtfs, 90))
59
+
60
+
61
+ if __name__ == "__main__":
62
+ main()
scripts/train.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+ import gin
3
+ import pytorch_lightning as pl
4
+
5
+ from neural_waveshaping_synthesis.data.general import GeneralDataModule
6
+ from neural_waveshaping_synthesis.data.urmp import URMPDataModule
7
+ from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping
8
+
9
+
10
+ @gin.configurable
11
+ def get_model(model, with_wandb):
12
+ return model(log_audio=with_wandb)
13
+
14
+
15
+ @gin.configurable
16
+ def trainer_kwargs(**kwargs):
17
+ return kwargs
18
+
19
+
20
+ @click.command()
21
+ @click.option("--gin-file", prompt="Gin config file")
22
+ @click.option("--dataset-path", prompt="Dataset root")
23
+ @click.option("--urmp", is_flag=True)
24
+ @click.option("--device", default="0")
25
+ @click.option("--instrument", default="vn")
26
+ @click.option("--load-data-to-memory", is_flag=True)
27
+ @click.option("--with-wandb", is_flag=True)
28
+ @click.option("--restore-checkpoint", default="")
29
+ def main(
30
+ gin_file,
31
+ dataset_path,
32
+ urmp,
33
+ device,
34
+ instrument,
35
+ load_data_to_memory,
36
+ with_wandb,
37
+ restore_checkpoint,
38
+ ):
39
+ gin.parse_config_file(gin_file)
40
+ model = get_model(with_wandb=with_wandb)
41
+
42
+ if urmp:
43
+ data = URMPDataModule(
44
+ dataset_path,
45
+ instrument,
46
+ load_to_memory=load_data_to_memory,
47
+ num_workers=16,
48
+ shuffle=True,
49
+ )
50
+ else:
51
+ data = GeneralDataModule(
52
+ dataset_path,
53
+ load_to_memory=load_data_to_memory,
54
+ num_workers=16,
55
+ shuffle=True,
56
+ )
57
+
58
+ checkpointing = pl.callbacks.ModelCheckpoint(
59
+ monitor="val/loss", save_top_k=1, save_last=True
60
+ )
61
+ callbacks = [checkpointing]
62
+ if with_wandb:
63
+ lr_logger = pl.callbacks.LearningRateMonitor(logging_interval="epoch")
64
+ callbacks.append(lr_logger)
65
+ logger = pl.loggers.WandbLogger(project="neural-waveshaping-synthesis")
66
+ logger.watch(model, log="parameters")
67
+
68
+
69
+ kwargs = trainer_kwargs()
70
+ trainer = pl.Trainer(
71
+ logger=logger if with_wandb else None,
72
+ callbacks=callbacks,
73
+ gpus=device,
74
+ resume_from_checkpoint=restore_checkpoint if restore_checkpoint != "" else None,
75
+ **kwargs
76
+ )
77
+ trainer.fit(model, data)
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main()
setup.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(name="neural_waveshaping_synthesis", version="0.0.1", packages=find_packages())