teticio commited on
Commit
ea68dfd
1 Parent(s): 63ee254

add out-painting feature

Browse files
README.md CHANGED
@@ -15,8 +15,12 @@ license: gpl-3.0
15
 
16
  ---
17
 
18
- **UPDATE**:
19
 
 
 
 
 
20
  You can now generate an audio based on a previous one. You can use this to generate variations of the same audio or even to "remix" a track (via a sort of "style transfer"). You can find examples of how to do this in the [`test_model.ipynb`](https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/test_model.ipynb) notebook.
21
 
22
  ---
 
15
 
16
  ---
17
 
18
+ **UPDATES**:
19
 
20
+ 4/10/2022
21
+ It is now possible to mask parts of the input audio during generation which means you can stitch several samples together (think "out-painting").
22
+
23
+ 27/9/2022
24
  You can now generate an audio based on a previous one. You can use this to generate variations of the same audio or even to "remix" a track (via a sort of "style transfer"). You can find examples of how to do this in the [`test_model.ipynb`](https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/test_model.ipynb) notebook.
25
 
26
  ---
audiodiffusion/__init__.py CHANGED
@@ -9,7 +9,7 @@ from diffusers import DDPMPipeline, DDPMScheduler
9
 
10
  from .mel import Mel
11
 
12
- VERSION = "1.1.3"
13
 
14
 
15
  class AudioDiffusion:
@@ -61,7 +61,9 @@ class AudioDiffusion:
61
  slice: int = 0,
62
  start_step: int = 0,
63
  steps: int = None,
64
- generator: torch.Generator = None
 
 
65
  ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
66
  """Generate random mel spectrogram from audio input and convert to audio.
67
 
@@ -72,6 +74,8 @@ class AudioDiffusion:
72
  start_step (int): step to start from
73
  steps (int): number of de-noising steps to perform (defaults to num_train_timesteps)
74
  generator (torch.Generator): random number generator or None
 
 
75
 
76
  Returns:
77
  PIL Image: mel spectrogram
@@ -84,31 +88,51 @@ class AudioDiffusion:
84
  steps = self.ddpm.scheduler.num_train_timesteps
85
  scheduler = DDPMScheduler(num_train_timesteps=steps)
86
  scheduler.set_timesteps(steps)
87
- images = torch.randn(
 
88
  (1, self.ddpm.unet.in_channels, self.ddpm.unet.sample_size,
89
  self.ddpm.unet.sample_size),
90
  generator=generator,
91
  )
 
92
  if audio_file is not None or raw_audio is not None:
93
  self.mel.load_audio(audio_file, raw_audio)
94
  input_image = self.mel.audio_slice_to_image(slice)
95
  input_image = np.frombuffer(input_image.tobytes(),
96
  dtype="uint8").reshape(
97
- (input_image.width,
98
- input_image.height))
99
  input_image = ((input_image / 255) * 2 - 1)
 
100
  if start_step > 0:
101
- images[0][0] = scheduler.add_noise(
102
  torch.tensor(input_image[np.newaxis, np.newaxis, :]),
103
- images, steps - start_step)
 
 
 
 
 
 
 
 
104
 
105
  images = images.to(self.ddpm.device)
106
- for t in self.progress_bar(scheduler.timesteps[start_step:]):
 
107
  model_output = self.ddpm.unet(images, t)['sample']
108
  images = scheduler.step(model_output,
109
  t,
110
  images,
111
  generator=generator)['prev_sample']
 
 
 
 
 
 
 
 
112
  images = (images / 2 + 0.5).clamp(0, 1)
113
  images = images.cpu().permute(0, 2, 3, 1).numpy()
114
 
 
9
 
10
  from .mel import Mel
11
 
12
+ VERSION = "1.1.4"
13
 
14
 
15
  class AudioDiffusion:
 
61
  slice: int = 0,
62
  start_step: int = 0,
63
  steps: int = None,
64
+ generator: torch.Generator = None,
65
+ mask_start_secs: float = 0,
66
+ mask_end_secs: float = 0
67
  ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
68
  """Generate random mel spectrogram from audio input and convert to audio.
69
 
 
74
  start_step (int): step to start from
75
  steps (int): number of de-noising steps to perform (defaults to num_train_timesteps)
76
  generator (torch.Generator): random number generator or None
77
+ mask_start_secs (float): number of seconds of audio to mask (not generate) at start
78
+ mask_end_secs (float): number of seconds of audio to mask (not generate) at end
79
 
80
  Returns:
81
  PIL Image: mel spectrogram
 
88
  steps = self.ddpm.scheduler.num_train_timesteps
89
  scheduler = DDPMScheduler(num_train_timesteps=steps)
90
  scheduler.set_timesteps(steps)
91
+ mask = None
92
+ images = noise = torch.randn(
93
  (1, self.ddpm.unet.in_channels, self.ddpm.unet.sample_size,
94
  self.ddpm.unet.sample_size),
95
  generator=generator,
96
  )
97
+
98
  if audio_file is not None or raw_audio is not None:
99
  self.mel.load_audio(audio_file, raw_audio)
100
  input_image = self.mel.audio_slice_to_image(slice)
101
  input_image = np.frombuffer(input_image.tobytes(),
102
  dtype="uint8").reshape(
103
+ (input_image.height,
104
+ input_image.width))
105
  input_image = ((input_image / 255) * 2 - 1)
106
+
107
  if start_step > 0:
108
+ images[0, 0] = scheduler.add_noise(
109
  torch.tensor(input_image[np.newaxis, np.newaxis, :]),
110
+ noise, steps - start_step)
111
+
112
+ mask_start = int(mask_start_secs * self.mel.get_sample_rate() /
113
+ self.mel.hop_length)
114
+ mask_end = int(mask_end_secs * self.mel.get_sample_rate() /
115
+ self.mel.hop_length)
116
+ mask = scheduler.add_noise(
117
+ torch.tensor(input_image[np.newaxis, np.newaxis, :]), noise,
118
+ scheduler.timesteps[start_step:])
119
 
120
  images = images.to(self.ddpm.device)
121
+ for step, t in enumerate(
122
+ self.progress_bar(scheduler.timesteps[start_step:])):
123
  model_output = self.ddpm.unet(images, t)['sample']
124
  images = scheduler.step(model_output,
125
  t,
126
  images,
127
  generator=generator)['prev_sample']
128
+
129
+ if mask is not None:
130
+ if mask_start > 0:
131
+ images[0, 0, :, :mask_start] = mask[step,
132
+ 0, :, :mask_start]
133
+ if mask_end > 0:
134
+ images[0, 0, :, -mask_end:] = mask[step, 0, :, -mask_end:]
135
+
136
  images = (images / 2 + 0.5).clamp(0, 1)
137
  images = images.cpu().permute(0, 2, 3, 1).numpy()
138
 
audiodiffusion/mel.py CHANGED
@@ -37,7 +37,7 @@ class Mel:
37
  self.slice_size = self.x_res * self.hop_length - 1
38
  self.fmax = self.sr / 2
39
  self.top_db = top_db
40
- self.y = None
41
 
42
  def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None):
43
  """Load audio.
@@ -47,11 +47,16 @@ class Mel:
47
  raw_audio (np.ndarray): audio as numpy array
48
  """
49
  if audio_file is not None:
50
- self.y, _ = librosa.load(
51
- audio_file,
52
- mono=True)
53
  else:
54
- self.y = raw_audio
 
 
 
 
 
 
 
55
 
56
  def get_number_of_slices(self) -> int:
57
  """Get number of slices in audio.
@@ -59,7 +64,7 @@ class Mel:
59
  Returns:
60
  int: number of spectograms audio can be sliced into
61
  """
62
- return len(self.y) // self.slice_size
63
 
64
  def get_audio_slice(self, slice: int = 0) -> np.ndarray:
65
  """Get slice of audio.
@@ -70,7 +75,8 @@ class Mel:
70
  Returns:
71
  np.ndarray: audio as numpy array
72
  """
73
- return self.y[self.slice_size * slice:self.slice_size * (slice + 1)]
 
74
 
75
  def get_sample_rate(self) -> int:
76
  """Get sample rate:
 
37
  self.slice_size = self.x_res * self.hop_length - 1
38
  self.fmax = self.sr / 2
39
  self.top_db = top_db
40
+ self.audio = None
41
 
42
  def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None):
43
  """Load audio.
 
47
  raw_audio (np.ndarray): audio as numpy array
48
  """
49
  if audio_file is not None:
50
+ self.audio, _ = librosa.load(audio_file, mono=True)
 
 
51
  else:
52
+ self.audio = raw_audio
53
+
54
+ # Pad with silence if necessary.
55
+ if len(self.audio) < self.x_res * self.hop_length:
56
+ self.audio = np.concatenate([
57
+ self.audio,
58
+ np.zeros((self.x_res * self.hop_length - len(self.audio), ))
59
+ ])
60
 
61
  def get_number_of_slices(self) -> int:
62
  """Get number of slices in audio.
 
64
  Returns:
65
  int: number of spectograms audio can be sliced into
66
  """
67
+ return len(self.audio) // self.slice_size
68
 
69
  def get_audio_slice(self, slice: int = 0) -> np.ndarray:
70
  """Get slice of audio.
 
75
  Returns:
76
  np.ndarray: audio as numpy array
77
  """
78
+ return self.audio[self.slice_size * slice:self.slice_size *
79
+ (slice + 1)]
80
 
81
  def get_sample_rate(self) -> int:
82
  """Get sample rate:
notebooks/test_model.ipynb CHANGED
@@ -10,7 +10,7 @@
10
  },
11
  {
12
  "cell_type": "code",
13
- "execution_count": 25,
14
  "id": "6c7800a6",
15
  "metadata": {},
16
  "outputs": [],
@@ -27,7 +27,7 @@
27
  },
28
  {
29
  "cell_type": "code",
30
- "execution_count": 26,
31
  "id": "b447e2c4",
32
  "metadata": {},
33
  "outputs": [],
@@ -39,7 +39,7 @@
39
  },
40
  {
41
  "cell_type": "code",
42
- "execution_count": 41,
43
  "id": "c2fc0e7a",
44
  "metadata": {},
45
  "outputs": [],
@@ -63,7 +63,7 @@
63
  },
64
  {
65
  "cell_type": "code",
66
- "execution_count": 28,
67
  "id": "97f24046",
68
  "metadata": {},
69
  "outputs": [],
@@ -79,7 +79,7 @@
79
  },
80
  {
81
  "cell_type": "code",
82
- "execution_count": 29,
83
  "id": "a3d45c36",
84
  "metadata": {},
85
  "outputs": [],
@@ -169,6 +169,39 @@
169
  "display(Audio(track, rate=sample_rate))"
170
  ]
171
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  {
173
  "cell_type": "markdown",
174
  "id": "b6434d3f",
@@ -182,12 +215,12 @@
182
  "id": "0da030b2",
183
  "metadata": {},
184
  "source": [
185
- "Alternatively, you can start from another audio altogether, resulting in a kind of style transfer."
186
  ]
187
  },
188
  {
189
  "cell_type": "code",
190
- "execution_count": 50,
191
  "id": "fc620a80",
192
  "metadata": {},
193
  "outputs": [],
@@ -207,41 +240,31 @@
207
  "metadata": {
208
  "scrolled": true
209
  },
210
- "outputs": [
211
- {
212
- "data": {
213
- "application/vnd.jupyter.widget-view+json": {
214
- "model_id": "6e741e6bd196458fa38f86197bd16378",
215
- "version_major": 2,
216
- "version_minor": 0
217
- },
218
- "text/plain": [
219
- " 0%| | 0/500 [00:00<?, ?it/s]"
220
- ]
221
- },
222
- "metadata": {},
223
- "output_type": "display_data"
224
- }
225
- ],
226
  "source": [
227
- "start_steps = 500 #@param {type:\"slider\", min:0, max:1000, step:10}\n",
228
- "audio_diffusion.mel.load_audio(audio_file)\n",
229
- "track = np.array([])\n",
 
 
 
230
  "generator = torch.Generator()\n",
231
  "seed = generator.seed()\n",
232
- "for slice in range(audio_diffusion.mel.get_number_of_slices()):\n",
 
233
  " generator.manual_seed(seed)\n",
234
- " audio = audio_diffusion.mel.get_audio_slice(slice)\n",
235
- " _, (\n",
236
- " sample_rate, audio2\n",
237
- " ) = audio_diffusion.generate_spectrogram_and_audio_from_audio(\n",
238
- " audio_file=audio_file,\n",
239
- " slice=slice,\n",
240
- " start_step=start_steps,\n",
241
- " generator=generator)\n",
 
242
  " display(Audio(audio, rate=sample_rate))\n",
243
  " display(Audio(audio2, rate=sample_rate))\n",
244
- " track = np.concatenate([track, audio2])"
245
  ]
246
  },
247
  {
@@ -307,7 +330,17 @@
307
  {
308
  "cell_type": "code",
309
  "execution_count": null,
310
- "id": "c59bcc0f",
 
 
 
 
 
 
 
 
 
 
311
  "metadata": {},
312
  "outputs": [],
313
  "source": []
@@ -334,7 +367,7 @@
334
  "name": "python",
335
  "nbconvert_exporter": "python",
336
  "pygments_lexer": "ipython3",
337
- "version": "3.10.4"
338
  },
339
  "toc": {
340
  "base_numbering": 1,
 
10
  },
11
  {
12
  "cell_type": "code",
13
+ "execution_count": null,
14
  "id": "6c7800a6",
15
  "metadata": {},
16
  "outputs": [],
 
27
  },
28
  {
29
  "cell_type": "code",
30
+ "execution_count": null,
31
  "id": "b447e2c4",
32
  "metadata": {},
33
  "outputs": [],
 
39
  },
40
  {
41
  "cell_type": "code",
42
+ "execution_count": null,
43
  "id": "c2fc0e7a",
44
  "metadata": {},
45
  "outputs": [],
 
63
  },
64
  {
65
  "cell_type": "code",
66
+ "execution_count": null,
67
  "id": "97f24046",
68
  "metadata": {},
69
  "outputs": [],
 
79
  },
80
  {
81
  "cell_type": "code",
82
+ "execution_count": null,
83
  "id": "a3d45c36",
84
  "metadata": {},
85
  "outputs": [],
 
169
  "display(Audio(track, rate=sample_rate))"
170
  ]
171
  },
172
+ {
173
+ "cell_type": "markdown",
174
+ "id": "11f91ad3",
175
+ "metadata": {},
176
+ "source": [
177
+ "### Generate continuations (\"out-painting\")"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "id": "756d7af5",
184
+ "metadata": {},
185
+ "outputs": [],
186
+ "source": [
187
+ "overlap_secs = 2 #@param {type:\"integer\"}\n",
188
+ "start_step = 0 #@param {type:\"slider\", min:0, max:1000, step:10}\n",
189
+ "overlap_samples = overlap_secs * sample_rate\n",
190
+ "track = audio\n",
191
+ "for variation in range(12):\n",
192
+ " image2, (\n",
193
+ " sample_rate, audio2\n",
194
+ " ) = audio_diffusion.generate_spectrogram_and_audio_from_audio(\n",
195
+ " raw_audio=audio[-overlap_samples:],\n",
196
+ " start_step=start_step,\n",
197
+ " mask_start_secs=overlap_secs)\n",
198
+ " display(image2)\n",
199
+ " display(Audio(audio2, rate=sample_rate))\n",
200
+ " track = np.concatenate([track, audio2[overlap_samples:]])\n",
201
+ " audio = audio2\n",
202
+ "display(Audio(track, rate=sample_rate))"
203
+ ]
204
+ },
205
  {
206
  "cell_type": "markdown",
207
  "id": "b6434d3f",
 
215
  "id": "0da030b2",
216
  "metadata": {},
217
  "source": [
218
+ "Alternatively, you can start from another audio altogether, resulting in a kind of style transfer. Maintaining the same seed during generation fixes the style, while masking helps stitch consecutive segments together more smoothly."
219
  ]
220
  },
221
  {
222
  "cell_type": "code",
223
+ "execution_count": null,
224
  "id": "fc620a80",
225
  "metadata": {},
226
  "outputs": [],
 
240
  "metadata": {
241
  "scrolled": true
242
  },
243
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  "source": [
245
+ "start_step = 500 #@param {type:\"slider\", min:0, max:1000, step:10}\n",
246
+ "overlap_secs = 1 #@param {type:\"integer\"}\n",
247
+ "overlap_samples = overlap_secs * sample_rate\n",
248
+ "mel.load_audio(audio_file)\n",
249
+ "slice_size = audio_diffusion.mel.x_res * audio_diffusion.mel.hop_length\n",
250
+ "stride = slice_size - overlap_samples\n",
251
  "generator = torch.Generator()\n",
252
  "seed = generator.seed()\n",
253
+ "track = np.array([])\n",
254
+ "for sample in range(len(mel.audio) // stride):\n",
255
  " generator.manual_seed(seed)\n",
256
+ " audio = mel.audio[sample * stride:sample * stride + slice_size]\n",
257
+ " if len(track) > 0:\n",
258
+ " audio[:overlap_samples] = audio2[-overlap_samples:]\n",
259
+ " _, (sample_rate,\n",
260
+ " audio2) = audio_diffusion.generate_spectrogram_and_audio_from_audio(\n",
261
+ " raw_audio=audio,\n",
262
+ " start_step=start_step,\n",
263
+ " generator=generator,\n",
264
+ " mask_start_secs=1 if len(track) > 0 else 0)\n",
265
  " display(Audio(audio, rate=sample_rate))\n",
266
  " display(Audio(audio2, rate=sample_rate))\n",
267
+ " track = np.concatenate([track, audio2[overlap_samples:]])"
268
  ]
269
  },
270
  {
 
330
  {
331
  "cell_type": "code",
332
  "execution_count": null,
333
+ "id": "df112a72",
334
+ "metadata": {},
335
+ "outputs": [],
336
+ "source": [
337
+ "len(audio) / mel.hop_length"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": null,
343
+ "id": "ad467206",
344
  "metadata": {},
345
  "outputs": [],
346
  "source": []
 
367
  "name": "python",
368
  "nbconvert_exporter": "python",
369
  "pygments_lexer": "ipython3",
370
+ "version": "3.10.6"
371
  },
372
  "toc": {
373
  "base_numbering": 1,