Hugo Flores Garcia commited on
Commit
e288b6f
2 Parent(s): 5235095 ebb1173

Merge branch 'ismir' into pr/9

Browse files
.gitignore CHANGED
@@ -179,9 +179,6 @@ models/
179
  samples*/
180
  models-all/
181
  models.zip
182
- audiotools/
183
- descript-audio-codec/
184
- # *.pth
185
  .git-old
186
  conf/generated/*
187
  runs*/
 
179
  samples*/
180
  models-all/
181
  models.zip
 
 
 
182
  .git-old
183
  conf/generated/*
184
  runs*/
README.md CHANGED
@@ -18,7 +18,15 @@ you can try vampnet in a co-creative looper called unloop. see this link: https:
18
 
19
  # Setting up
20
 
21
- Requires Python 3.9 or later.
 
 
 
 
 
 
 
 
22
 
23
 
24
  install VampNet
@@ -91,7 +99,7 @@ python scripts/exp/train.py --args.load conf/<fine_tune_name>/c2f.yml
91
 
92
  launch the interface:
93
  ```bash
94
- python demo.py --args.load conf/generated/<fine_tune_name>/interface.yml
95
  ```
96
 
97
 
 
18
 
19
  # Setting up
20
 
21
+ **Requires Python 3.9**.
22
+
23
+ you'll need a Python 3.9 environment to run VampNet. This is due to a [known issue with madmom](https://github.com/hugofloresgarcia/vampnet/issues/15).
24
+
25
+ (for example, using conda)
26
+ ```bash
27
+ conda create -n vampnet python=3.9
28
+ conda activate vampnet
29
+ ```
30
 
31
 
32
  install VampNet
 
99
 
100
  launch the interface:
101
  ```bash
102
+ python app.py --args.load conf/generated/<fine_tune_name>/interface.yml
103
  ```
104
 
105
 
app.py CHANGED
@@ -18,16 +18,37 @@ from vampnet import mask as pmask
18
  # Interface = argbind.bind(Interface)
19
  # AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
20
 
21
- interface = Interface(
22
- coarse_ckpt="./models/vampnet/coarse.pth",
23
- coarse2fine_ckpt="./models/vampnet/c2f.pth",
24
- codec_ckpt="./models/vampnet/codec.pth",
25
- wavebeat_ckpt="./models/wavebeat.pth",
26
- device="cuda" if torch.cuda.is_available() else "cpu",
27
- )
28
 
29
  # loader = AudioLoader()
30
  print(f"interface device is {interface.device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # dataset = at.data.datasets.AudioDataset(
33
  # loader,
@@ -50,7 +71,7 @@ def load_audio(file):
50
  )
51
  sig = interface.preprocess(sig)
52
 
53
- out_dir = OUT_DIR / str(uuid.uuid4())
54
  out_dir.mkdir(parents=True, exist_ok=True)
55
  sig.write(out_dir / "input.wav")
56
  return sig.path_to_file
@@ -68,6 +89,10 @@ def _vamp(data, return_mask=False):
68
  out_dir = OUT_DIR / str(uuid.uuid4())
69
  out_dir.mkdir()
70
  sig = at.AudioSignal(data[input_audio])
 
 
 
 
71
 
72
  z = interface.encode(sig)
73
 
@@ -107,7 +132,27 @@ def _vamp(data, return_mask=False):
107
  mask = pmask.codebook_unmask(mask, ncc)
108
 
109
 
110
- print(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  _top_p = data[top_p] if data[top_p] > 0 else None
112
  # save the mask as a txt file
113
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
@@ -126,6 +171,7 @@ def _vamp(data, return_mask=False):
126
  top_p=_top_p,
127
  gen_fn=interface.coarse.generate,
128
  seed=_seed,
 
129
  )
130
 
131
  if use_coarse2fine:
@@ -134,7 +180,8 @@ def _vamp(data, return_mask=False):
134
  mask_temperature=data[masktemp]*10,
135
  sampling_temperature=data[sampletemp],
136
  mask=mask,
137
- sampling_steps=data[num_steps],
 
138
  seed=_seed,
139
  )
140
 
@@ -183,6 +230,7 @@ def save_vamp(data):
183
  "use_coarse2fine": data[use_coarse2fine],
184
  "stretch_factor": data[stretch_factor],
185
  "seed": data[seed],
 
186
  }
187
 
188
  # save with yaml
@@ -265,29 +313,41 @@ with gr.Blocks() as demo:
265
  "beat_mask_downbeats": False,
266
  },
267
  "slight periodic variation": {
268
- "periodic_p": 7,
269
- "onset_mask_width": 0,
270
  "beat_mask_width": 0,
271
  "beat_mask_downbeats": False,
272
  },
273
- "strong periodic variation": {
274
  "periodic_p": 13,
275
  "onset_mask_width": 5,
276
  "beat_mask_width": 0,
277
  "beat_mask_downbeats": False,
278
  },
279
- "very strong periodic variation": {
280
  "periodic_p": 17,
281
  "onset_mask_width": 5,
282
  "beat_mask_width": 0,
283
  "beat_mask_downbeats": False,
284
  },
 
 
 
 
 
 
285
  "beat-driven variation": {
286
  "periodic_p": 0,
287
  "onset_mask_width": 0,
288
- "beat_mask_width": 20,
289
  "beat_mask_downbeats": False,
290
  },
 
 
 
 
 
 
291
  "beat-driven variation (downbeats only, strong)": {
292
  "periodic_p": 0,
293
  "onset_mask_width": 0,
@@ -309,14 +369,14 @@ with gr.Blocks() as demo:
309
  minimum=0,
310
  maximum=128,
311
  step=1,
312
- value=13,
313
  )
314
 
315
 
316
  onset_mask_width = gr.Slider(
317
  label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
318
  minimum=0,
319
- maximum=20,
320
  step=1,
321
  value=5,
322
  )
@@ -334,6 +394,14 @@ with gr.Blocks() as demo:
334
 
335
 
336
  with gr.Accordion("extras ", open=False):
 
 
 
 
 
 
 
 
337
  rand_mask_intensity = gr.Slider(
338
  label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
339
  minimum=0.0,
@@ -396,14 +464,15 @@ with gr.Blocks() as demo:
396
  masktemp = gr.Slider(
397
  label="mask temperature",
398
  minimum=0.0,
399
- maximum=10.0,
400
  value=1.5
401
  )
402
  sampletemp = gr.Slider(
403
  label="sample temperature",
404
  minimum=0.1,
405
- maximum=2.0,
406
- value=1.0
 
407
  )
408
 
409
 
@@ -419,7 +488,7 @@ with gr.Blocks() as demo:
419
  label="typical filtering ",
420
  value=False
421
  )
422
- typical_mass = gr.Slider(
423
  label="typical mass (should probably stay between 0.1 and 0.5)",
424
  minimum=0.01,
425
  maximum=0.99,
@@ -432,6 +501,13 @@ with gr.Blocks() as demo:
432
  step=1,
433
  value=64
434
  )
 
 
 
 
 
 
 
435
 
436
  use_coarse2fine = gr.Checkbox(
437
  label="use coarse2fine",
@@ -506,8 +582,11 @@ with gr.Blocks() as demo:
506
  typical_mass,
507
  typical_min_tokens,
508
  beat_mask_width,
 
509
  seed,
510
- beat_mask_downbeats
 
 
511
  }
512
 
513
  # connect widgets
 
18
  # Interface = argbind.bind(Interface)
19
  # AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
20
 
21
+
 
 
 
 
 
 
22
 
23
  # loader = AudioLoader()
24
  print(f"interface device is {interface.device}")
25
+ Interface = argbind.bind(Interface)
26
+ # AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
27
+
28
+ conf = argbind.parse_args()
29
+
30
+
31
+ from torch_pitch_shift import pitch_shift, get_fast_shifts
32
+ def shift_pitch(signal, interval: int):
33
+ signal.samples = pitch_shift(
34
+ signal.samples,
35
+ shift=interval,
36
+ sample_rate=signal.sample_rate
37
+ )
38
+ return signal
39
+
40
+ def load_interface():
41
+ interface = Interface(
42
+ coarse_ckpt="./models/vampnet/coarse.pth",
43
+ coarse2fine_ckpt="./models/vampnet/c2f.pth",
44
+ codec_ckpt="./models/vampnet/codec.pth",
45
+ wavebeat_ckpt="./models/wavebeat.pth",
46
+ device="cuda" if torch.cuda.is_available() else "cpu",
47
+ )
48
+ return interface
49
+
50
+
51
+ interface = load_interface()
52
 
53
  # dataset = at.data.datasets.AudioDataset(
54
  # loader,
 
71
  )
72
  sig = interface.preprocess(sig)
73
 
74
+ out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
75
  out_dir.mkdir(parents=True, exist_ok=True)
76
  sig.write(out_dir / "input.wav")
77
  return sig.path_to_file
 
89
  out_dir = OUT_DIR / str(uuid.uuid4())
90
  out_dir.mkdir()
91
  sig = at.AudioSignal(data[input_audio])
92
+ sig = interface.preprocess(sig)
93
+
94
+ if data[pitch_shift_amt] != 0:
95
+ sig = shift_pitch(sig, data[pitch_shift_amt])
96
 
97
  z = interface.encode(sig)
98
 
 
132
  mask = pmask.codebook_unmask(mask, ncc)
133
 
134
 
135
+ print(f"dropout {data[dropout]}")
136
+ print(f"masktemp {data[masktemp]}")
137
+ print(f"sampletemp {data[sampletemp]}")
138
+ print(f"top_p {data[top_p]}")
139
+ print(f"prefix_s {data[prefix_s]}")
140
+ print(f"suffix_s {data[suffix_s]}")
141
+ print(f"rand_mask_intensity {data[rand_mask_intensity]}")
142
+ print(f"num_steps {data[num_steps]}")
143
+ print(f"periodic_p {data[periodic_p]}")
144
+ print(f"periodic_w {data[periodic_w]}")
145
+ print(f"n_conditioning_codebooks {data[n_conditioning_codebooks]}")
146
+ print(f"use_coarse2fine {data[use_coarse2fine]}")
147
+ print(f"onset_mask_width {data[onset_mask_width]}")
148
+ print(f"beat_mask_width {data[beat_mask_width]}")
149
+ print(f"beat_mask_downbeats {data[beat_mask_downbeats]}")
150
+ print(f"stretch_factor {data[stretch_factor]}")
151
+ print(f"seed {data[seed]}")
152
+ print(f"pitch_shift_amt {data[pitch_shift_amt]}")
153
+ print(f"sample_cutoff {data[sample_cutoff]}")
154
+
155
+
156
  _top_p = data[top_p] if data[top_p] > 0 else None
157
  # save the mask as a txt file
158
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
 
171
  top_p=_top_p,
172
  gen_fn=interface.coarse.generate,
173
  seed=_seed,
174
+ sample_cutoff=data[sample_cutoff],
175
  )
176
 
177
  if use_coarse2fine:
 
180
  mask_temperature=data[masktemp]*10,
181
  sampling_temperature=data[sampletemp],
182
  mask=mask,
183
+ sampling_steps=data[num_steps],
184
+ sample_cutoff=data[sample_cutoff],
185
  seed=_seed,
186
  )
187
 
 
230
  "use_coarse2fine": data[use_coarse2fine],
231
  "stretch_factor": data[stretch_factor],
232
  "seed": data[seed],
233
+ "samplecutoff": data[sample_cutoff],
234
  }
235
 
236
  # save with yaml
 
313
  "beat_mask_downbeats": False,
314
  },
315
  "slight periodic variation": {
316
+ "periodic_p": 5,
317
+ "onset_mask_width": 5,
318
  "beat_mask_width": 0,
319
  "beat_mask_downbeats": False,
320
  },
321
+ "moderate periodic variation": {
322
  "periodic_p": 13,
323
  "onset_mask_width": 5,
324
  "beat_mask_width": 0,
325
  "beat_mask_downbeats": False,
326
  },
327
+ "strong periodic variation": {
328
  "periodic_p": 17,
329
  "onset_mask_width": 5,
330
  "beat_mask_width": 0,
331
  "beat_mask_downbeats": False,
332
  },
333
+ "very strong periodic variation": {
334
+ "periodic_p": 21,
335
+ "onset_mask_width": 5,
336
+ "beat_mask_width": 0,
337
+ "beat_mask_downbeats": False,
338
+ },
339
  "beat-driven variation": {
340
  "periodic_p": 0,
341
  "onset_mask_width": 0,
342
+ "beat_mask_width": 50,
343
  "beat_mask_downbeats": False,
344
  },
345
+ "beat-driven variation (downbeats only)": {
346
+ "periodic_p": 0,
347
+ "onset_mask_width": 0,
348
+ "beat_mask_width": 50,
349
+ "beat_mask_downbeats": True,
350
+ },
351
  "beat-driven variation (downbeats only, strong)": {
352
  "periodic_p": 0,
353
  "onset_mask_width": 0,
 
369
  minimum=0,
370
  maximum=128,
371
  step=1,
372
+ value=5,
373
  )
374
 
375
 
376
  onset_mask_width = gr.Slider(
377
  label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
378
  minimum=0,
379
+ maximum=100,
380
  step=1,
381
  value=5,
382
  )
 
394
 
395
 
396
  with gr.Accordion("extras ", open=False):
397
+ pitch_shift_amt = gr.Slider(
398
+ label="pitch shift amount (semitones)",
399
+ minimum=-12,
400
+ maximum=12,
401
+ step=1,
402
+ value=0,
403
+ )
404
+
405
  rand_mask_intensity = gr.Slider(
406
  label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
407
  minimum=0.0,
 
464
  masktemp = gr.Slider(
465
  label="mask temperature",
466
  minimum=0.0,
467
+ maximum=100.0,
468
  value=1.5
469
  )
470
  sampletemp = gr.Slider(
471
  label="sample temperature",
472
  minimum=0.1,
473
+ maximum=10.0,
474
+ value=1.0,
475
+ step=0.001
476
  )
477
 
478
 
 
488
  label="typical filtering ",
489
  value=False
490
  )
491
+ typical_mass = gr.Slider(
492
  label="typical mass (should probably stay between 0.1 and 0.5)",
493
  minimum=0.01,
494
  maximum=0.99,
 
501
  step=1,
502
  value=64
503
  )
504
+ sample_cutoff = gr.Slider(
505
+ label="sample cutoff",
506
+ minimum=0.0,
507
+ maximum=1.0,
508
+ value=0.5,
509
+ step=0.01
510
+ )
511
 
512
  use_coarse2fine = gr.Checkbox(
513
  label="use coarse2fine",
 
582
  typical_mass,
583
  typical_min_tokens,
584
  beat_mask_width,
585
+ beat_mask_downbeats,
586
  seed,
587
+ # lora_choice,
588
+ pitch_shift_amt,
589
+ sample_cutoff
590
  }
591
 
592
  # connect widgets
conf/lora/lora.yml CHANGED
@@ -4,14 +4,16 @@ $include:
4
  fine_tune: True
5
 
6
  train/AudioDataset.n_examples: 100000000
7
- val/AudioDataset.n_examples: 100
8
 
9
 
10
  NoamScheduler.warmup: 500
11
 
12
  batch_size: 7
13
  num_workers: 7
14
- save_iters: [100000, 200000, 300000, 4000000, 500000]
 
 
15
 
16
  AdamW.lr: 0.0001
17
 
 
4
  fine_tune: True
5
 
6
  train/AudioDataset.n_examples: 100000000
7
+ val/AudioDataset.n_examples: 500
8
 
9
 
10
  NoamScheduler.warmup: 500
11
 
12
  batch_size: 7
13
  num_workers: 7
14
+ save_iters: [10000, 20000, 30000, 40000, 50000]
15
+ sample_freq: 1000
16
+ val_freq: 500
17
 
18
  AdamW.lr: 0.0001
19
 
scripts/exp/fine_tune.py CHANGED
@@ -48,11 +48,9 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
48
  }
49
 
50
  interface_conf = {
51
- "Interface.coarse_ckpt": f"./models/vampnet/coarse.pth",
52
- "Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
53
 
54
- "Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth",
55
- "Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
56
  "Interface.wavebeat_ckpt": "./models/wavebeat.pth",
57
 
58
  "Interface.codec_ckpt": "./models/vampnet/codec.pth",
 
48
  }
49
 
50
  interface_conf = {
51
+ "Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth",
 
52
 
53
+ "Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth",
 
54
  "Interface.wavebeat_ckpt": "./models/wavebeat.pth",
55
 
56
  "Interface.codec_ckpt": "./models/vampnet/codec.pth",
scripts/exp/train.py CHANGED
@@ -14,7 +14,7 @@ from audiotools.data import transforms
14
  from einops import rearrange
15
  from rich import pretty
16
  from rich.traceback import install
17
- from tensorboardX import SummaryWriter
18
 
19
  import vampnet
20
  from vampnet.modules.transformer import VampNet
@@ -29,6 +29,9 @@ from audiotools.ml.decorators import (
29
 
30
  import loralib as lora
31
 
 
 
 
32
 
33
  # Enable cudnn autotuner to speed up training
34
  # (can be altered by the funcs.seed function)
@@ -601,7 +604,7 @@ def train(
601
  accel=accel,
602
  tracker=tracker,
603
  save_path=save_path)
604
-
605
 
606
  train_dataloader = accel.prepare_dataloader(
607
  state.train_data,
@@ -616,13 +619,15 @@ def train(
616
  num_workers=num_workers,
617
  batch_size=batch_size,
618
  collate_fn=state.val_data.collate,
619
- persistent_workers=True,
620
  )
 
621
 
622
 
623
 
624
  if fine_tune:
625
  lora.mark_only_lora_as_trainable(state.model)
 
626
 
627
  # Wrap the functions so that they neatly track in TensorBoard + progress bars
628
  # and only run when specific conditions are met.
@@ -637,6 +642,7 @@ def train(
637
  save_samples = when(lambda: accel.local_rank == 0)(save_samples)
638
  checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
639
 
 
640
  with tracker.live:
641
  for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
642
  train_loop(state, batch, accel)
 
14
  from einops import rearrange
15
  from rich import pretty
16
  from rich.traceback import install
17
+ from torch.utils.tensorboard import SummaryWriter
18
 
19
  import vampnet
20
  from vampnet.modules.transformer import VampNet
 
29
 
30
  import loralib as lora
31
 
32
+ import torch._dynamo
33
+ torch._dynamo.config.verbose=True
34
+
35
 
36
  # Enable cudnn autotuner to speed up training
37
  # (can be altered by the funcs.seed function)
 
604
  accel=accel,
605
  tracker=tracker,
606
  save_path=save_path)
607
+ print("initialized state.")
608
 
609
  train_dataloader = accel.prepare_dataloader(
610
  state.train_data,
 
619
  num_workers=num_workers,
620
  batch_size=batch_size,
621
  collate_fn=state.val_data.collate,
622
+ persistent_workers=num_workers > 0,
623
  )
624
+ print("initialized dataloader.")
625
 
626
 
627
 
628
  if fine_tune:
629
  lora.mark_only_lora_as_trainable(state.model)
630
+ print("marked only lora as trainable.")
631
 
632
  # Wrap the functions so that they neatly track in TensorBoard + progress bars
633
  # and only run when specific conditions are met.
 
642
  save_samples = when(lambda: accel.local_rank == 0)(save_samples)
643
  checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
644
 
645
+ print("starting training loop.")
646
  with tracker.live:
647
  for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
648
  train_loop(state, batch, accel)
scripts/utils/augment.py CHANGED
@@ -5,34 +5,19 @@ from audiotools import AudioSignal
5
 
6
  import argbind
7
  import tqdm
 
8
 
9
 
10
- from pedalboard import (
11
- Compressor, Gain, Chorus, LadderFilter, Phaser, Convolution, Reverb, Pedalboard
12
- )
13
- from pedalboard.io import AudioFile
14
 
15
- # Read in a whole file, resampling to our desired sample rate:
16
- samplerate = 44100.0
17
- with AudioFile('guitar-input.wav').resampled_to(samplerate) as f:
18
- audio = f.read(f.frames)
19
-
20
- # Make a pretty interesting sounding guitar pedalboard:
21
- board = Pedalboard([
22
- Compressor(threshold_db=-50, ratio=25),
23
- Gain(gain_db=30),
24
- Chorus(),
25
- LadderFilter(mode=LadderFilter.Mode.HPF12, cutoff_hz=900),
26
- Phaser(),
27
- Convolution("./guitar_amp.wav", 1.0),
28
- Reverb(room_size=0.25),
29
- ])
30
 
31
 
32
  @argbind.bind(without_prefix=True)
33
  def augment(
34
- audio_folder: Path,
35
- dest_folder: Path,
36
  n_augmentations: int = 10,
37
  ):
38
  """
@@ -41,7 +26,8 @@ def augment(
41
  The dest foler will contain a folder for each of the clean dataset's files.
42
  Under each of these folders, there will be a clean file and many augmented files.
43
  """
44
-
 
45
  audio_files = at.util.find_audio(audio_folder)
46
 
47
  for audio_file in tqdm.tqdm(audio_files):
@@ -49,5 +35,33 @@ def augment(
49
  subdir = subtree / audio_file.stem
50
  subdir.mkdir(parents=True, exist_ok=True)
51
 
52
- # apply pedalboard transforms
53
- for i in range(n_augmentations):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  import argbind
7
  import tqdm
8
+ import torch
9
 
10
 
11
+ from torch_pitch_shift import pitch_shift, get_fast_shifts
12
+ from torch_time_stretch import time_stretch, get_fast_stretches
 
 
13
 
14
+ from audiotools.core.util import sample_from_dist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  @argbind.bind(without_prefix=True)
18
  def augment(
19
+ audio_folder: Path = None,
20
+ dest_folder: Path = None,
21
  n_augmentations: int = 10,
22
  ):
23
  """
 
26
  The dest foler will contain a folder for each of the clean dataset's files.
27
  Under each of these folders, there will be a clean file and many augmented files.
28
  """
29
+ assert audio_folder is not None
30
+ assert dest_folder is not None
31
  audio_files = at.util.find_audio(audio_folder)
32
 
33
  for audio_file in tqdm.tqdm(audio_files):
 
35
  subdir = subtree / audio_file.stem
36
  subdir.mkdir(parents=True, exist_ok=True)
37
 
38
+ src = AudioSignal(audio_file).to("cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+
41
+ for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
42
+ # apply pedalboard transforms
43
+ for j in range(n_augmentations):
44
+ # pitch shift between -7 and 7 semitones
45
+ import random
46
+ dst = chunk.clone()
47
+ dst.samples = pitch_shift(
48
+ dst.samples,
49
+ shift=random.choice(get_fast_shifts(src.sample_rate,
50
+ condition=lambda x: x >= 0.25 and x <= 1.0)),
51
+ sample_rate=src.sample_rate
52
+ )
53
+ dst.samples = time_stretch(
54
+ dst.samples,
55
+ stretch=random.choice(get_fast_stretches(src.sample_rate,
56
+ condition=lambda x: x >= 0.667 and x <= 1.5, )),
57
+ sample_rate=src.sample_rate,
58
+ )
59
+
60
+ dst.cpu().write(subdir / f"{i}-{j}.wav")
61
+
62
+
63
+ if __name__ == "__main__":
64
+ args = argbind.parse_args()
65
+
66
+ with argbind.scope(args):
67
+ augment()
scripts/utils/remove_quiet_files.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # removes files with loudness below 24db
2
+
3
+ from pathlib import Path
4
+ import shutil
5
+ import audiotools as at
6
+ import argbind
7
+
8
+ @argbind.bind(without_prefix=True)
9
+ def remove_quiet_files(
10
+ src_dir: Path = None,
11
+ dest_dir: Path = None,
12
+ min_loudness: float = -30,
13
+ ):
14
+ # copy src to dest
15
+ dest_dir.mkdir(parents=True, exist_ok=True)
16
+ shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
17
+
18
+ audio_files = at.util.find_audio(dest_dir)
19
+ for audio_file in audio_files:
20
+ sig = at.AudioSignal(audio_file)
21
+ if sig.loudness() < min_loudness:
22
+ audio_file.unlink()
23
+ print(f"removed {audio_file}")
24
+
25
+ if __name__ == "__main__":
26
+ args = argbind.parse_args()
27
+
28
+ with argbind.scope(args):
29
+ remove_quiet_files()
scripts/utils/split_long_audio_file.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import argbind
3
+
4
+ import audiotools as at
5
+ import tqdm
6
+
7
+
8
+ @argbind.bind(without_prefix=True)
9
+ def split_long_audio_file(
10
+ file: str = None,
11
+ max_chunk_size_s: int = 60*10
12
+ ):
13
+ file = Path(file)
14
+ output_dir = file.parent / file.stem
15
+ output_dir.mkdir()
16
+
17
+ sig = at.AudioSignal(file)
18
+
19
+ # split into chunks
20
+ for i, sig in tqdm.tqdm(enumerate(sig.windows(
21
+ window_duration=max_chunk_size_s, hop_duration=max_chunk_size_s/2,
22
+ preprocess=True))
23
+ ):
24
+ sig.write(output_dir / f"{i}.wav")
25
+
26
+ print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
27
+
28
+ return output_dir
29
+
30
+ if __name__ == "__main__":
31
+ args = argbind.parse_args()
32
+
33
+ with argbind.scope(args):
34
+ split_long_audio_file()
scripts/utils/xeno-canto-dl.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from xenopy import Query
2
+
3
+
4
+ SPECIES = [
5
+ "American Robin",
6
+ "Northern Cardinal",
7
+ "Mourning Dove",
8
+ "American Crow",
9
+ "Baltimore Oriole",
10
+ "Blue Jay",
11
+ "Eastern Bluebird",
12
+ "House Finch",
13
+ "American Goldfinch",
14
+ "House Sparrow",
15
+ "Song Sparrow",
16
+ "Tufted Titmouse",
17
+ "White-breasted Nuthatch",
18
+ "European Starling",
19
+ "American Redstart",
20
+ "Red-winged Blackbird",
21
+ "Brown-headed Cowbird",
22
+ "Common Grackle",
23
+ "Boat-tailed Grackle",
24
+ "Common Yellowthroat",
25
+ "Northern Mockingbird",
26
+ "Carolina Wren",
27
+ "Eastern Meadowlark",
28
+ "Chipping Sparrow",
29
+ "Tree Swallow",
30
+ "Barn Swallow",
31
+ "Cliff Swallow",
32
+ "Pine Siskin",
33
+ "Indigo Bunting",
34
+ "Eastern Towhee",
35
+ "Carolina Chickadee",
36
+ "Great Crested Flycatcher",
37
+ "Eastern Wood-Pewee",
38
+ "Ovenbird",
39
+ "Northern Flicker",
40
+ "Red-eyed Vireo",
41
+ "American Woodcock",
42
+ "Eastern Phoebe",
43
+ "Downy Woodpecker",
44
+ "Scarlet Tanager",
45
+ "Yellow Warbler",
46
+ "White-eyed Vireo",
47
+ "Common Loon",
48
+ "White-throated Sparrow",
49
+ "Yellow-throated Vireo",
50
+ "Great Blue Heron",
51
+ "Belted Kingfisher",
52
+ "Pied-billed Grebe",
53
+ "Wild Turkey",
54
+ "Wood Thrush",
55
+ "Rose-breasted Grosbeak",
56
+ "Field Sparrow",
57
+ "Hooded Warbler",
58
+ "Northern Parula",
59
+ "Chestnut-sided Warbler",
60
+ "Blue-winged Warbler",
61
+ "Red-bellied Woodpecker",
62
+ "Yellow-billed Cuckoo",
63
+ "Gray Catbird",
64
+ "Northern Saw-whet Owl",
65
+ "Osprey",
66
+ "Common Nighthawk",
67
+ "Broad-winged Hawk",
68
+ "Black-throated Green Warbler",
69
+ "Great Horned Owl",
70
+ "Common Raven",
71
+ "Barred Owl",
72
+ "Canada Warbler",
73
+ "Magnolia Warbler",
74
+ "Black-and-white Warbler",
75
+ "Eastern Kingbird",
76
+ "Swainson's Thrush",
77
+ "Worm-eating Warbler",
78
+ "Prairie Warbler",
79
+ "Baltimore Oriole",
80
+ "Black-throated Blue Warbler",
81
+ "Louisiana Waterthrush",
82
+ "Blackburnian Warbler",
83
+ "Black-capped Chickadee",
84
+ "Cerulean Warbler",
85
+ "Red-shouldered Hawk",
86
+ "Cooper's Hawk",
87
+ "Yellow-throated Warbler",
88
+ "Blue-headed Vireo",
89
+ "Blackpoll Warbler",
90
+ "Ruffed Grouse",
91
+ "Kentucky Warbler",
92
+ "Hermit Thrush",
93
+ "Cedar Waxwing",
94
+ "Eastern Screech-Owl",
95
+ "Northern Goshawk",
96
+ "Green Heron",
97
+ "Red-tailed Hawk",
98
+ "Black Vulture",
99
+ "Hairy Woodpecker",
100
+ "Golden-crowned Kinglet",
101
+ "Ruby-crowned Kinglet",
102
+ "Bicknell's Thrush",
103
+ "Blue-gray Gnatcatcher",
104
+ "Veery",
105
+ "Pileated Woodpecker",
106
+ "Purple Finch",
107
+ "White-crowned Sparrow",
108
+ "Snow Bunting",
109
+ "Pine Grosbeak",
110
+ "American Tree Sparrow",
111
+ "Dark-eyed Junco",
112
+ "Snowy Owl",
113
+ "White-winged Crossbill",
114
+ "Red Crossbill",
115
+ "Common Redpoll",
116
+ "Northern Shrike",
117
+ "Northern Harrier",
118
+ "Rough-legged Hawk",
119
+ "Long-eared Owl",
120
+ "Evening Grosbeak",
121
+ "Northern Pintail",
122
+ "American Black Duck",
123
+ "Mallard",
124
+ "Canvasback",
125
+ "Redhead",
126
+ "Ring-necked Duck",
127
+ "Greater Scaup",
128
+ "Lesser Scaup",
129
+ "Bufflehead",
130
+ "Common Goldeneye",
131
+ "Hooded Merganser",
132
+ "Common Merganser",
133
+ "Red-breasted Merganser",
134
+ "Ruddy Duck",
135
+ "Wood Duck",
136
+ "Gadwall",
137
+ "American Wigeon",
138
+ "Northern Shoveler",
139
+ "Green-winged Teal",
140
+ "Blue-winged Teal",
141
+ "Cinnamon Teal",
142
+ "Ringed Teal",
143
+ "Cape Teal",
144
+ "Northern Fulmar",
145
+ "Yellow-billed Loon",
146
+ "Red-throated Loon",
147
+ "Arctic Loon",
148
+ "Pacific Loon",
149
+ "Horned Grebe",
150
+ "Red-necked Grebe",
151
+ "Eared Grebe",
152
+ "Western Grebe",
153
+ "Clark's Grebe",
154
+ "Double-crested Cormorant",
155
+ "Pelagic Cormorant",
156
+ "Great Cormorant",
157
+ "American White Pelican",
158
+ "Brown Pelican",
159
+ "Brandt's Cormorant",
160
+ "Least Bittern",
161
+ "Great Egret",
162
+ "Snowy Egret",
163
+ "Little Blue Heron",
164
+ "Tricolored Heron",
165
+ "Reddish Egret",
166
+ "Black-crowned Night-Heron",
167
+ "Yellow-crowned Night-Heron",
168
+ "White Ibis",
169
+ "Glossy Ibis",
170
+ "Roseate Spoonbill",
171
+ "Wood Stork",
172
+ "Black-bellied Whistling-Duck",
173
+ "Fulvous Whistling-Duck",
174
+ "Greater White-fronted Goose",
175
+ "Snow Goose",
176
+ "Ross's Goose",
177
+ "Canada Goose",
178
+ "Brant",
179
+ "Mute Swan",
180
+ "Tundra Swan",
181
+ "Whooper Swan",
182
+ "Sandhill Crane",
183
+ "Black-necked Stilt",
184
+ "American Avocet",
185
+ "Northern Jacana",
186
+ "Greater Yellowlegs",
187
+ "Lesser Yellowlegs",
188
+ "Willet",
189
+ "Spotted Sandpiper",
190
+ "Upland Sandpiper",
191
+ "Whimbrel",
192
+ "Long-billed Curlew",
193
+ "Marbled Godwit",
194
+ "Ruddy Turnstone",
195
+ "Red Knot",
196
+ "Sanderling",
197
+ "Semipalmated Sandpiper",
198
+ "Western Sandpiper",
199
+ "Least Sandpiper",
200
+ "White-rumped Sandpiper",
201
+ "Baird's Sandpiper",
202
+ "Pectoral Sandpiper",
203
+ "Dunlin",
204
+ "Buff-breasted Sandpiper",
205
+ "Short-billed Dowitcher",
206
+ "Long-billed Dowitcher",
207
+ "Common Snipe",
208
+ "American Woodcock",
209
+ "Wilson's Phalarope",
210
+ "Red-necked Phalarope",
211
+ "Red Phalarope"
212
+ ]
213
+
214
+ from pathlib import Path
215
+
216
+ def remove_spaces(s):
217
+ return s.replace(" ", "")
218
+
219
+ for species in SPECIES:
220
+ if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
221
+ continue
222
+ try:
223
+ q = Query(
224
+ name=species, q="A", length="10-30",
225
+ )
226
+
227
+ # retrieve metadata
228
+ metafiles = q.retrieve_meta(verbose=True)
229
+ # retrieve recordings
230
+ q.retrieve_recordings(multiprocess=True, nproc=10, attempts=10, outdir="/media/CHONK/hugo/xeno-canto-full/")
231
+
232
+ except:
233
+ print("Failed to download " + species)
234
+ continue
setup.py CHANGED
@@ -28,12 +28,13 @@ setup(
28
  install_requires=[
29
  "torch",
30
  "argbind>=0.3.2",
31
- "numpy==1.22",
32
  "wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
33
  "lac @ git+https://github.com/hugofloresgarcia/lac.git",
34
  "descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2",
35
  "gradio",
36
- "tensorboardX",
37
  "loralib",
 
 
38
  ],
39
  )
 
28
  install_requires=[
29
  "torch",
30
  "argbind>=0.3.2",
31
+ "numpy==1.23",
32
  "wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
33
  "lac @ git+https://github.com/hugofloresgarcia/lac.git",
34
  "descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2",
35
  "gradio",
 
36
  "loralib",
37
+ "torch_pitch_shift",
38
+ "madmom",
39
  ],
40
  )
vampnet/interface.py CHANGED
@@ -120,17 +120,16 @@ class Interface(torch.nn.Module):
120
  if coarse_ckpt is not None:
121
  self.coarse.to("cpu")
122
  state_dict = torch.load(coarse_ckpt, map_location="cpu")
123
-
124
  self.coarse.load_state_dict(state_dict, strict=False)
125
  self.coarse.to(self.device)
126
  if c2f_ckpt is not None:
127
  self.c2f.to("cpu")
128
  state_dict = torch.load(c2f_ckpt, map_location="cpu")
129
-
130
  self.c2f.load_state_dict(state_dict, strict=False)
131
  self.c2f.to(self.device)
132
 
133
-
134
  def s2t(self, seconds: float):
135
  """seconds to tokens"""
136
  if isinstance(seconds, np.ndarray):
 
120
  if coarse_ckpt is not None:
121
  self.coarse.to("cpu")
122
  state_dict = torch.load(coarse_ckpt, map_location="cpu")
123
+ print(f"loading coarse from {coarse_ckpt}")
124
  self.coarse.load_state_dict(state_dict, strict=False)
125
  self.coarse.to(self.device)
126
  if c2f_ckpt is not None:
127
  self.c2f.to("cpu")
128
  state_dict = torch.load(c2f_ckpt, map_location="cpu")
129
+ print(f"loading c2f from {c2f_ckpt}")
130
  self.c2f.load_state_dict(state_dict, strict=False)
131
  self.c2f.to(self.device)
132
 
 
133
  def s2t(self, seconds: float):
134
  """seconds to tokens"""
135
  if isinstance(seconds, np.ndarray):
vampnet/mask.py CHANGED
@@ -191,29 +191,47 @@ def onset_mask(
191
  width: int = 1
192
  ):
193
  import librosa
194
-
195
- onset_indices = librosa.onset.onset_detect(
196
- y=sig.clone().to_mono().samples.cpu().numpy()[0, 0],
197
- sr=sig.sample_rate,
198
- hop_length=interface.codec.hop_length,
199
- backtrack=True,
200
- )
201
-
202
- # create a mask, set onset
203
- mask = torch.ones_like(z)
204
- n_timesteps = z.shape[-1]
205
-
206
- for onset_index in onset_indices:
207
- onset_index = min(onset_index, n_timesteps - 1)
208
- onset_index = max(onset_index, 0)
209
- mask[:, :, onset_index - width:onset_index + width] = 0.0
210
-
211
- print(mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  return mask
214
 
215
 
216
 
217
  if __name__ == "__main__":
218
- torch.set_printoptions(threshold=10000)
219
-
 
191
  width: int = 1
192
  ):
193
  import librosa
194
+ import madmom
195
+ from madmom.features.onsets import RNNOnsetProcessor, OnsetPeakPickingProcessor
196
+ import tempfile
197
+ import numpy as np
198
+
199
+ with tempfile.NamedTemporaryFile(suffix='.wav') as f:
200
+ sig = sig.clone()
201
+ sig.write(f.name)
202
+
203
+ proc = RNNOnsetProcessor(online=False)
204
+ onsetproc = OnsetPeakPickingProcessor(threshold=0.3,
205
+ fps=sig.sample_rate/interface.codec.hop_length)
206
+
207
+ act = proc(f.name)
208
+ onset_times = onsetproc(act)
209
+
210
+ # convert to indices for z array
211
+ onset_indices = librosa.time_to_frames(onset_times, sr=sig.sample_rate, hop_length=interface.codec.hop_length)
212
+
213
+ if onset_indices.shape[0] == 0:
214
+ mask = empty_mask(z)
215
+ print(f"no onsets found, returning empty mask")
216
+ else:
217
+ torch.set_printoptions(threshold=1000)
218
+ print("onset indices: ", onset_indices)
219
+ print("onset times: ", onset_times)
220
+
221
+ # create a mask, set onset
222
+ mask = torch.ones_like(z)
223
+ n_timesteps = z.shape[-1]
224
+
225
+ for onset_index in onset_indices:
226
+ onset_index = min(onset_index, n_timesteps - 1)
227
+ onset_index = max(onset_index, 0)
228
+ mask[:, :, onset_index - width:onset_index + width] = 0.0
229
+
230
+ print(mask)
231
 
232
  return mask
233
 
234
 
235
 
236
  if __name__ == "__main__":
237
+ pass
 
vampnet/modules/transformer.py CHANGED
@@ -367,15 +367,6 @@ class TransformerLayer(nn.Module):
367
 
368
  return x, position_bias, encoder_decoder_position_bias
369
 
370
- def t_schedule(n_steps, max_temp=1.0, min_temp=0.0, k=1.0):
371
- x = np.linspace(0, 1, n_steps)
372
- a = (0.5 - min_temp) / (max_temp - min_temp)
373
-
374
- x = (x * 12) - 6
375
- x0 = np.log((1 / a - 1) + 1e-5) / k
376
- y = (1 / (1 + np.exp(- k *(x-x0))))[::-1]
377
-
378
- return y
379
 
380
  class TransformerStack(nn.Module):
381
  def __init__(
@@ -587,17 +578,18 @@ class VampNet(at.ml.BaseModel):
587
  self,
588
  codec,
589
  time_steps: int = 300,
590
- sampling_steps: int = 24,
591
  start_tokens: Optional[torch.Tensor] = None,
592
  sampling_temperature: float = 1.0,
593
  mask: Optional[torch.Tensor] = None,
594
- mask_temperature: float = 20.5,
595
  typical_filtering=False,
596
  typical_mass=0.2,
597
  typical_min_tokens=1,
598
  top_p=None,
599
  return_signal=True,
600
- seed: int = None
 
601
  ):
602
  if seed is not None:
603
  at.util.seed(seed)
@@ -650,7 +642,6 @@ class VampNet(at.ml.BaseModel):
650
  #################
651
  # begin sampling #
652
  #################
653
- t_sched = t_schedule(sampling_steps, max_temp=sampling_temperature)
654
 
655
  for i in range(sampling_steps):
656
  logging.debug(f"step {i} of {sampling_steps}")
@@ -676,10 +667,13 @@ class VampNet(at.ml.BaseModel):
676
  logging.debug(f"permuted logits with shape: {logits.shape}")
677
 
678
  sampled_z, selected_probs = sample_from_logits(
679
- logits, sample=True, temperature=t_sched[i],
 
 
 
680
  typical_filtering=typical_filtering, typical_mass=typical_mass,
681
  typical_min_tokens=typical_min_tokens,
682
- top_k=None, top_p=top_p, return_probs=True
683
  )
684
 
685
  logging.debug(f"sampled z with shape: {sampled_z.shape}")
@@ -839,7 +833,11 @@ def sample_from_logits(
839
 
840
 
841
 
842
- def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
 
 
 
 
843
  """
844
  Args:
845
  num_to_mask (int): number of tokens to mask
@@ -852,7 +850,8 @@ def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: floa
852
  logging.debug(f"temperature: {temperature}")
853
  logging.debug("")
854
 
855
- confidence = torch.log(probs) + temperature * gumbel_noise_like(probs)
 
856
  logging.debug(f"confidence shape: {confidence.shape}")
857
 
858
  sorted_confidence, sorted_idx = confidence.sort(dim=-1)
 
367
 
368
  return x, position_bias, encoder_decoder_position_bias
369
 
 
 
 
 
 
 
 
 
 
370
 
371
  class TransformerStack(nn.Module):
372
  def __init__(
 
578
  self,
579
  codec,
580
  time_steps: int = 300,
581
+ sampling_steps: int = 36,
582
  start_tokens: Optional[torch.Tensor] = None,
583
  sampling_temperature: float = 1.0,
584
  mask: Optional[torch.Tensor] = None,
585
+ mask_temperature: float = 10.5,
586
  typical_filtering=False,
587
  typical_mass=0.2,
588
  typical_min_tokens=1,
589
  top_p=None,
590
  return_signal=True,
591
+ seed: int = None,
592
+ sample_cutoff: float = 0.5,
593
  ):
594
  if seed is not None:
595
  at.util.seed(seed)
 
642
  #################
643
  # begin sampling #
644
  #################
 
645
 
646
  for i in range(sampling_steps):
647
  logging.debug(f"step {i} of {sampling_steps}")
 
667
  logging.debug(f"permuted logits with shape: {logits.shape}")
668
 
669
  sampled_z, selected_probs = sample_from_logits(
670
+ logits, sample=(
671
+ (i / sampling_steps) <= sample_cutoff
672
+ ),
673
+ temperature=sampling_temperature,
674
  typical_filtering=typical_filtering, typical_mass=typical_mass,
675
  typical_min_tokens=typical_min_tokens,
676
+ top_k=None, top_p=top_p, return_probs=True,
677
  )
678
 
679
  logging.debug(f"sampled z with shape: {sampled_z.shape}")
 
833
 
834
 
835
 
836
+ def mask_by_random_topk(
837
+ num_to_mask: int,
838
+ probs: torch.Tensor,
839
+ temperature: float = 1.0,
840
+ ):
841
  """
842
  Args:
843
  num_to_mask (int): number of tokens to mask
 
850
  logging.debug(f"temperature: {temperature}")
851
  logging.debug("")
852
 
853
+ noise = gumbel_noise_like(probs)
854
+ confidence = torch.log(probs) + temperature * noise
855
  logging.debug(f"confidence shape: {confidence.shape}")
856
 
857
  sorted_confidence, sorted_idx = confidence.sort(dim=-1)