teticio commited on
Commit
d021b1c
1 Parent(s): d73aa64
.gitignore CHANGED
@@ -3,3 +3,4 @@ __pycache__
3
  .ipynb_checkpoints
4
  data*
5
  ddpm-ema-audio-*
 
 
3
  .ipynb_checkpoints
4
  data*
5
  ddpm-ema-audio-*
6
+ flagged
README.md CHANGED
@@ -22,7 +22,7 @@ Audio can be represented as images by transforming to a [mel spectrogram](https:
22
 
23
  A DDPM model is trained on a set of mel spectrograms that have been generated from a directory of audio files. It is then used to synthesize similar mel spectrograms, which are then converted back into audio. See the `test-model.ipynb` notebook for an example.
24
 
25
- You can play around with the model I trained on about 500 songs from my Spotify "liked" playlist [here](https://huggingface.co/spaces/teticio/audio-diffusion)
26
 
27
  ## Generate Mel spectrogram dataset from directory of audio files
28
  #### Training can be run with Mel spectrograms of resolution 64x64 on a single commercial grade GPU (e.g. RTX 2080 Ti). The `hop_length` should be set to 1024 for better results.
 
22
 
23
  A DDPM model is trained on a set of mel spectrograms that have been generated from a directory of audio files. It is then used to synthesize similar mel spectrograms, which are then converted back into audio. See the `test-model.ipynb` notebook for an example.
24
 
25
+ You can play around with the model I trained on about 500 songs from my Spotify "liked" playlist on [Google Colab](https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/test-model.ipynb) or [Hugging Face spaces](https://huggingface.co/spaces/teticio/audio-diffusion).
26
 
27
  ## Generate Mel spectrogram dataset from directory of audio files
28
  #### Training can be run with Mel spectrograms of resolution 64x64 on a single commercial grade GPU (e.g. RTX 2080 Ti). The `hop_length` should be set to 1024 for better results.
app.py CHANGED
@@ -28,7 +28,8 @@ if __name__ == "__main__":
28
  demo = gr.Interface(
29
  fn=generate_spectrogram_and_audio,
30
  title="Audio Diffusion",
31
- description=f"Generate audio using Huggingface diffusers",
 
32
  inputs=[],
33
  outputs=[
34
  gr.Image(label="Mel spectrogram", image_mode="L"),
 
28
  demo = gr.Interface(
29
  fn=generate_spectrogram_and_audio,
30
  title="Audio Diffusion",
31
+ description=f"Generate audio using Huggingface diffusers.\
32
+ This takes about 20 minutes without a GPU, so why not make yourself a cup of tea in the meantime?",
33
  inputs=[],
34
  outputs=[
35
  gr.Image(label="Mel spectrogram", image_mode="L"),
notebooks/test-model.ipynb CHANGED
@@ -1,5 +1,30 @@
1
  {
2
  "cells": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": 1,
@@ -256,22 +281,29 @@
256
  "\n",
257
  "\n",
258
  "args = AttributeDict({\n",
259
- " \"hub_model_id\": \"teticio/audio-diffusion-256\",\n",
260
- " \"output_dir\": \"../ddpm-ema-audio-256-repo\",\n",
261
- " \"local_rank\": -1,\n",
262
- " \"hub_token\": \"hf_\",\n",
263
- " \"hub_private_repo\": False,\n",
264
- " \"overwrite_output_dir\": False\n",
 
 
 
 
 
 
265
  "})\n",
266
  "\n",
267
  "repo = init_git_repo(args, at_init=True)\n",
 
268
  "push_to_hub(args, ddpm, repo)"
269
  ]
270
  },
271
  {
272
  "cell_type": "code",
273
  "execution_count": null,
274
- "id": "11b3963b",
275
  "metadata": {},
276
  "outputs": [],
277
  "source": []
 
1
  {
2
  "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "0fd939b0",
6
+ "metadata": {},
7
+ "source": [
8
+ "<a href=\"https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/test-model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 5,
14
+ "id": "6c7800a6",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "try:\n",
19
+ " # are we running on Google Colab?\n",
20
+ " import google.colab\n",
21
+ " !git clone -q https://github.com/teticio/audio-diffusion.git\n",
22
+ " %cd audio-diffusion\n",
23
+ " !pip install -q -r requirements.txt\n",
24
+ "except:\n",
25
+ " pass"
26
+ ]
27
+ },
28
  {
29
  "cell_type": "code",
30
  "execution_count": 1,
 
281
  "\n",
282
  "\n",
283
  "args = AttributeDict({\n",
284
+ " \"hub_model_id\":\n",
285
+ " \"teticio/audio-diffusion-256\",\n",
286
+ " \"output_dir\":\n",
287
+ " \"../ddpm-ema-audio-256-repo\",\n",
288
+ " \"local_rank\":\n",
289
+ " -1,\n",
290
+ " \"hub_token\":\n",
291
+ " open(os.path.join(os.environ['HOME'], '.huggingface/token'), 'rt').read(),\n",
292
+ " \"hub_private_repo\":\n",
293
+ " False,\n",
294
+ " \"overwrite_output_dir\":\n",
295
+ " False\n",
296
  "})\n",
297
  "\n",
298
  "repo = init_git_repo(args, at_init=True)\n",
299
+ "ddpm = DDPMPipeline.from_pretrained('../ddpm-ema-audio-256')\n",
300
  "push_to_hub(args, ddpm, repo)"
301
  ]
302
  },
303
  {
304
  "cell_type": "code",
305
  "execution_count": null,
306
+ "id": "8c8261a0",
307
  "metadata": {},
308
  "outputs": [],
309
  "source": []
requirements-lock.txt CHANGED
@@ -11,7 +11,7 @@ argon2-cffi==21.3.0
11
  argon2-cffi-bindings==21.2.0
12
  async-timeout==4.0.2
13
  attrs==21.4.0
14
- audioread==2.1.9
15
  backcall==0.2.0
16
  backoff==1.10.0
17
  bcrypt==3.2.2
@@ -25,9 +25,12 @@ cachetools==5.2.0
25
  captum==0.5.0
26
  certifi==2022.6.15
27
  cffi==1.15.1
 
28
  charset-normalizer==2.1.0
29
  click==8.1.3
30
  cloudpickle==2.1.0
 
 
31
  cryptography==37.0.4
32
  cycler==0.11.0
33
  datasets==2.4.0
@@ -35,9 +38,11 @@ debugpy==1.6.2
35
  decorator==5.1.1
36
  deepspeed==0.7.0
37
  defusedxml==0.7.1
38
- diffusers==0.1.3
39
  dill==0.3.5.1
 
40
  entrypoints==0.4
 
41
  fastapi==0.79.0
42
  fastjsonschema==2.16.1
43
  ffmpy==0.3.0
@@ -51,12 +56,14 @@ google-pasta==0.2.0
51
  gradio==3.1.4
52
  grpcio==1.47.0
53
  h11==0.12.0
54
- hjson==3.0.2
55
  httpcore==0.15.0
56
  httpx==0.23.0
57
  huggingface-hub==0.8.1
 
58
  idna==3.3
59
  importlib-metadata==4.12.0
 
60
  ipykernel==6.15.1
61
  ipython==7.34.0
62
  ipython-genutils==0.2.0
@@ -78,10 +85,10 @@ lxml==4.9.1
78
  Markdown==3.4.1
79
  markdown-it-py==2.1.0
80
  MarkupSafe==2.1.1
81
- matplotlib==3.5.2
82
  matplotlib-inline==0.1.3
83
  mdit-py-plugins==0.3.0
84
- mdurl==0.1.1
85
  mistune==0.8.4
86
  monotonic==1.6
87
  more-itertools==8.14.0
@@ -90,23 +97,25 @@ multiprocess==0.70.13
90
  munkres==1.1.4
91
  mypy-extensions==0.4.3
92
  nbclient==0.6.6
93
- nbconvert==6.5.1
94
  nbformat==5.4.0
95
  nest-asyncio==1.5.5
96
  networkx==2.8.5
97
  ninja==1.10.2.3
98
  nlp==0.4.0
99
  nltk==3.7
 
100
  notebook==6.4.12
101
  numba==0.56.0
102
  numpy==1.22.4
103
  oauthlib==3.2.0
104
- orjson==3.7.11
105
  packaging==21.3
106
  pandas==1.4.3
107
  pandocfilters==1.5.0
108
  paramiko==2.11.0
109
  parso==0.8.3
 
110
  pathos==0.2.9
111
  pathspec==0.9.0
112
  pexpect==4.8.0
@@ -117,6 +126,7 @@ pluggy==0.13.1
117
  pooch==1.6.0
118
  pox==0.3.1
119
  ppft==1.7.6.5
 
120
  prometheus-client==0.14.1
121
  prompt-toolkit==3.0.30
122
  protobuf==3.19.4
@@ -130,7 +140,7 @@ pyasn1==0.4.8
130
  pyasn1-modules==0.2.8
131
  pycparser==2.21
132
  pycryptodome==3.15.0
133
- pydantic==1.9.1
134
  pydub==0.25.1
135
  Pygments==2.12.0
136
  PyNaCl==1.5.0
@@ -140,15 +150,16 @@ pytest==5.4.3
140
  python-dateutil==2.8.2
141
  python-dotenv==0.20.0
142
  python-multipart==0.0.5
143
- pytz==2022.1
144
  PyYAML==6.0
145
- pyzmq==23.2.0
146
  regex==2022.7.25
147
  requests==2.28.1
148
  requests-oauthlib==1.3.1
149
  resampy==0.4.0
150
  responses==0.18.0
151
  rfc3986==1.5.0
 
152
  rsa==4.9
153
  s3fs==2022.7.1
154
  s3transfer==0.5.2
@@ -167,7 +178,7 @@ snorkel==0.9.9
167
  SoundFile==0.10.3.post1
168
  soupsieve==2.3.2.post1
169
  starlette==0.19.1
170
- tensorboard==2.9.1
171
  tensorboard-data-server==0.6.1
172
  tensorboard-plugin-wit==1.8.1
173
  terminado==0.15.0
@@ -176,8 +187,9 @@ tinycss2==1.1.1
176
  tokenizers==0.12.1
177
  toml==0.10.2
178
  tomli==2.0.1
179
- torch==1.12.1
180
- torchvision==0.13.1
 
181
  tornado==6.2
182
  tqdm==4.64.0
183
  traitlets==5.3.0
@@ -187,6 +199,7 @@ typing_extensions==4.3.0
187
  uc-micro-py==1.0.1
188
  urllib3==1.26.11
189
  uvicorn==0.18.2
 
190
  wcwidth==0.2.5
191
  webencodings==0.5.1
192
  Werkzeug==2.2.2
 
11
  argon2-cffi-bindings==21.2.0
12
  async-timeout==4.0.2
13
  attrs==21.4.0
14
+ audioread==3.0.0
15
  backcall==0.2.0
16
  backoff==1.10.0
17
  bcrypt==3.2.2
 
25
  captum==0.5.0
26
  certifi==2022.6.15
27
  cffi==1.15.1
28
+ cfgv==3.3.1
29
  charset-normalizer==2.1.0
30
  click==8.1.3
31
  cloudpickle==2.1.0
32
+ colossalai==0.1.8
33
+ commonmark==0.9.1
34
  cryptography==37.0.4
35
  cycler==0.11.0
36
  datasets==2.4.0
 
38
  decorator==5.1.1
39
  deepspeed==0.7.0
40
  defusedxml==0.7.1
41
+ diffusers==0.2.2
42
  dill==0.3.5.1
43
+ distlib==0.3.5
44
  entrypoints==0.4
45
+ fabric==2.7.1
46
  fastapi==0.79.0
47
  fastjsonschema==2.16.1
48
  ffmpy==0.3.0
 
56
  gradio==3.1.4
57
  grpcio==1.47.0
58
  h11==0.12.0
59
+ hjson==3.1.0
60
  httpcore==0.15.0
61
  httpx==0.23.0
62
  huggingface-hub==0.8.1
63
+ identify==2.5.3
64
  idna==3.3
65
  importlib-metadata==4.12.0
66
+ invoke==1.7.1
67
  ipykernel==6.15.1
68
  ipython==7.34.0
69
  ipython-genutils==0.2.0
 
85
  Markdown==3.4.1
86
  markdown-it-py==2.1.0
87
  MarkupSafe==2.1.1
88
+ matplotlib==3.5.3
89
  matplotlib-inline==0.1.3
90
  mdit-py-plugins==0.3.0
91
+ mdurl==0.1.2
92
  mistune==0.8.4
93
  monotonic==1.6
94
  more-itertools==8.14.0
 
97
  munkres==1.1.4
98
  mypy-extensions==0.4.3
99
  nbclient==0.6.6
100
+ nbconvert==6.5.3
101
  nbformat==5.4.0
102
  nest-asyncio==1.5.5
103
  networkx==2.8.5
104
  ninja==1.10.2.3
105
  nlp==0.4.0
106
  nltk==3.7
107
+ nodeenv==1.7.0
108
  notebook==6.4.12
109
  numba==0.56.0
110
  numpy==1.22.4
111
  oauthlib==3.2.0
112
+ orjson==3.7.12
113
  packaging==21.3
114
  pandas==1.4.3
115
  pandocfilters==1.5.0
116
  paramiko==2.11.0
117
  parso==0.8.3
118
+ pathlib2==2.3.7.post1
119
  pathos==0.2.9
120
  pathspec==0.9.0
121
  pexpect==4.8.0
 
126
  pooch==1.6.0
127
  pox==0.3.1
128
  ppft==1.7.6.5
129
+ pre-commit==2.20.0
130
  prometheus-client==0.14.1
131
  prompt-toolkit==3.0.30
132
  protobuf==3.19.4
 
140
  pyasn1-modules==0.2.8
141
  pycparser==2.21
142
  pycryptodome==3.15.0
143
+ pydantic==1.9.2
144
  pydub==0.25.1
145
  Pygments==2.12.0
146
  PyNaCl==1.5.0
 
150
  python-dateutil==2.8.2
151
  python-dotenv==0.20.0
152
  python-multipart==0.0.5
153
+ pytz==2022.2.1
154
  PyYAML==6.0
155
+ pyzmq==23.2.1
156
  regex==2022.7.25
157
  requests==2.28.1
158
  requests-oauthlib==1.3.1
159
  resampy==0.4.0
160
  responses==0.18.0
161
  rfc3986==1.5.0
162
+ rich==12.5.1
163
  rsa==4.9
164
  s3fs==2022.7.1
165
  s3transfer==0.5.2
 
178
  SoundFile==0.10.3.post1
179
  soupsieve==2.3.2.post1
180
  starlette==0.19.1
181
+ tensorboard==2.10.0
182
  tensorboard-data-server==0.6.1
183
  tensorboard-plugin-wit==1.8.1
184
  terminado==0.15.0
 
187
  tokenizers==0.12.1
188
  toml==0.10.2
189
  tomli==2.0.1
190
+ torch==1.12.1+cu116
191
+ torchaudio==0.12.1+cu116
192
+ torchvision==0.13.1+cu116
193
  tornado==6.2
194
  tqdm==4.64.0
195
  traitlets==5.3.0
 
199
  uc-micro-py==1.0.1
200
  urllib3==1.26.11
201
  uvicorn==0.18.2
202
+ virtualenv==20.16.3
203
  wcwidth==0.2.5
204
  webencodings==0.5.1
205
  Werkzeug==2.2.2
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
- # for Hugging Face spaces
2
  torch
3
  numpy
4
  Pillow
5
  diffusers
6
  librosa
 
 
1
+ # for Hugging Face Spaces
2
  torch
3
  numpy
4
  Pillow
5
  diffusers
6
  librosa
7
+ datasets
src/train_unconditional.py CHANGED
@@ -241,13 +241,16 @@ def main(args):
241
  if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
242
  # save the model
243
  if args.push_to_hub:
244
- push_to_hub(
245
- args,
246
- pipeline,
247
- repo,
248
- commit_message=f"Epoch {epoch}",
249
- blocking=False,
250
- )
 
 
 
251
  else:
252
  pipeline.save_pretrained(output_dir)
253
  accelerator.wait_for_everyone()
 
241
  if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
242
  # save the model
243
  if args.push_to_hub:
244
+ try:
245
+ push_to_hub(
246
+ args,
247
+ pipeline,
248
+ repo,
249
+ commit_message=f"Epoch {epoch}",
250
+ blocking=False,
251
+ )
252
+ except NameError: # current version of diffusers has a little bug
253
+ pass
254
  else:
255
  pipeline.save_pretrained(output_dir)
256
  accelerator.wait_for_everyone()