Files changed (2) hide show
  1. Interacting_with_Jukebox.ipynb +961 -0
  2. train.py +345 -0
Interacting_with_Jukebox.ipynb ADDED
@@ -0,0 +1,961 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "name": "Interacting with Jukebox",
7
+ "provenance": [],
8
+ "collapsed_sections": [],
9
+ "machine_shape": "hm"
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "accelerator": "GPU"
16
+ },
17
+ "cells": [
18
+ {
19
+ "cell_type": "code",
20
+ "metadata": {
21
+ "id": "sAdFGF-bqVMY",
22
+ "colab_type": "code",
23
+ "colab": {}
24
+ },
25
+ "source": [
26
+ "!pip install git+https://github.com/openai/jukebox.git"
27
+ ],
28
+ "execution_count": 0,
29
+ "outputs": []
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "metadata": {
34
+ "id": "uq8uLwZCn0BV",
35
+ "colab_type": "text"
36
+ },
37
+ "source": [
38
+ "IMPORTANT NOTE ON SYSTEM REQUIREMENTS:\n",
39
+ "\n",
40
+ "If you are connecting to a hosted runtime, make sure it has a P100 GPU (optionally run !nvidia-smi to confirm). Go to Edit>Notebook Settings to set this.\n",
41
+ "\n",
42
+ "CoLab may first assign you a lower memory machine if you are using a hosted runtime. If so, the first time you try to load the 5B model, it will run out of memory, and then you'll be prompted to restart with more memory (then return to the top of this CoLab). If you continue to have memory issues after this (or run into issues on your own home setup), switch to the 1B model.\n",
43
+ "\n",
44
+ "If you are using a local GPU, we recommend V100 or P100 with 16GB GPU memory for best performance. For GPU’s with less memory, we recommend using the 1B model and a smaller batch size throughout. \n",
45
+ "\n"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "metadata": {
51
+ "id": "8qEqdj8u0gdN",
52
+ "colab_type": "code",
53
+ "colab": {}
54
+ },
55
+ "source": [
56
+ "!nvidia-smi"
57
+ ],
58
+ "execution_count": 0,
59
+ "outputs": []
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "metadata": {
64
+ "id": "taDHgk1WCC_C",
65
+ "colab_type": "code",
66
+ "colab": {}
67
+ },
68
+ "source": [
69
+ "import jukebox\n",
70
+ "import torch as t\n",
71
+ "import librosa\n",
72
+ "import os\n",
73
+ "from IPython.display import Audio\n",
74
+ "from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model\n",
75
+ "from jukebox.hparams import Hyperparams, setup_hparams\n",
76
+ "from jukebox.sample import sample_single_window, _sample, \\\n",
77
+ " sample_partial_window, upsample\n",
78
+ "from jukebox.utils.dist_utils import setup_dist_from_mpi\n",
79
+ "from jukebox.utils.torch_utils import empty_cache\n",
80
+ "rank, local_rank, device = setup_dist_from_mpi()"
81
+ ],
82
+ "execution_count": 0,
83
+ "outputs": []
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "metadata": {
88
+ "id": "89FftI5kc-Az",
89
+ "colab_type": "text"
90
+ },
91
+ "source": [
92
+ "# Sample from the 5B or 1B Lyrics Model\n"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "metadata": {
98
+ "id": "65aR2OZxmfzq",
99
+ "colab_type": "code",
100
+ "colab": {}
101
+ },
102
+ "source": [
103
+ "model = \"5b_lyrics\" # or \"1b_lyrics\" \n",
104
+ "hps = Hyperparams()\n",
105
+ "hps.sr = 44100\n",
106
+ "hps.n_samples = 3 if model=='5b_lyrics' else 8\n",
107
+ "hps.name = 'samples'\n",
108
+ "chunk_size = 16 if model==\"5b_lyrics\" else 32\n",
109
+ "max_batch_size = 3 if model==\"5b_lyrics\" else 16\n",
110
+ "hps.levels = 3\n",
111
+ "hps.hop_fraction = [.5,.5,.125]\n",
112
+ "\n",
113
+ "vqvae, *priors = MODELS[model]\n",
114
+ "vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)\n",
115
+ "top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)\n",
116
+ "\n"
117
+ ],
118
+ "execution_count": 0,
119
+ "outputs": []
120
+ },
121
+ {
122
+ "cell_type": "markdown",
123
+ "metadata": {
124
+ "id": "JYKiwkzy0Iyf",
125
+ "colab_type": "text"
126
+ },
127
+ "source": [
128
+ "Specify your choice of artist, genre, lyrics, and length of musical sample. "
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "metadata": {
134
+ "id": "-sY9aGHcZP-u",
135
+ "colab_type": "code",
136
+ "colab": {}
137
+ },
138
+ "source": [
139
+ "sample_length_in_seconds = 60 # Full length of musical sample to generate - we find songs in the 1 to 4 minute\n",
140
+ " # range work well, with generation time proportional to sample length. \n",
141
+ " # This total length affects how quickly the model \n",
142
+ " # progresses through lyrics (model also generates differently\n",
143
+ " # depending on if it thinks it's in the beginning, middle, or end of sample)\n",
144
+ "\n",
145
+ "hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens\n",
146
+ "assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'"
147
+ ],
148
+ "execution_count": 0,
149
+ "outputs": []
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "metadata": {
154
+ "colab_type": "code",
155
+ "id": "qD0qxQeLaTR0",
156
+ "colab": {}
157
+ },
158
+ "source": [
159
+ "metas = [dict(artist = \"Zac Brown Band\",\n",
160
+ " genre = \"Country\",\n",
161
+ " total_length = hps.sample_length,\n",
162
+ " offset = 0,\n",
163
+ " lyrics = \"\"\"I met a traveller from an antique land,\n",
164
+ " Who said—“Two vast and trunkless legs of stone\n",
165
+ " Stand in the desert. . . . Near them, on the sand,\n",
166
+ " Half sunk a shattered visage lies, whose frown,\n",
167
+ " And wrinkled lip, and sneer of cold command,\n",
168
+ " Tell that its sculptor well those passions read\n",
169
+ " Which yet survive, stamped on these lifeless things,\n",
170
+ " The hand that mocked them, and the heart that fed;\n",
171
+ " And on the pedestal, these words appear:\n",
172
+ " My name is Ozymandias, King of Kings;\n",
173
+ " Look on my Works, ye Mighty, and despair!\n",
174
+ " Nothing beside remains. Round the decay\n",
175
+ " Of that colossal Wreck, boundless and bare\n",
176
+ " The lone and level sands stretch far away\n",
177
+ " \"\"\",\n",
178
+ " ),\n",
179
+ " ] * hps.n_samples\n",
180
+ "labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')]"
181
+ ],
182
+ "execution_count": 0,
183
+ "outputs": []
184
+ },
185
+ {
186
+ "cell_type": "markdown",
187
+ "metadata": {
188
+ "id": "6PHC1XnEfV4Y",
189
+ "colab_type": "text"
190
+ },
191
+ "source": [
192
+ "Optionally adjust the sampling temperature (we've found .98 or .99 to be our favorite). \n"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "metadata": {
198
+ "colab_type": "code",
199
+ "id": "eNwKyqYraTR9",
200
+ "colab": {}
201
+ },
202
+ "source": [
203
+ "sampling_temperature = .98\n",
204
+ "\n",
205
+ "lower_batch_size = 16\n",
206
+ "max_batch_size = 3 if model == \"5b_lyrics\" else 16\n",
207
+ "lower_level_chunk_size = 32\n",
208
+ "chunk_size = 16 if model == \"5b_lyrics\" else 32\n",
209
+ "sampling_kwargs = [dict(temp=.99, fp16=True, max_batch_size=lower_batch_size,\n",
210
+ " chunk_size=lower_level_chunk_size),\n",
211
+ " dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size,\n",
212
+ " chunk_size=lower_level_chunk_size),\n",
213
+ " dict(temp=sampling_temperature, fp16=True, \n",
214
+ " max_batch_size=max_batch_size, chunk_size=chunk_size)]"
215
+ ],
216
+ "execution_count": 0,
217
+ "outputs": []
218
+ },
219
+ {
220
+ "cell_type": "markdown",
221
+ "metadata": {
222
+ "id": "S3j0gT3HfrRD",
223
+ "colab_type": "text"
224
+ },
225
+ "source": [
226
+ "Now we're ready to sample from the model. We'll generate the top level (2) first, followed by the first upsampling (level 1), and the second upsampling (0). In this CoLab we load the top prior separately from the upsamplers, because of memory concerns on the hosted runtimes. If you are using a local machine, you can also load all models directly with make_models, and then use sample.py's ancestral_sampling to put this all in one step.\n",
227
+ "\n",
228
+ "After each level, we decode to raw audio and save the audio files. \n",
229
+ "\n",
230
+ "This next cell will take a while (approximately 10 minutes per 20 seconds of music sample)"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "metadata": {
236
+ "id": "2nET_YBEopyp",
237
+ "colab_type": "code",
238
+ "colab": {}
239
+ },
240
+ "source": [
241
+ "zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))]\n",
242
+ "zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps)"
243
+ ],
244
+ "execution_count": 0,
245
+ "outputs": []
246
+ },
247
+ {
248
+ "cell_type": "markdown",
249
+ "metadata": {
250
+ "id": "-gxY9aqHqfLJ",
251
+ "colab_type": "text"
252
+ },
253
+ "source": [
254
+ "Listen to the results from the top level (note this will sound very noisy until we do the upsampling stage). You may have more generated samples, depending on the batch size you requested."
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "metadata": {
260
+ "id": "TPZENDGZqOOb",
261
+ "colab_type": "code",
262
+ "colab": {}
263
+ },
264
+ "source": [
265
+ "Audio(f'{hps.name}/level_2/item_0.wav')"
266
+ ],
267
+ "execution_count": 0,
268
+ "outputs": []
269
+ },
270
+ {
271
+ "cell_type": "markdown",
272
+ "metadata": {
273
+ "id": "EJc3bQxmusc6",
274
+ "colab_type": "text"
275
+ },
276
+ "source": [
277
+ "We are now done with the large top_prior model, and instead load the upsamplers."
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "metadata": {
283
+ "id": "W5VLX0zRapIm",
284
+ "colab_type": "code",
285
+ "colab": {}
286
+ },
287
+ "source": [
288
+ "# Set this False if you are on a local machine that has enough memory (this allows you to do the\n",
289
+ "# lyrics alignment visualization during the upsampling stage). For a hosted runtime, \n",
290
+ "# we'll need to go ahead and delete the top_prior if you are using the 5b_lyrics model.\n",
291
+ "if True:\n",
292
+ " del top_prior\n",
293
+ " empty_cache()\n",
294
+ " top_prior=None\n",
295
+ "upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]\n",
296
+ "labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers]"
297
+ ],
298
+ "execution_count": 0,
299
+ "outputs": []
300
+ },
301
+ {
302
+ "cell_type": "markdown",
303
+ "metadata": {
304
+ "id": "eH_jUhGDprAt",
305
+ "colab_type": "text"
306
+ },
307
+ "source": [
308
+ "Please note: this next upsampling step will take several hours. At the free tier, Google CoLab lets you run for 12 hours. As the upsampling is completed, samples will appear in the Files tab (you can access this at the left of the CoLab), under \"samples\" (or whatever hps.name is currently). Level 1 is the partially upsampled version, and then Level 0 is fully completed."
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "code",
313
+ "metadata": {
314
+ "id": "9lkJgLolpZ6w",
315
+ "colab_type": "code",
316
+ "colab": {}
317
+ },
318
+ "source": [
319
+ "zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)\n"
320
+ ],
321
+ "execution_count": 0,
322
+ "outputs": []
323
+ },
324
+ {
325
+ "cell_type": "markdown",
326
+ "metadata": {
327
+ "id": "3SJgBYJPri55",
328
+ "colab_type": "text"
329
+ },
330
+ "source": [
331
+ "Listen to your final sample!"
332
+ ]
333
+ },
334
+ {
335
+ "cell_type": "code",
336
+ "metadata": {
337
+ "id": "2ip2PPE0rgAb",
338
+ "colab_type": "code",
339
+ "colab": {}
340
+ },
341
+ "source": [
342
+ "Audio(f'{hps.name}/level_0/item_0.wav')"
343
+ ],
344
+ "execution_count": 0,
345
+ "outputs": []
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "metadata": {
350
+ "id": "8JAgFxytwrLG",
351
+ "colab_type": "code",
352
+ "colab": {}
353
+ },
354
+ "source": [
355
+ "del upsamplers\n",
356
+ "empty_cache()"
357
+ ],
358
+ "execution_count": 0,
359
+ "outputs": []
360
+ },
361
+ {
362
+ "cell_type": "markdown",
363
+ "metadata": {
364
+ "id": "LpvvFH85bbBC",
365
+ "colab_type": "text"
366
+ },
367
+ "source": [
368
+ "# Co-Composing with the 5B or 1B Lyrics Model"
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "markdown",
373
+ "metadata": {
374
+ "id": "nFDROuS7gFQY",
375
+ "colab_type": "text"
376
+ },
377
+ "source": [
378
+ "For more control over the generations, try co-composing with either the 5B or 1B Lyrics Models. Again, specify your artist, genre, and lyrics. However, now instead of generating the entire sample, the model will return 3 short options for the opening of the piece (or up to 16 options if you use the 1B model instead). Choose your favorite, and then continue the loop, for as long as you like. Throughout these steps, you'll be listening to the audio at the top prior level, which means it will sound quite noisy. When you are satisfied with your co-creation, continue on through the upsampling section. This will render the piece in higher audio quality.\n",
379
+ "\n",
380
+ "NOTE: CoLab will first assign you a lower memory machine if you are using a hosted runtime. The next cell will run out of memory, and then you'll be prompted to restart with more memory (then return to the top of this CoLab). If you continue to have memory issues after this (or run into issues on your own home setup), switch to the 1B model. "
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "metadata": {
386
+ "id": "3y-q8ifhGBlU",
387
+ "colab_type": "code",
388
+ "colab": {}
389
+ },
390
+ "source": [
391
+ "model = \"5b_lyrics\" # or \"1b_lyrics\"\n",
392
+ "hps = Hyperparams()\n",
393
+ "hps.sr = 44100\n",
394
+ "hps.n_samples = 3 if model=='5b_lyrics' else 16\n",
395
+ "hps.name = 'co_composer'\n",
396
+ "hps.sample_length = 1048576 if model==\"5b_lyrics\" else 786432 \n",
397
+ "chunk_size = 16 if model==\"5b_lyrics\" else 32\n",
398
+ "max_batch_size = 3 if model==\"5b_lyrics\" else 16\n",
399
+ "hps.hop_fraction = [.5, .5, .125] \n",
400
+ "hps.levels = 3\n",
401
+ "\n",
402
+ "vqvae, *priors = MODELS[model]\n",
403
+ "vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = hps.sample_length)), device)\n",
404
+ "top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)"
405
+ ],
406
+ "execution_count": 0,
407
+ "outputs": []
408
+ },
409
+ {
410
+ "cell_type": "markdown",
411
+ "metadata": {
412
+ "id": "68hz4x7igq0c",
413
+ "colab_type": "text"
414
+ },
415
+ "source": [
416
+ "Choose your artist, genre, and lyrics here!"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "metadata": {
422
+ "id": "QDMvH_1zUHo6",
423
+ "colab_type": "code",
424
+ "colab": {}
425
+ },
426
+ "source": [
427
+ "total_sample_length_in_seconds = 120\n",
428
+ "metas = [dict(artist = \"Zac Brown Band\",\n",
429
+ " genre = \"Country\",\n",
430
+ " total_length = total_sample_length_in_seconds * hps.sr,\n",
431
+ " offset = 0,\n",
432
+ " lyrics = \"\"\"I met a traveller from an antique land,\n",
433
+ " Who said—“Two vast and trunkless legs of stone\n",
434
+ " Stand in the desert. . . . Near them, on the sand,\n",
435
+ " Half sunk a shattered visage lies, whose frown,\n",
436
+ " And wrinkled lip, and sneer of cold command,\n",
437
+ " Tell that its sculptor well those passions read\n",
438
+ " Which yet survive, stamped on these lifeless things,\n",
439
+ " The hand that mocked them, and the heart that fed;\n",
440
+ " And on the pedestal, these words appear:\n",
441
+ " My name is Ozymandias, King of Kings;\n",
442
+ " Look on my Works, ye Mighty, and despair!\n",
443
+ " Nothing beside remains. Round the decay\n",
444
+ " Of that colossal Wreck, boundless and bare\n",
445
+ " The lone and level sands stretch far away\n",
446
+ " \"\"\",\n",
447
+ " ),\n",
448
+ " ] * hps.n_samples\n",
449
+ "labels = top_prior.labeller.get_batch_labels(metas, 'cuda')"
450
+ ],
451
+ "execution_count": 0,
452
+ "outputs": []
453
+ },
454
+ {
455
+ "cell_type": "markdown",
456
+ "metadata": {
457
+ "id": "B9onZMEXh34f",
458
+ "colab_type": "text"
459
+ },
460
+ "source": [
461
+ "## Generate 3 options for the start of the song\n",
462
+ "\n",
463
+ "Initial generation is set to be 4 seconds long, but feel free to change this"
464
+ ]
465
+ },
466
+ {
467
+ "cell_type": "code",
468
+ "metadata": {
469
+ "id": "c6peEj8I_HHO",
470
+ "colab_type": "code",
471
+ "colab": {}
472
+ },
473
+ "source": [
474
+ "def seconds_to_tokens(sec, sr, prior, chunk_size):\n",
475
+ " tokens = sec * hps.sr // prior.raw_to_tokens\n",
476
+ " tokens = ((tokens // chunk_size) + 1) * chunk_size\n",
477
+ " assert tokens <= prior.n_ctx, 'Choose a shorter generation length to stay within the top prior context'\n",
478
+ " return tokens"
479
+ ],
480
+ "execution_count": 0,
481
+ "outputs": []
482
+ },
483
+ {
484
+ "cell_type": "code",
485
+ "metadata": {
486
+ "id": "2gn2GXt3zt3y",
487
+ "colab_type": "code",
488
+ "colab": {}
489
+ },
490
+ "source": [
491
+ "initial_generation_in_seconds = 4\n",
492
+ "tokens_to_sample = seconds_to_tokens(initial_generation_in_seconds, hps.sr, top_prior, chunk_size)"
493
+ ],
494
+ "execution_count": 0,
495
+ "outputs": []
496
+ },
497
+ {
498
+ "cell_type": "markdown",
499
+ "metadata": {
500
+ "id": "U0zcWcMoiigl",
501
+ "colab_type": "text"
502
+ },
503
+ "source": [
504
+ "Change the sampling temperature if you like (higher is more random). Our favorite is in the range .98 to .995"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "metadata": {
510
+ "id": "NHbH68H7VMeO",
511
+ "colab_type": "code",
512
+ "colab": {}
513
+ },
514
+ "source": [
515
+ "sampling_temperature = .98\n",
516
+ "sampling_kwargs = dict(temp=sampling_temperature, fp16=True,\n",
517
+ " max_batch_size=max_batch_size, chunk_size=chunk_size)"
518
+ ],
519
+ "execution_count": 0,
520
+ "outputs": []
521
+ },
522
+ {
523
+ "cell_type": "code",
524
+ "metadata": {
525
+ "id": "JGZEPe-WTt4g",
526
+ "colab_type": "code",
527
+ "colab": {}
528
+ },
529
+ "source": [
530
+ "zs=[t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(3)]\n",
531
+ "zs=sample_partial_window(zs, labels, sampling_kwargs, 2, top_prior, tokens_to_sample, hps)\n",
532
+ "x = vqvae.decode(zs[2:], start_level=2).cpu().numpy()"
533
+ ],
534
+ "execution_count": 0,
535
+ "outputs": []
536
+ },
537
+ {
538
+ "cell_type": "markdown",
539
+ "metadata": {
540
+ "id": "mveN4Be8jK2J",
541
+ "colab_type": "text"
542
+ },
543
+ "source": [
544
+ "Listen to your generated samples, and then pick a favorite. If you don't like any, go back and rerun the cell above. \n",
545
+ "\n",
546
+ "** NOTE this is at the noisy top level, upsample fully (in the next section) to hear the final audio version"
547
+ ]
548
+ },
549
+ {
550
+ "cell_type": "code",
551
+ "metadata": {
552
+ "colab_type": "code",
553
+ "id": "LrJSGMhUOhZg",
554
+ "colab": {}
555
+ },
556
+ "source": [
557
+ "for i in range(hps.n_samples):\n",
558
+ " librosa.output.write_wav(f'noisy_top_level_generation_{i}.wav', x[i], sr=44100)"
559
+ ],
560
+ "execution_count": 0,
561
+ "outputs": []
562
+ },
563
+ {
564
+ "cell_type": "code",
565
+ "metadata": {
566
+ "colab_type": "code",
567
+ "id": "rQ4ersQ5OhZr",
568
+ "colab": {}
569
+ },
570
+ "source": [
571
+ "Audio('noisy_top_level_generation_0.wav')"
572
+ ],
573
+ "execution_count": 0,
574
+ "outputs": []
575
+ },
576
+ {
577
+ "cell_type": "code",
578
+ "metadata": {
579
+ "colab_type": "code",
580
+ "id": "-GdqzrGkOhZv",
581
+ "colab": {}
582
+ },
583
+ "source": [
584
+ "Audio('noisy_top_level_generation_1.wav')"
585
+ ],
586
+ "execution_count": 0,
587
+ "outputs": []
588
+ },
589
+ {
590
+ "cell_type": "code",
591
+ "metadata": {
592
+ "colab_type": "code",
593
+ "id": "gE5S8hyZOhZy",
594
+ "colab": {}
595
+ },
596
+ "source": [
597
+ "Audio('noisy_top_level_generation_2.wav')"
598
+ ],
599
+ "execution_count": 0,
600
+ "outputs": []
601
+ },
602
+ {
603
+ "cell_type": "markdown",
604
+ "metadata": {
605
+ "id": "t2-mEJaqZfuS",
606
+ "colab_type": "text"
607
+ },
608
+ "source": [
609
+ "If you don't like any of the options, return a few cells back to \"Sample a few options...\" and rerun from there."
610
+ ]
611
+ },
612
+ {
613
+ "cell_type": "markdown",
614
+ "metadata": {
615
+ "id": "o7CzSiv0MmFP",
616
+ "colab_type": "text"
617
+ },
618
+ "source": [
619
+ "## Choose your favorite sample and request longer generation\n",
620
+ "\n",
621
+ "---\n",
622
+ "\n",
623
+ "(Repeat from here)\n"
624
+ ]
625
+ },
626
+ {
627
+ "cell_type": "code",
628
+ "metadata": {
629
+ "id": "j_XFtVi99CIY",
630
+ "colab_type": "code",
631
+ "colab": {}
632
+ },
633
+ "source": [
634
+ "my_choice=0"
635
+ ],
636
+ "execution_count": 0,
637
+ "outputs": []
638
+ },
639
+ {
640
+ "cell_type": "code",
641
+ "metadata": {
642
+ "id": "Pgk3sHHBLYoq",
643
+ "colab_type": "code",
644
+ "colab": {}
645
+ },
646
+ "source": [
647
+ "zs[2]=zs[2][my_choice].repeat(hps.n_samples,1)\n",
648
+ "t.save(zs, 'zs-checkpoint2.t')"
649
+ ],
650
+ "execution_count": 0,
651
+ "outputs": []
652
+ },
653
+ {
654
+ "cell_type": "code",
655
+ "metadata": {
656
+ "id": "W8Rd9xxm565S",
657
+ "colab_type": "code",
658
+ "colab": {}
659
+ },
660
+ "source": [
661
+ "# Set to True to load the previous checkpoint:\n",
662
+ "if False:\n",
663
+ " zs=t.load('zs-checkpoint2.t') "
664
+ ],
665
+ "execution_count": 0,
666
+ "outputs": []
667
+ },
668
+ {
669
+ "cell_type": "markdown",
670
+ "metadata": {
671
+ "id": "k12xjMgHkRGP",
672
+ "colab_type": "text"
673
+ },
674
+ "source": [
675
+ "Choose the length of the continuation. The 1B model can generate up to 17 second samples and the 5B up to 23 seconds, but you'll want to pick a shorter continuation length so that it will be able to look back at what you've generated already. Here we've chosen 4 seconds."
676
+ ]
677
+ },
678
+ {
679
+ "cell_type": "code",
680
+ "metadata": {
681
+ "id": "h3_-0a07kHHG",
682
+ "colab_type": "code",
683
+ "colab": {}
684
+ },
685
+ "source": [
686
+ "continue_generation_in_seconds=4\n",
687
+ "tokens_to_sample = seconds_to_tokens(continue_generation_in_seconds, hps.sr, top_prior, chunk_size)"
688
+ ],
689
+ "execution_count": 0,
690
+ "outputs": []
691
+ },
692
+ {
693
+ "cell_type": "markdown",
694
+ "metadata": {
695
+ "id": "GpPG3Ifqk8ue",
696
+ "colab_type": "text"
697
+ },
698
+ "source": [
699
+ "The next step asks the top prior to generate more of the sample. It'll take up to a few minutes, depending on the sample length you request."
700
+ ]
701
+ },
702
+ {
703
+ "cell_type": "code",
704
+ "metadata": {
705
+ "id": "YoHkeSTaEyLj",
706
+ "colab_type": "code",
707
+ "colab": {}
708
+ },
709
+ "source": [
710
+ "zs = sample_partial_window(zs, labels, sampling_kwargs, 2, top_prior, tokens_to_sample, hps)\n",
711
+ "x = vqvae.decode(zs[2:], start_level=2).cpu().numpy()"
712
+ ],
713
+ "execution_count": 0,
714
+ "outputs": []
715
+ },
716
+ {
717
+ "cell_type": "markdown",
718
+ "metadata": {
719
+ "id": "ymhUqEdhleEi",
720
+ "colab_type": "text"
721
+ },
722
+ "source": [
723
+ "Now listen to the longer versions of the sample you selected, and again choose a favorite sample. If you don't like any, return back to the cell where you can load the checkpoint, and continue again from there.\n",
724
+ "\n",
725
+ "When the samples start getting long, you might not always want to listen from the start, so change the playback start time later on if you like."
726
+ ]
727
+ },
728
+ {
729
+ "cell_type": "code",
730
+ "metadata": {
731
+ "id": "2H1LNLTa_R6a",
732
+ "colab_type": "code",
733
+ "colab": {}
734
+ },
735
+ "source": [
736
+ "playback_start_time_in_seconds = 0 "
737
+ ],
738
+ "execution_count": 0,
739
+ "outputs": []
740
+ },
741
+ {
742
+ "cell_type": "code",
743
+ "metadata": {
744
+ "id": "r4SBGAmsnJtH",
745
+ "colab_type": "code",
746
+ "colab": {}
747
+ },
748
+ "source": [
749
+ "for i in range(hps.n_samples):\n",
750
+ " librosa.output.write_wav(f'top_level_continuation_{i}.wav', x[i][playback_start_time_in_seconds*44100:], sr=44100)"
751
+ ],
752
+ "execution_count": 0,
753
+ "outputs": []
754
+ },
755
+ {
756
+ "cell_type": "code",
757
+ "metadata": {
758
+ "id": "2WeyE5Qtnmeo",
759
+ "colab_type": "code",
760
+ "colab": {}
761
+ },
762
+ "source": [
763
+ "Audio('top_level_continuation_0.wav')"
764
+ ],
765
+ "execution_count": 0,
766
+ "outputs": []
767
+ },
768
+ {
769
+ "cell_type": "code",
770
+ "metadata": {
771
+ "id": "BKtfEtcaazXE",
772
+ "colab_type": "code",
773
+ "colab": {}
774
+ },
775
+ "source": [
776
+ "Audio('top_level_continuation_1.wav')"
777
+ ],
778
+ "execution_count": 0,
779
+ "outputs": []
780
+ },
781
+ {
782
+ "cell_type": "code",
783
+ "metadata": {
784
+ "id": "7yrlS0XwK2S0",
785
+ "colab_type": "code",
786
+ "colab": {}
787
+ },
788
+ "source": [
789
+ "Audio('top_level_continuation_2.wav')"
790
+ ],
791
+ "execution_count": 0,
792
+ "outputs": []
793
+ },
794
+ {
795
+ "cell_type": "markdown",
796
+ "metadata": {
797
+ "id": "-OJT704dvnGv",
798
+ "colab_type": "text"
799
+ },
800
+ "source": [
801
+ "To make a longer song, return back to \"Choose your favorite sample\" and loop through that again"
802
+ ]
803
+ },
804
+ {
805
+ "cell_type": "markdown",
806
+ "metadata": {
807
+ "id": "RzCrkCZJvUcQ",
808
+ "colab_type": "text"
809
+ },
810
+ "source": [
811
+ "# Upsample Co-Composition to Higher Audio Quality"
812
+ ]
813
+ },
814
+ {
815
+ "cell_type": "markdown",
816
+ "metadata": {
817
+ "id": "4MPgukwMmB0p",
818
+ "colab_type": "text"
819
+ },
820
+ "source": [
821
+ "Choose your favorite sample from your latest group of generations. (If you haven't already gone through the Co-Composition block, make sure to do that first so you have a generation to upsample)."
822
+ ]
823
+ },
824
+ {
825
+ "cell_type": "code",
826
+ "metadata": {
827
+ "id": "yv-pNNPHBQYC",
828
+ "colab_type": "code",
829
+ "colab": {}
830
+ },
831
+ "source": [
832
+ "choice = 0\n",
833
+ "select_best_sample = True # Set false if you want to upsample all your samples \n",
834
+ " # upsampling sometimes yields subtly different results on multiple runs,\n",
835
+ " # so this way you can choose your favorite upsampling"
836
+ ],
837
+ "execution_count": 0,
838
+ "outputs": []
839
+ },
840
+ {
841
+ "cell_type": "code",
842
+ "metadata": {
843
+ "id": "v17cEAqyCgfo",
844
+ "colab_type": "code",
845
+ "colab": {}
846
+ },
847
+ "source": [
848
+ "if select_best_sample:\n",
849
+ " zs[2]=zs[2][choice].repeat(zs[2].shape[0],1)\n",
850
+ "\n",
851
+ "t.save(zs, 'zs-top-level-final.t')"
852
+ ],
853
+ "execution_count": 0,
854
+ "outputs": []
855
+ },
856
+ {
857
+ "cell_type": "markdown",
858
+ "metadata": {
859
+ "id": "0YjK-Ac0tBfu",
860
+ "colab_type": "text"
861
+ },
862
+ "source": [
863
+ "Note: If you are using a CoLab hosted runtime on the free tier, you may want to download this zs-top-level-final.t file, and then restart an instance and load it in the next cell. The free tier will last a maximum of 12 hours, and the upsampling stage can take many hours, depending on how long a sample you have generated."
864
+ ]
865
+ },
866
+ {
867
+ "cell_type": "code",
868
+ "metadata": {
869
+ "id": "qqlR9368s3jJ",
870
+ "colab_type": "code",
871
+ "colab": {}
872
+ },
873
+ "source": [
874
+ "if False:\n",
875
+ " zs = t.load('zs-top-level-final.t')\n",
876
+ "\n",
877
+ "assert zs[2].shape[1]>=2048, f'Please first generate at least 2048 tokens at the top level, currently you have {zs[2].shape[1]}'\n",
878
+ "hps.sample_length = zs[2].shape[1]*top_prior.raw_to_tokens"
879
+ ],
880
+ "execution_count": 0,
881
+ "outputs": []
882
+ },
883
+ {
884
+ "cell_type": "code",
885
+ "metadata": {
886
+ "id": "jzHwF_iqgIWM",
887
+ "colab_type": "code",
888
+ "colab": {}
889
+ },
890
+ "source": [
891
+ "# Set this False if you are on a local machine that has enough memory (this allows you to do the\n",
892
+ "# lyrics alignment visualization). For a hosted runtime, we'll need to go ahead and delete the top_prior\n",
893
+ "# if you are using the 5b_lyrics model.\n",
894
+ "if True:\n",
895
+ " del top_prior\n",
896
+ " empty_cache()\n",
897
+ " top_prior=None\n",
898
+ "\n",
899
+ "upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]"
900
+ ],
901
+ "execution_count": 0,
902
+ "outputs": []
903
+ },
904
+ {
905
+ "cell_type": "code",
906
+ "metadata": {
907
+ "id": "q22Ier6YSkKS",
908
+ "colab_type": "code",
909
+ "colab": {}
910
+ },
911
+ "source": [
912
+ "sampling_kwargs = [dict(temp=.99, fp16=True, max_batch_size=16, chunk_size=32),\n",
913
+ " dict(temp=0.99, fp16=True, max_batch_size=16, chunk_size=32),\n",
914
+ " None]\n",
915
+ "\n",
916
+ "if type(labels)==dict:\n",
917
+ " labels = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers] + [labels] "
918
+ ],
919
+ "execution_count": 0,
920
+ "outputs": []
921
+ },
922
+ {
923
+ "cell_type": "markdown",
924
+ "metadata": {
925
+ "id": "T1MCa9_jnjpf",
926
+ "colab_type": "text"
927
+ },
928
+ "source": [
929
+ "This next step upsamples 2 levels. The level_1 samples will be available after around one hour (depending on the length of your sample) and are saved under {hps.name}/level_0/item_0.wav, while the fully upsampled level_0 will likely take 4-12 hours. You can access the wav files down below, or using the \"Files\" panel at the left of this CoLab.\n",
930
+ "\n",
931
+ "(Please note, if you are using this CoLab on Google's free tier, you may want to download intermediate steps as the connection will last for a maximum 12 hours.)"
932
+ ]
933
+ },
934
+ {
935
+ "cell_type": "code",
936
+ "metadata": {
937
+ "id": "NcNT5qIRMmHq",
938
+ "colab_type": "code",
939
+ "colab": {}
940
+ },
941
+ "source": [
942
+ "zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)"
943
+ ],
944
+ "execution_count": 0,
945
+ "outputs": []
946
+ },
947
+ {
948
+ "cell_type": "code",
949
+ "metadata": {
950
+ "id": "W2jTYLPBc29M",
951
+ "colab_type": "code",
952
+ "colab": {}
953
+ },
954
+ "source": [
955
+ "Audio(f'{hps.name}/level_0/item_0.wav')"
956
+ ],
957
+ "execution_count": 0,
958
+ "outputs": []
959
+ }
960
+ ]
961
+ }
train.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ability to train vq-vae and prior
3
+ First try for random inputs
4
+ Then from maestros
5
+ """
6
+ import sys
7
+ import fire
8
+ import warnings
9
+ import numpy as np
10
+ import torch as t
11
+ import jukebox.utils.dist_adapter as dist
12
+ from torch.nn.parallel import DistributedDataParallel
13
+
14
+ from jukebox.hparams import setup_hparams
15
+ from jukebox.make_models import make_vqvae, make_prior, restore_opt, save_checkpoint
16
+ from jukebox.utils.logger import init_logging
17
+ from jukebox.utils.audio_utils import audio_preprocess, audio_postprocess
18
+ from jukebox.utils.torch_utils import zero_grad, count_parameters
19
+ from jukebox.utils.dist_utils import print_once, allreduce, allgather
20
+ from jukebox.utils.ema import CPUEMA, FusedEMA, EMA
21
+ from jukebox.utils.fp16 import FP16FusedAdam, FusedAdam, LossScalar, clipped_grad_scale, backward
22
+ from jukebox.data.data_processor import DataProcessor
23
+
24
+ def prepare_aud(x, hps):
25
+ x = audio_postprocess(x.detach().contiguous(), hps)
26
+ return allgather(x)
27
+
28
+ def log_aud(logger, tag, x, hps):
29
+ logger.add_audios(tag, prepare_aud(x, hps), hps.sr, max_len=hps.max_len, max_log=hps.max_log)
30
+ logger.flush()
31
+
32
+ def log_labels(logger, labeller, tag, y, hps):
33
+ y = y.cpu().numpy()
34
+ txt = ''
35
+ for item in range(y.shape[0]):
36
+ description = labeller.describe_label(y[item])
37
+ artist, genre, lyrics = description['artist'], description['genre'], description['lyrics']
38
+ txt += f'{item} artist:{artist}, genre:{genre}, lyrics:{lyrics}\n'
39
+ logger.add_text(tag, txt)
40
+ logger.flush()
41
+
42
+ def get_ddp(model, hps):
43
+ rank = dist.get_rank()
44
+ local_rank = rank % 8
45
+ ddp = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False, bucket_cap_mb=hps.bucket)
46
+ return ddp
47
+
48
+ def get_ema(model, hps):
49
+ mu = hps.mu or (1. - (hps.bs * hps.ngpus/8.)/1000)
50
+ ema = None
51
+ if hps.ema and hps.train:
52
+ if hps.cpu_ema:
53
+ if dist.get_rank() == 0:
54
+ print("Using CPU EMA")
55
+ ema = CPUEMA(model.parameters(), mu=mu, freq=hps.cpu_ema_freq)
56
+ elif hps.ema_fused:
57
+ ema = FusedEMA(model.parameters(), mu=mu)
58
+ else:
59
+ ema = EMA(model.parameters(), mu=mu)
60
+ return ema
61
+
62
+ def get_lr_scheduler(opt, hps):
63
+ def lr_lambda(step):
64
+ if hps.lr_use_linear_decay:
65
+ lr_scale = hps.lr_scale * min(1.0, step / hps.lr_warmup)
66
+ decay = max(0.0, 1.0 - max(0.0, step - hps.lr_start_linear_decay) / hps.lr_decay)
67
+ if decay == 0.0:
68
+ if dist.get_rank() == 0:
69
+ print("Reached end of training")
70
+ return lr_scale * decay
71
+ else:
72
+ return hps.lr_scale * (hps.lr_gamma ** (step // hps.lr_decay)) * min(1.0, step / hps.lr_warmup)
73
+
74
+ shd = t.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
75
+
76
+ return shd
77
+
78
+ def get_optimizer(model, hps):
79
+ # Optimizer
80
+ betas = (hps.beta1, hps.beta2)
81
+ if hps.fp16_opt:
82
+ opt = FP16FusedAdam(model.parameters(), lr=hps.lr, weight_decay=hps.weight_decay, betas=betas, eps=hps.eps)
83
+ else:
84
+ opt = FusedAdam(model.parameters(), lr=hps.lr, weight_decay=hps.weight_decay, betas=betas, eps=hps.eps)
85
+
86
+ # lr scheduler
87
+ shd = get_lr_scheduler(opt, hps)
88
+
89
+ restore_path = hps.restore_prior if hps.prior else hps.restore_vqvae
90
+ restore_opt(opt, shd, restore_path)
91
+
92
+ # fp16 dynamic loss scaler
93
+ scalar = None
94
+ if hps.fp16:
95
+ rank = dist.get_rank()
96
+ local_rank = rank % 8
97
+ scalar = LossScalar(hps.fp16_loss_scale, scale_factor=2 ** (1./hps.fp16_scale_window))
98
+ if local_rank == 0: print(scalar.__dict__)
99
+
100
+ zero_grad(model)
101
+ return opt, shd, scalar
102
+
103
+ def log_inputs(orig_model, logger, x_in, y, x_out, hps, tag="train"):
104
+ print(f"Logging {tag} inputs/ouputs")
105
+ log_aud(logger, f'{tag}_x_in', x_in, hps)
106
+ log_aud(logger, f'{tag}_x_out', x_out, hps)
107
+ bs = x_in.shape[0]
108
+ if hps.prior:
109
+ if hps.labels:
110
+ log_labels(logger, orig_model.labeller, f'{tag}_y_in', allgather(y.cuda()), hps)
111
+ else:
112
+ zs_in = orig_model.encode(x_in, start_level=0, bs_chunks=bs)
113
+ x_ds = [orig_model.decode(zs_in[level:], start_level=level, bs_chunks=bs) for level in range(0, hps.levels)]
114
+ for i in range(len(x_ds)):
115
+ log_aud(logger, f'{tag}_x_ds_start_{i}', x_ds[i], hps)
116
+ logger.flush()
117
+
118
+ def sample_prior(orig_model, ema, logger, x_in, y, hps):
119
+ if ema is not None: ema.swap()
120
+ orig_model.eval()
121
+
122
+ x_in = x_in[:hps.bs_sample]
123
+ bs = x_in.shape[0]
124
+ zs_in = orig_model.encode(x_in, start_level=0, bs_chunks=bs)
125
+ assert len(zs_in) == hps.levels
126
+ x_ds = [orig_model.decode(zs_in[level:], start_level=level, bs_chunks=bs) for level in range(0, hps.levels)]
127
+
128
+ if not hps.labels:
129
+ y = None
130
+ elif hps.level == (hps.levels - 1):
131
+ # Topmost level labels in order
132
+ y = y[:hps.bs_sample] # t.ones((hps.bs_sample, 1), device=y.device, dtype=t.long) * dist.get_rank()
133
+ else:
134
+ # Other levels keep labels to match x_cond
135
+ y = y[:hps.bs_sample]
136
+
137
+ # Temp 1.0
138
+ _, *z_conds = orig_model.encode(x_in, bs_chunks=bs)
139
+ z = orig_model.sample(hps.bs_sample, z_conds=z_conds, y=y, fp16=False, temp=1.0)
140
+ x_sample = orig_model.decode([z, *z_conds], bs_chunks=bs)
141
+
142
+ log_aud(logger, 'sample_x_T1', x_sample, hps)
143
+ if hps.prior and hps.labels:
144
+ log_labels(logger, orig_model.labeller, f'sample_x_T1', allgather(y.cuda()), hps)
145
+
146
+ # Recons
147
+ for i in range(len(x_ds)):
148
+ log_aud(logger, f'x_ds_start_{i}', x_ds[i], hps)
149
+ orig_model.train()
150
+ if ema is not None: ema.swap()
151
+ logger.flush()
152
+
153
+ def evaluate(model, orig_model, logger, metrics, data_processor, hps):
154
+ model.eval()
155
+ orig_model.eval()
156
+ if hps.prior:
157
+ _print_keys = dict(l="loss", bpd="bpd")
158
+ else:
159
+ _print_keys = dict(l="loss", rl="recons_loss", sl="spectral_loss")
160
+
161
+ with t.no_grad():
162
+ for i, x in logger.get_range(data_processor.test_loader):
163
+ if isinstance(x, (tuple, list)):
164
+ x, y = x
165
+ else:
166
+ y = None
167
+
168
+ x = x.to('cuda', non_blocking=True)
169
+ if y is not None:
170
+ y = y.to('cuda', non_blocking=True)
171
+
172
+ x_in = x = audio_preprocess(x, hps)
173
+ log_input_output = (i==0)
174
+
175
+ if hps.prior:
176
+ forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output)
177
+ else:
178
+ forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps)
179
+
180
+ x_out, loss, _metrics = model(x, **forw_kwargs)
181
+
182
+ # Logging
183
+ for key, val in _metrics.items():
184
+ _metrics[key] = val.item()
185
+ _metrics["loss"] = loss = loss.item() # Make sure to call to free graph
186
+
187
+ # Average and log
188
+ for key, val in _metrics.items():
189
+ _metrics[key] = metrics.update(f"test_{key}", val, x.shape[0])
190
+
191
+ with t.no_grad():
192
+ if log_input_output:
193
+ log_inputs(orig_model, logger, x_in, y, x_out, hps)
194
+
195
+ logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()})
196
+
197
+ for key, val in _metrics.items():
198
+ logger.add_scalar(f"test_{key}", metrics.avg(f"test_{key}"))
199
+
200
+ logger.close_range()
201
+ return {key: metrics.avg(f"test_{key}") for key in _metrics.keys()}
202
+
203
+ def train(model, orig_model, opt, shd, scalar, ema, logger, metrics, data_processor, hps):
204
+ model.train()
205
+ orig_model.train()
206
+ if hps.prior:
207
+ _print_keys = dict(l="loss", bpd="bpd", gn="gn", g_l="gen_loss", p_l="prime_loss")
208
+ else:
209
+ _print_keys = dict(l="loss", sl="spectral_loss", rl="recons_loss", e="entropy", u="usage", uc="used_curr", gn="gn", pn="pn", dk="dk")
210
+
211
+ for i, x in logger.get_range(data_processor.train_loader):
212
+ if isinstance(x, (tuple, list)):
213
+ x, y = x
214
+ else:
215
+ y = None
216
+
217
+ x = x.to('cuda', non_blocking=True)
218
+ if y is not None:
219
+ y = y.to('cuda', non_blocking=True)
220
+
221
+ x_in = x = audio_preprocess(x, hps)
222
+ log_input_output = (logger.iters % hps.save_iters == 0)
223
+
224
+ if hps.prior:
225
+ forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output)
226
+ else:
227
+ forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps)
228
+
229
+ # Forward
230
+ x_out, loss, _metrics = model(x, **forw_kwargs)
231
+
232
+ # Backward
233
+ loss, scale, grad_norm, overflow_loss, overflow_grad = backward(loss=loss, params=list(model.parameters()),
234
+ scalar=scalar, fp16=hps.fp16, logger=logger)
235
+ # Skip step if overflow
236
+ grad_norm = allreduce(grad_norm, op=dist.ReduceOp.MAX)
237
+ if overflow_loss or overflow_grad or grad_norm > hps.ignore_grad_norm > 0:
238
+ zero_grad(orig_model)
239
+ continue
240
+
241
+ # Step opt. Divide by scale to include clipping and fp16 scaling
242
+ logger.step()
243
+ opt.step(scale=clipped_grad_scale(grad_norm, hps.clip, scale))
244
+ zero_grad(orig_model)
245
+ lr = hps.lr if shd is None else shd.get_lr()[0]
246
+ if shd is not None: shd.step()
247
+ if ema is not None: ema.step()
248
+ next_lr = hps.lr if shd is None else shd.get_lr()[0]
249
+ finished_training = (next_lr == 0.0)
250
+
251
+ # Logging
252
+ for key, val in _metrics.items():
253
+ _metrics[key] = val.item()
254
+ _metrics["loss"] = loss = loss.item() * hps.iters_before_update # Make sure to call to free graph
255
+ _metrics["gn"] = grad_norm
256
+ _metrics["lr"] = lr
257
+ _metrics["lg_loss_scale"] = np.log2(scale)
258
+
259
+ # Average and log
260
+ for key, val in _metrics.items():
261
+ _metrics[key] = metrics.update(key, val, x.shape[0])
262
+ if logger.iters % hps.log_steps == 0:
263
+ logger.add_scalar(key, _metrics[key])
264
+
265
+ # Save checkpoint
266
+ with t.no_grad():
267
+ if hps.save and (logger.iters % hps.save_iters == 1 or finished_training):
268
+ if ema is not None: ema.swap()
269
+ orig_model.eval()
270
+ name = 'latest' if hps.prior else f'step_{logger.iters}'
271
+ if dist.get_rank() % 8 == 0:
272
+ save_checkpoint(logger, name, orig_model, opt, dict(step=logger.iters), hps)
273
+ orig_model.train()
274
+ if ema is not None: ema.swap()
275
+
276
+ # Sample
277
+ with t.no_grad():
278
+ if (logger.iters % 12000) in list(range(1, 1 + hps.iters_before_update)) or finished_training:
279
+ if hps.prior:
280
+ sample_prior(orig_model, ema, logger, x_in, y, hps)
281
+
282
+ # Input/Output
283
+ with t.no_grad():
284
+ if log_input_output:
285
+ log_inputs(orig_model, logger, x_in, y, x_out, hps)
286
+
287
+ logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()})
288
+ if finished_training:
289
+ dist.barrier()
290
+ exit()
291
+ logger.close_range()
292
+ return {key: metrics.avg(key) for key in _metrics.keys()}
293
+
294
+ def run(hps="teeny", port=29500, **kwargs):
295
+ from jukebox.utils.dist_utils import setup_dist_from_mpi
296
+ rank, local_rank, device = setup_dist_from_mpi(port=port)
297
+ hps = setup_hparams(hps, kwargs)
298
+ hps.ngpus = dist.get_world_size()
299
+ hps.argv = " ".join(sys.argv)
300
+ hps.bs_sample = hps.nworkers = hps.bs
301
+
302
+ # Setup dataset
303
+ data_processor = DataProcessor(hps)
304
+
305
+ # Setup models
306
+ vqvae = make_vqvae(hps, device)
307
+ print_once(f"Parameters VQVAE:{count_parameters(vqvae)}")
308
+ if hps.prior:
309
+ prior = make_prior(hps, vqvae, device)
310
+ print_once(f"Parameters Prior:{count_parameters(prior)}")
311
+ model = prior
312
+ else:
313
+ model = vqvae
314
+
315
+ # Setup opt, ema and distributed_model.
316
+ opt, shd, scalar = get_optimizer(model, hps)
317
+ ema = get_ema(model, hps)
318
+ distributed_model = get_ddp(model, hps)
319
+
320
+ logger, metrics = init_logging(hps, local_rank, rank)
321
+ logger.iters = model.step
322
+
323
+ # Run training, eval, sample
324
+ for epoch in range(hps.curr_epoch, hps.epochs):
325
+ metrics.reset()
326
+ data_processor.set_epoch(epoch)
327
+ if hps.train:
328
+ train_metrics = train(distributed_model, model, opt, shd, scalar, ema, logger, metrics, data_processor, hps)
329
+ train_metrics['epoch'] = epoch
330
+ if rank == 0:
331
+ print('Train',' '.join([f'{key}: {val:0.4f}' for key,val in train_metrics.items()]))
332
+ dist.barrier()
333
+
334
+ if hps.test:
335
+ if ema: ema.swap()
336
+ test_metrics = evaluate(distributed_model, model, logger, metrics, data_processor, hps)
337
+ test_metrics['epoch'] = epoch
338
+ if rank == 0:
339
+ print('Ema',' '.join([f'{key}: {val:0.4f}' for key,val in test_metrics.items()]))
340
+ dist.barrier()
341
+ if ema: ema.swap()
342
+ dist.barrier()
343
+
344
+ if __name__ == '__main__':
345
+ fire.Fire(run)