Hugo Flores Garcia commited on
Commit
881d56d
·
1 Parent(s): 6f55a79

demo cleanup, onset masks, pitch shifting

Browse files
conf/interface/xeno-canto.yml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Interface.coarse_ckpt: ./runs/xeno-canto-2/coarse/best/vampnet/weights.pth
2
+ Interface.coarse2fine_ckpt: ./runs/xeno-canto-2/c2f/best/vampnet/weights.pth
3
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
4
+ Interface.coarse_chunk_size_s: 10
5
+ Interface.coarse2fine_chunk_size_s: 3
6
+ # Interface.wavebeat_ckpt: ./models/wavebeat.pth
7
+
8
+
9
+ AudioLoader.sources:
10
+ - /media/CHONK/hugo/xeno-canto-2
11
+ - /media/CHONK/hugo/xeno-canto-2
conf/lora/lora.yml CHANGED
@@ -8,9 +8,11 @@ train/AudioDataset.n_examples: 10000000
8
  val/AudioDataset.n_examples: 10
9
 
10
 
11
- NoamScheduler.warmup: 400
12
 
 
 
13
  epoch_length: 100
14
- save_audio_epochs: 2
15
 
16
  AdamW.lr: 0.0001
 
8
  val/AudioDataset.n_examples: 10
9
 
10
 
11
+ NoamScheduler.warmup: 500
12
 
13
+ batch_size: 7
14
+ num_workers: 7
15
  epoch_length: 100
16
+ save_audio_epochs: 4
17
 
18
  AdamW.lr: 0.0001
conf/lora/xeno-canto/c2f.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+
4
+ fine_tune: True
5
+
6
+ train/AudioLoader.sources:
7
+ - /media/CHONK/hugo/xeno-canto-2
8
+
9
+ val/AudioLoader.sources:
10
+ - /media/CHONK/hugo/xeno-canto-2
11
+
12
+
13
+ VampNet.n_codebooks: 14
14
+ VampNet.n_conditioning_codebooks: 4
15
+
16
+ VampNet.embedding_dim: 1280
17
+ VampNet.n_layers: 16
18
+ VampNet.n_heads: 20
19
+
20
+ AudioDataset.duration: 3.0
21
+ AudioDataset.loudness_cutoff: -40.0
conf/lora/xeno-canto/coarse.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+
4
+ fine_tune: True
5
+
6
+ train/AudioLoader.sources:
7
+ - /media/CHONK/hugo/xeno-canto-2
8
+
9
+ val/AudioLoader.sources:
10
+ - /media/CHONK/hugo/xeno-canto-2
conf/vampnet.yml CHANGED
@@ -25,6 +25,9 @@ AdamW.lr: 0.001
25
  NoamScheduler.factor: 2.0
26
  NoamScheduler.warmup: 10000
27
 
 
 
 
28
  VampNet.vocab_size: 1024
29
  VampNet.n_codebooks: 4
30
  VampNet.n_conditioning_codebooks: 0
 
25
  NoamScheduler.factor: 2.0
26
  NoamScheduler.warmup: 10000
27
 
28
+ PitchShift.shift_amount: [const, 0]
29
+ PitchShift.prob: 0.0
30
+
31
  VampNet.vocab_size: 1024
32
  VampNet.n_codebooks: 4
33
  VampNet.n_conditioning_codebooks: 0
demo.py CHANGED
@@ -62,10 +62,13 @@ def load_random_audio():
62
  return sig.path_to_file
63
 
64
 
65
- def vamp(data):
 
66
  print(data[input_audio])
67
  sig = at.AudioSignal(data[input_audio])
68
 
 
 
69
  z = interface.encode(sig)
70
 
71
  ncc = data[n_conditioning_codebooks]
@@ -87,6 +90,11 @@ def vamp(data):
87
  random_roll=True
88
  )
89
  )
 
 
 
 
 
90
  mask = pmask.dropout(mask, data[dropout])
91
  mask = pmask.codebook_unmask(mask, ncc)
92
 
@@ -103,9 +111,6 @@ def vamp(data):
103
  if use_coarse2fine:
104
  zv = interface.coarse_to_fine(zv)
105
 
106
-
107
- mask = interface.to_signal(mask_z).cpu()
108
-
109
  sig = interface.to_signal(zv).cpu()
110
  print("done")
111
 
@@ -113,8 +118,19 @@ def vamp(data):
113
  out_dir.mkdir()
114
 
115
  sig.write(out_dir / "output.wav")
116
- mask.write(out_dir / "mask.wav")
117
- return sig.path_to_file, mask.path_to_file
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  def save_vamp(data):
120
  out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
@@ -198,6 +214,14 @@ with gr.Blocks() as demo:
198
  # mask settings
199
  with gr.Column():
200
 
 
 
 
 
 
 
 
 
201
  rand_mask_intensity = gr.Slider(
202
  label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
203
  minimum=0.0,
@@ -220,6 +244,14 @@ with gr.Blocks() as demo:
220
  value=1,
221
  )
222
 
 
 
 
 
 
 
 
 
223
  with gr.Accordion("extras ", open=False):
224
  n_conditioning_codebooks = gr.Number(
225
  label="number of conditioning codebooks. probably 0",
@@ -322,11 +354,9 @@ with gr.Blocks() as demo:
322
  )
323
 
324
  thank_you = gr.Markdown("")
325
-
326
- # connect widgets
327
- vamp_button.click(
328
- fn=vamp,
329
- inputs={
330
  input_audio,
331
  num_steps,
332
  init_temp, final_temp,
@@ -336,27 +366,29 @@ with gr.Blocks() as demo:
336
  n_conditioning_codebooks,
337
  dropout,
338
  use_coarse2fine,
339
- stretch_factor
340
- },
 
 
 
 
 
 
 
341
  outputs=[output_audio, audio_mask],
 
 
 
 
 
 
 
342
  api_name="vamp"
343
  )
344
 
345
  save_button.click(
346
  fn=save_vamp,
347
- inputs={
348
- input_audio,
349
- num_steps,
350
- init_temp, final_temp,
351
- prefix_s, suffix_s,
352
- rand_mask_intensity,
353
- periodic_p, periodic_w,
354
- n_conditioning_codebooks,
355
- dropout,
356
- use_coarse2fine,
357
- stretch_factor,
358
- notes_text
359
- },
360
  outputs=[thank_you, download_file]
361
  )
362
 
 
62
  return sig.path_to_file
63
 
64
 
65
+ def _vamp(data, return_mask=False):
66
+ print(data)
67
  print(data[input_audio])
68
  sig = at.AudioSignal(data[input_audio])
69
 
70
+ # TODO: random pitch shift of segments in the signal to prompt! window size should be a parameter, pitch shift width should be a parameter
71
+
72
  z = interface.encode(sig)
73
 
74
  ncc = data[n_conditioning_codebooks]
 
90
  random_roll=True
91
  )
92
  )
93
+ if data[onset_mask_width] > 0:
94
+ mask = pmask.mask_or(
95
+ mask, pmask.onset_mask(sig, z, interface, width=data[onset_mask_width])
96
+ )
97
+ # these should be the last two mask ops
98
  mask = pmask.dropout(mask, data[dropout])
99
  mask = pmask.codebook_unmask(mask, ncc)
100
 
 
111
  if use_coarse2fine:
112
  zv = interface.coarse_to_fine(zv)
113
 
 
 
 
114
  sig = interface.to_signal(zv).cpu()
115
  print("done")
116
 
 
118
  out_dir.mkdir()
119
 
120
  sig.write(out_dir / "output.wav")
121
+
122
+ if return_mask:
123
+ mask = interface.to_signal(mask_z).cpu()
124
+ mask.write(out_dir / "mask.wav")
125
+ return sig.path_to_file, mask.path_to_file
126
+ else:
127
+ return sig.path_to_file
128
+
129
+ def vamp(data):
130
+ return _vamp(data, return_mask=True)
131
+
132
+ def api_vamp(data):
133
+ return _vamp(data, return_mask=False)
134
 
135
  def save_vamp(data):
136
  out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
 
214
  # mask settings
215
  with gr.Column():
216
 
217
+ input_pitch_shift = gr.Slider(
218
+ label="input pitch shift (semitones)",
219
+ minimum=-12,
220
+ maximum=12,
221
+ step=1,
222
+ value=0,
223
+ )
224
+
225
  rand_mask_intensity = gr.Slider(
226
  label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
227
  minimum=0.0,
 
244
  value=1,
245
  )
246
 
247
+ onset_mask_width = gr.Slider(
248
+ label="onset mask width (steps, 1 step ~= 10milliseconds)",
249
+ minimum=0,
250
+ maximum=20,
251
+ step=1,
252
+ value=0,
253
+ )
254
+
255
  with gr.Accordion("extras ", open=False):
256
  n_conditioning_codebooks = gr.Number(
257
  label="number of conditioning codebooks. probably 0",
 
354
  )
355
 
356
  thank_you = gr.Markdown("")
357
+
358
+
359
+ _inputs = {
 
 
360
  input_audio,
361
  num_steps,
362
  init_temp, final_temp,
 
366
  n_conditioning_codebooks,
367
  dropout,
368
  use_coarse2fine,
369
+ stretch_factor,
370
+ onset_mask_width,
371
+ input_pitch_shift
372
+ }
373
+
374
+ # connect widgets
375
+ vamp_button.click(
376
+ fn=vamp,
377
+ inputs=_inputs,
378
  outputs=[output_audio, audio_mask],
379
+ )
380
+
381
+ api_vamp_button = gr.Button("api vamp")
382
+ api_vamp_button.click(
383
+ fn=api_vamp,
384
+ inputs=_inputs,
385
+ outputs=[output_audio],
386
  api_name="vamp"
387
  )
388
 
389
  save_button.click(
390
  fn=save_vamp,
391
+ inputs=_inputs | {notes_text},
 
 
 
 
 
 
 
 
 
 
 
 
392
  outputs=[thank_you, download_file]
393
  )
394
 
scripts/exp/train.py CHANGED
@@ -62,6 +62,7 @@ IGNORE_INDEX = -100
62
  def build_transform():
63
  transform = transforms.Compose(
64
  tfm.VolumeNorm(("const", -24)),
 
65
  tfm.RescaleAudio(),
66
  )
67
  return transform
 
62
  def build_transform():
63
  transform = transforms.Compose(
64
  tfm.VolumeNorm(("const", -24)),
65
+ # tfm.PitchShift(),
66
  tfm.RescaleAudio(),
67
  )
68
  return transform
vampnet/mask.py CHANGED
@@ -1,6 +1,7 @@
1
  from typing import Optional
2
 
3
  import torch
 
4
 
5
  from .util import scalar_to_batch_tensor
6
 
@@ -150,7 +151,9 @@ def dropout(
150
  mask: torch.Tensor,
151
  p: float,
152
  ):
153
- return torch.bernoulli((torch.ones_like(mask) * (1-p)).float()).long() * mask
 
 
154
 
155
  def mask_or(
156
  mask1: torch.Tensor,
@@ -166,7 +169,6 @@ def mask_or(
166
  def time_stretch_mask(
167
  x: torch.Tensor,
168
  stretch_factor: int,
169
- mask_token: int
170
  ):
171
  assert stretch_factor >= 1, "stretch factor must be >= 1"
172
  c_seq_len = x.shape[-1]
@@ -176,7 +178,35 @@ def time_stretch_mask(
176
  x = x[:, :, :c_seq_len]
177
 
178
  mask = periodic_mask(x, stretch_factor, width=1)
179
- return apply_mask(x, mask, mask_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
 
182
  if __name__ == "__main__":
 
1
  from typing import Optional
2
 
3
  import torch
4
+ from audiotools import AudioSignal
5
 
6
  from .util import scalar_to_batch_tensor
7
 
 
151
  mask: torch.Tensor,
152
  p: float,
153
  ):
154
+ # negate the mask (we want the 0s to be 1s, since we want to drop the prompt, not the mask)
155
+ mask = (~(mask.bool())).long()
156
+ return torch.nn.functional.dropout(mask.float(), p=p, training=True).long().bool().long()
157
 
158
  def mask_or(
159
  mask1: torch.Tensor,
 
169
  def time_stretch_mask(
170
  x: torch.Tensor,
171
  stretch_factor: int,
 
172
  ):
173
  assert stretch_factor >= 1, "stretch factor must be >= 1"
174
  c_seq_len = x.shape[-1]
 
178
  x = x[:, :, :c_seq_len]
179
 
180
  mask = periodic_mask(x, stretch_factor, width=1)
181
+ return mask
182
+
183
+ def onset_mask(
184
+ sig: AudioSignal,
185
+ z: torch.Tensor,
186
+ interface,
187
+ width: int = 1
188
+ ):
189
+ import librosa
190
+
191
+ onset_indices = librosa.onset.onset_detect(
192
+ y=sig.clone().to_mono().samples.cpu().numpy()[0, 0],
193
+ sr=sig.sample_rate,
194
+ hop_length=interface.codec.hop_length
195
+ )
196
+
197
+ # create a mask, set onset
198
+ mask = torch.ones_like(z)
199
+ n_timesteps = z.shape[-1]
200
+
201
+ for onset_index in onset_indices:
202
+ onset_index = min(onset_index, n_timesteps - 1)
203
+ onset_index = max(onset_index, 0)
204
+ mask[:, :, onset_index - width:onset_index + width] = 0.0
205
+
206
+ print(mask)
207
+
208
+ return mask
209
+
210
 
211
 
212
  if __name__ == "__main__":
vampnet/modules/transformer.py CHANGED
@@ -62,8 +62,8 @@ class FeedForward(nn.Module):
62
  ):
63
  super().__init__()
64
  factor = 2 if activation == "geglu" else 1
65
- self.w_1 = nn.Linear(d_model, d_model * 4, bias=False)
66
- self.w_2 = nn.Linear(d_model * 4 // factor, d_model, bias=False)
67
  self.drop = nn.Dropout(dropout)
68
  self.act = get_activation(activation)()
69
 
@@ -109,7 +109,7 @@ class MultiHeadRelativeAttention(nn.Module):
109
  self.w_vs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
110
 
111
  # Create linear final output projection
112
- self.fc = nn.Linear(d_model, d_model, bias=False)
113
 
114
  # Dropout for attention output weights
115
  self.dropout = nn.Dropout(dropout)
 
62
  ):
63
  super().__init__()
64
  factor = 2 if activation == "geglu" else 1
65
+ self.w_1 = lora.Linear(d_model, d_model * 4, bias=False, r=LORA_R)
66
+ self.w_2 = lora.Linear(d_model * 4 // factor, d_model, bias=False, r=LORA_R)
67
  self.drop = nn.Dropout(dropout)
68
  self.act = get_activation(activation)()
69
 
 
109
  self.w_vs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
110
 
111
  # Create linear final output projection
112
+ self.fc = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
113
 
114
  # Dropout for attention output weights
115
  self.dropout = nn.Dropout(dropout)