teticio commited on
Commit
8bf6ce9
1 Parent(s): 1bb3e27

added training notebook for colab

Browse files
notebooks/train_model.ipynb ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "62c5865f",
6
+ "metadata": {
7
+ "id": "62c5865f"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/test_model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "6c7800a6",
17
+ "metadata": {
18
+ "colab": {
19
+ "base_uri": "https://localhost:8080/"
20
+ },
21
+ "id": "6c7800a6",
22
+ "outputId": "ed18f4a9-ccea-4d7c-c82b-1749f1041f6c"
23
+ },
24
+ "outputs": [],
25
+ "source": [
26
+ "try:\n",
27
+ " # are we running on Google Colab?\n",
28
+ " import google.colab\n",
29
+ " !git clone -q https://github.com/teticio/audio-diffusion.git\n",
30
+ " %cd audio-diffusion\n",
31
+ " !pip install -q -r requirements.txt .\n",
32
+ "except:\n",
33
+ " pass"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "id": "c2fc0e7a",
40
+ "metadata": {
41
+ "id": "c2fc0e7a"
42
+ },
43
+ "outputs": [],
44
+ "source": [
45
+ "from IPython.display import Audio\n",
46
+ "from audiodiffusion import AudioDiffusion"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "markdown",
51
+ "id": "MqlpL75_mDVv",
52
+ "metadata": {
53
+ "id": "MqlpL75_mDVv"
54
+ },
55
+ "source": [
56
+ "### Upload / specify audio files to train on\n",
57
+ "Provide some MP3 or WAV files that will be split into samples and converted to Mel spectrograms. For a resolution of 256, the samples will be about 5 seconds long."
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "id": "jg1zAHVsmCBG",
64
+ "metadata": {
65
+ "colab": {
66
+ "base_uri": "https://localhost:8080/",
67
+ "height": 73
68
+ },
69
+ "id": "jg1zAHVsmCBG",
70
+ "outputId": "414244c9-02b6-4ccf-cbfd-83f9022a0fc1"
71
+ },
72
+ "outputs": [],
73
+ "source": [
74
+ "try:\n",
75
+ " # are we running on Google Colab?\n",
76
+ " from google.colab import files\n",
77
+ " input_dir = '.'\n",
78
+ " files.upload();\n",
79
+ "except:\n",
80
+ " input_dir = \"/home/teticio/Music/liked\""
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "markdown",
85
+ "id": "10v0RCSUu75P",
86
+ "metadata": {
87
+ "id": "10v0RCSUu75P"
88
+ },
89
+ "source": [
90
+ "### Prepare dataset"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "id": "NJNeEU6ftaTM",
97
+ "metadata": {
98
+ "colab": {
99
+ "base_uri": "https://localhost:8080/"
100
+ },
101
+ "id": "NJNeEU6ftaTM",
102
+ "outputId": "6c5bed15-c821-4def-eb90-3ab1a17b3c3d"
103
+ },
104
+ "outputs": [],
105
+ "source": [
106
+ "!python scripts/audio_to_images.py \\\n",
107
+ " --resolution 256,256 \\\n",
108
+ " --input_dir {input_dir} \\\n",
109
+ " --output_dir data"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "markdown",
114
+ "id": "5mGeXyJFvQCO",
115
+ "metadata": {
116
+ "id": "5mGeXyJFvQCO"
117
+ },
118
+ "source": [
119
+ "### Train model\n",
120
+ "The DDIM scheduler generates samples much faster."
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "id": "JGnlePbLvTOH",
127
+ "metadata": {
128
+ "colab": {
129
+ "base_uri": "https://localhost:8080/"
130
+ },
131
+ "id": "JGnlePbLvTOH",
132
+ "outputId": "69b6f53e-25a3-4c59-e205-2eab42889cd8"
133
+ },
134
+ "outputs": [],
135
+ "source": [
136
+ "!python scripts/train_unconditional.py \\\n",
137
+ " --dataset_name data \\\n",
138
+ " --output_dir model \\\n",
139
+ " --num_epochs 10 \\\n",
140
+ " --train_batch_size 2 \\\n",
141
+ " --eval_batch_size 2 \\\n",
142
+ " --gradient_accumulation_steps 8 \\\n",
143
+ " --save_images_epochs 100 \\\n",
144
+ " --save_model_epochs 1 \\\n",
145
+ " --scheduler ddim"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "markdown",
150
+ "id": "nTMAYEtMxtt0",
151
+ "metadata": {
152
+ "id": "nTMAYEtMxtt0"
153
+ },
154
+ "source": [
155
+ "### Generate samples with model"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "id": "b294a94a",
162
+ "metadata": {
163
+ "id": "b294a94a"
164
+ },
165
+ "outputs": [],
166
+ "source": [
167
+ "audio_diffusion = AudioDiffusion('model')"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "id": "k2bKq3aqyAIM",
174
+ "metadata": {
175
+ "colab": {
176
+ "base_uri": "https://localhost:8080/",
177
+ "height": 363,
178
+ "referenced_widgets": [
179
+ "474d4db933d54e0497da4076a7fe135b",
180
+ "a849a3a1b46947db830a6a087411ec68",
181
+ "378f819239274ac88d913714bc27bf06",
182
+ "cc3b33e508744206955b26a417fbbdec",
183
+ "6015e5a9e6774e9abf7db273bca57363",
184
+ "629c21c68d22447185bb961e22bce4a6",
185
+ "2d5abefbc2ed4b72aed8c4f8ddc7a00c",
186
+ "11d1dbae00764a1c9dcc899c0b0f67dc",
187
+ "acdb5ddc7bda411a948689787b18b21e",
188
+ "9c4955f9d0f443a7b28ed827c5cdb37f",
189
+ "f9a1a976d82148f8961e80c357bc2764"
190
+ ]
191
+ },
192
+ "id": "k2bKq3aqyAIM",
193
+ "outputId": "d48238fe-ae36-4736-e67b-b69e3729304a"
194
+ },
195
+ "outputs": [],
196
+ "source": [
197
+ "image, (sample_rate, audio) = audio_diffusion.generate_spectrogram_and_audio()\n",
198
+ "display(image)\n",
199
+ "display(Audio(audio, rate=sample_rate))"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "id": "K2qAIJzg2DNK",
206
+ "metadata": {
207
+ "id": "K2qAIJzg2DNK"
208
+ },
209
+ "outputs": [],
210
+ "source": []
211
+ }
212
+ ],
213
+ "metadata": {
214
+ "accelerator": "GPU",
215
+ "colab": {
216
+ "collapsed_sections": [],
217
+ "provenance": []
218
+ },
219
+ "gpuClass": "standard",
220
+ "kernelspec": {
221
+ "display_name": "huggingface",
222
+ "language": "python",
223
+ "name": "huggingface"
224
+ },
225
+ "language_info": {
226
+ "codemirror_mode": {
227
+ "name": "ipython",
228
+ "version": 3
229
+ },
230
+ "file_extension": ".py",
231
+ "mimetype": "text/x-python",
232
+ "name": "python",
233
+ "nbconvert_exporter": "python",
234
+ "pygments_lexer": "ipython3",
235
+ "version": "3.10.6"
236
+ },
237
+ "toc": {
238
+ "base_numbering": 1,
239
+ "nav_menu": {},
240
+ "number_sections": true,
241
+ "sideBar": true,
242
+ "skip_h1_title": false,
243
+ "title_cell": "Table of Contents",
244
+ "title_sidebar": "Contents",
245
+ "toc_cell": false,
246
+ "toc_position": {},
247
+ "toc_section_display": true,
248
+ "toc_window_display": false
249
+ },
250
+ "widgets": {
251
+ "application/vnd.jupyter.widget-state+json": {
252
+ "11d1dbae00764a1c9dcc899c0b0f67dc": {
253
+ "model_module": "@jupyter-widgets/base",
254
+ "model_module_version": "1.2.0",
255
+ "model_name": "LayoutModel",
256
+ "state": {
257
+ "_model_module": "@jupyter-widgets/base",
258
+ "_model_module_version": "1.2.0",
259
+ "_model_name": "LayoutModel",
260
+ "_view_count": null,
261
+ "_view_module": "@jupyter-widgets/base",
262
+ "_view_module_version": "1.2.0",
263
+ "_view_name": "LayoutView",
264
+ "align_content": null,
265
+ "align_items": null,
266
+ "align_self": null,
267
+ "border": null,
268
+ "bottom": null,
269
+ "display": null,
270
+ "flex": null,
271
+ "flex_flow": null,
272
+ "grid_area": null,
273
+ "grid_auto_columns": null,
274
+ "grid_auto_flow": null,
275
+ "grid_auto_rows": null,
276
+ "grid_column": null,
277
+ "grid_gap": null,
278
+ "grid_row": null,
279
+ "grid_template_areas": null,
280
+ "grid_template_columns": null,
281
+ "grid_template_rows": null,
282
+ "height": null,
283
+ "justify_content": null,
284
+ "justify_items": null,
285
+ "left": null,
286
+ "margin": null,
287
+ "max_height": null,
288
+ "max_width": null,
289
+ "min_height": null,
290
+ "min_width": null,
291
+ "object_fit": null,
292
+ "object_position": null,
293
+ "order": null,
294
+ "overflow": null,
295
+ "overflow_x": null,
296
+ "overflow_y": null,
297
+ "padding": null,
298
+ "right": null,
299
+ "top": null,
300
+ "visibility": null,
301
+ "width": null
302
+ }
303
+ },
304
+ "2d5abefbc2ed4b72aed8c4f8ddc7a00c": {
305
+ "model_module": "@jupyter-widgets/controls",
306
+ "model_module_version": "1.5.0",
307
+ "model_name": "DescriptionStyleModel",
308
+ "state": {
309
+ "_model_module": "@jupyter-widgets/controls",
310
+ "_model_module_version": "1.5.0",
311
+ "_model_name": "DescriptionStyleModel",
312
+ "_view_count": null,
313
+ "_view_module": "@jupyter-widgets/base",
314
+ "_view_module_version": "1.2.0",
315
+ "_view_name": "StyleView",
316
+ "description_width": ""
317
+ }
318
+ },
319
+ "378f819239274ac88d913714bc27bf06": {
320
+ "model_module": "@jupyter-widgets/controls",
321
+ "model_module_version": "1.5.0",
322
+ "model_name": "FloatProgressModel",
323
+ "state": {
324
+ "_dom_classes": [],
325
+ "_model_module": "@jupyter-widgets/controls",
326
+ "_model_module_version": "1.5.0",
327
+ "_model_name": "FloatProgressModel",
328
+ "_view_count": null,
329
+ "_view_module": "@jupyter-widgets/controls",
330
+ "_view_module_version": "1.5.0",
331
+ "_view_name": "ProgressView",
332
+ "bar_style": "success",
333
+ "description": "",
334
+ "description_tooltip": null,
335
+ "layout": "IPY_MODEL_11d1dbae00764a1c9dcc899c0b0f67dc",
336
+ "max": 50,
337
+ "min": 0,
338
+ "orientation": "horizontal",
339
+ "style": "IPY_MODEL_acdb5ddc7bda411a948689787b18b21e",
340
+ "value": 50
341
+ }
342
+ },
343
+ "474d4db933d54e0497da4076a7fe135b": {
344
+ "model_module": "@jupyter-widgets/controls",
345
+ "model_module_version": "1.5.0",
346
+ "model_name": "HBoxModel",
347
+ "state": {
348
+ "_dom_classes": [],
349
+ "_model_module": "@jupyter-widgets/controls",
350
+ "_model_module_version": "1.5.0",
351
+ "_model_name": "HBoxModel",
352
+ "_view_count": null,
353
+ "_view_module": "@jupyter-widgets/controls",
354
+ "_view_module_version": "1.5.0",
355
+ "_view_name": "HBoxView",
356
+ "box_style": "",
357
+ "children": [
358
+ "IPY_MODEL_a849a3a1b46947db830a6a087411ec68",
359
+ "IPY_MODEL_378f819239274ac88d913714bc27bf06",
360
+ "IPY_MODEL_cc3b33e508744206955b26a417fbbdec"
361
+ ],
362
+ "layout": "IPY_MODEL_6015e5a9e6774e9abf7db273bca57363"
363
+ }
364
+ },
365
+ "6015e5a9e6774e9abf7db273bca57363": {
366
+ "model_module": "@jupyter-widgets/base",
367
+ "model_module_version": "1.2.0",
368
+ "model_name": "LayoutModel",
369
+ "state": {
370
+ "_model_module": "@jupyter-widgets/base",
371
+ "_model_module_version": "1.2.0",
372
+ "_model_name": "LayoutModel",
373
+ "_view_count": null,
374
+ "_view_module": "@jupyter-widgets/base",
375
+ "_view_module_version": "1.2.0",
376
+ "_view_name": "LayoutView",
377
+ "align_content": null,
378
+ "align_items": null,
379
+ "align_self": null,
380
+ "border": null,
381
+ "bottom": null,
382
+ "display": null,
383
+ "flex": null,
384
+ "flex_flow": null,
385
+ "grid_area": null,
386
+ "grid_auto_columns": null,
387
+ "grid_auto_flow": null,
388
+ "grid_auto_rows": null,
389
+ "grid_column": null,
390
+ "grid_gap": null,
391
+ "grid_row": null,
392
+ "grid_template_areas": null,
393
+ "grid_template_columns": null,
394
+ "grid_template_rows": null,
395
+ "height": null,
396
+ "justify_content": null,
397
+ "justify_items": null,
398
+ "left": null,
399
+ "margin": null,
400
+ "max_height": null,
401
+ "max_width": null,
402
+ "min_height": null,
403
+ "min_width": null,
404
+ "object_fit": null,
405
+ "object_position": null,
406
+ "order": null,
407
+ "overflow": null,
408
+ "overflow_x": null,
409
+ "overflow_y": null,
410
+ "padding": null,
411
+ "right": null,
412
+ "top": null,
413
+ "visibility": null,
414
+ "width": null
415
+ }
416
+ },
417
+ "629c21c68d22447185bb961e22bce4a6": {
418
+ "model_module": "@jupyter-widgets/base",
419
+ "model_module_version": "1.2.0",
420
+ "model_name": "LayoutModel",
421
+ "state": {
422
+ "_model_module": "@jupyter-widgets/base",
423
+ "_model_module_version": "1.2.0",
424
+ "_model_name": "LayoutModel",
425
+ "_view_count": null,
426
+ "_view_module": "@jupyter-widgets/base",
427
+ "_view_module_version": "1.2.0",
428
+ "_view_name": "LayoutView",
429
+ "align_content": null,
430
+ "align_items": null,
431
+ "align_self": null,
432
+ "border": null,
433
+ "bottom": null,
434
+ "display": null,
435
+ "flex": null,
436
+ "flex_flow": null,
437
+ "grid_area": null,
438
+ "grid_auto_columns": null,
439
+ "grid_auto_flow": null,
440
+ "grid_auto_rows": null,
441
+ "grid_column": null,
442
+ "grid_gap": null,
443
+ "grid_row": null,
444
+ "grid_template_areas": null,
445
+ "grid_template_columns": null,
446
+ "grid_template_rows": null,
447
+ "height": null,
448
+ "justify_content": null,
449
+ "justify_items": null,
450
+ "left": null,
451
+ "margin": null,
452
+ "max_height": null,
453
+ "max_width": null,
454
+ "min_height": null,
455
+ "min_width": null,
456
+ "object_fit": null,
457
+ "object_position": null,
458
+ "order": null,
459
+ "overflow": null,
460
+ "overflow_x": null,
461
+ "overflow_y": null,
462
+ "padding": null,
463
+ "right": null,
464
+ "top": null,
465
+ "visibility": null,
466
+ "width": null
467
+ }
468
+ },
469
+ "9c4955f9d0f443a7b28ed827c5cdb37f": {
470
+ "model_module": "@jupyter-widgets/base",
471
+ "model_module_version": "1.2.0",
472
+ "model_name": "LayoutModel",
473
+ "state": {
474
+ "_model_module": "@jupyter-widgets/base",
475
+ "_model_module_version": "1.2.0",
476
+ "_model_name": "LayoutModel",
477
+ "_view_count": null,
478
+ "_view_module": "@jupyter-widgets/base",
479
+ "_view_module_version": "1.2.0",
480
+ "_view_name": "LayoutView",
481
+ "align_content": null,
482
+ "align_items": null,
483
+ "align_self": null,
484
+ "border": null,
485
+ "bottom": null,
486
+ "display": null,
487
+ "flex": null,
488
+ "flex_flow": null,
489
+ "grid_area": null,
490
+ "grid_auto_columns": null,
491
+ "grid_auto_flow": null,
492
+ "grid_auto_rows": null,
493
+ "grid_column": null,
494
+ "grid_gap": null,
495
+ "grid_row": null,
496
+ "grid_template_areas": null,
497
+ "grid_template_columns": null,
498
+ "grid_template_rows": null,
499
+ "height": null,
500
+ "justify_content": null,
501
+ "justify_items": null,
502
+ "left": null,
503
+ "margin": null,
504
+ "max_height": null,
505
+ "max_width": null,
506
+ "min_height": null,
507
+ "min_width": null,
508
+ "object_fit": null,
509
+ "object_position": null,
510
+ "order": null,
511
+ "overflow": null,
512
+ "overflow_x": null,
513
+ "overflow_y": null,
514
+ "padding": null,
515
+ "right": null,
516
+ "top": null,
517
+ "visibility": null,
518
+ "width": null
519
+ }
520
+ },
521
+ "a849a3a1b46947db830a6a087411ec68": {
522
+ "model_module": "@jupyter-widgets/controls",
523
+ "model_module_version": "1.5.0",
524
+ "model_name": "HTMLModel",
525
+ "state": {
526
+ "_dom_classes": [],
527
+ "_model_module": "@jupyter-widgets/controls",
528
+ "_model_module_version": "1.5.0",
529
+ "_model_name": "HTMLModel",
530
+ "_view_count": null,
531
+ "_view_module": "@jupyter-widgets/controls",
532
+ "_view_module_version": "1.5.0",
533
+ "_view_name": "HTMLView",
534
+ "description": "",
535
+ "description_tooltip": null,
536
+ "layout": "IPY_MODEL_629c21c68d22447185bb961e22bce4a6",
537
+ "placeholder": "​",
538
+ "style": "IPY_MODEL_2d5abefbc2ed4b72aed8c4f8ddc7a00c",
539
+ "value": "100%"
540
+ }
541
+ },
542
+ "acdb5ddc7bda411a948689787b18b21e": {
543
+ "model_module": "@jupyter-widgets/controls",
544
+ "model_module_version": "1.5.0",
545
+ "model_name": "ProgressStyleModel",
546
+ "state": {
547
+ "_model_module": "@jupyter-widgets/controls",
548
+ "_model_module_version": "1.5.0",
549
+ "_model_name": "ProgressStyleModel",
550
+ "_view_count": null,
551
+ "_view_module": "@jupyter-widgets/base",
552
+ "_view_module_version": "1.2.0",
553
+ "_view_name": "StyleView",
554
+ "bar_color": null,
555
+ "description_width": ""
556
+ }
557
+ },
558
+ "cc3b33e508744206955b26a417fbbdec": {
559
+ "model_module": "@jupyter-widgets/controls",
560
+ "model_module_version": "1.5.0",
561
+ "model_name": "HTMLModel",
562
+ "state": {
563
+ "_dom_classes": [],
564
+ "_model_module": "@jupyter-widgets/controls",
565
+ "_model_module_version": "1.5.0",
566
+ "_model_name": "HTMLModel",
567
+ "_view_count": null,
568
+ "_view_module": "@jupyter-widgets/controls",
569
+ "_view_module_version": "1.5.0",
570
+ "_view_name": "HTMLView",
571
+ "description": "",
572
+ "description_tooltip": null,
573
+ "layout": "IPY_MODEL_9c4955f9d0f443a7b28ed827c5cdb37f",
574
+ "placeholder": "​",
575
+ "style": "IPY_MODEL_f9a1a976d82148f8961e80c357bc2764",
576
+ "value": " 50/50 [00:07&lt;00:00, 8.13it/s]"
577
+ }
578
+ },
579
+ "f9a1a976d82148f8961e80c357bc2764": {
580
+ "model_module": "@jupyter-widgets/controls",
581
+ "model_module_version": "1.5.0",
582
+ "model_name": "DescriptionStyleModel",
583
+ "state": {
584
+ "_model_module": "@jupyter-widgets/controls",
585
+ "_model_module_version": "1.5.0",
586
+ "_model_name": "DescriptionStyleModel",
587
+ "_view_count": null,
588
+ "_view_module": "@jupyter-widgets/base",
589
+ "_view_module_version": "1.2.0",
590
+ "_view_name": "StyleView",
591
+ "description_width": ""
592
+ }
593
+ }
594
+ }
595
+ }
596
+ },
597
+ "nbformat": 4,
598
+ "nbformat_minor": 5
599
+ }
scripts/train_unconditional.py CHANGED
@@ -277,9 +277,7 @@ def main(args):
277
  else:
278
  pipeline.save_pretrained(output_dir)
279
 
280
- if (
281
- epoch + 1
282
- ) % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
283
  generator = torch.manual_seed(42)
284
  # run pipeline in inference (sample random noise and denoise)
285
  images, (sample_rate, audios) = pipeline(
 
277
  else:
278
  pipeline.save_pretrained(output_dir)
279
 
280
+ if (epoch + 1) % args.save_images_epochs == 0:
 
 
281
  generator = torch.manual_seed(42)
282
  # run pipeline in inference (sample random noise and denoise)
283
  images, (sample_rate, audios) = pipeline(