Hugo Flores Garcia commited on
Commit
7b88c07
1 Parent(s): bf35d45

better onset detection!!!!!

Browse files

remove annealing from sampling temperature
add pitch shifting w/ torch pitch shift
improvements to lora config
(TODO: fix a lora bug where the lora weights won't load correctly)
add helper scripts for collecting xeno canto data

app.py CHANGED
@@ -18,6 +18,16 @@ Interface = argbind.bind(Interface)
18
 
19
  conf = argbind.parse_args()
20
 
 
 
 
 
 
 
 
 
 
 
21
  def load_interface():
22
  with argbind.scope(conf):
23
  interface = Interface()
@@ -95,6 +105,10 @@ def _vamp(data, return_mask=False):
95
  out_dir = OUT_DIR / str(uuid.uuid4())
96
  out_dir.mkdir()
97
  sig = at.AudioSignal(data[input_audio])
 
 
 
 
98
 
99
  z = interface.encode(sig)
100
 
@@ -134,7 +148,27 @@ def _vamp(data, return_mask=False):
134
  mask = pmask.codebook_unmask(mask, ncc)
135
 
136
 
137
- print(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  _top_p = data[top_p] if data[top_p] > 0 else None
139
  # save the mask as a txt file
140
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
@@ -153,6 +187,7 @@ def _vamp(data, return_mask=False):
153
  top_p=_top_p,
154
  gen_fn=interface.coarse.generate,
155
  seed=_seed,
 
156
  )
157
 
158
  if use_coarse2fine:
@@ -356,7 +391,7 @@ with gr.Blocks() as demo:
356
  onset_mask_width = gr.Slider(
357
  label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
358
  minimum=0,
359
- maximum=20,
360
  step=1,
361
  value=5,
362
  )
@@ -374,6 +409,14 @@ with gr.Blocks() as demo:
374
 
375
 
376
  with gr.Accordion("extras ", open=False):
 
 
 
 
 
 
 
 
377
  rand_mask_intensity = gr.Slider(
378
  label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
379
  minimum=0.0,
@@ -436,14 +479,15 @@ with gr.Blocks() as demo:
436
  masktemp = gr.Slider(
437
  label="mask temperature",
438
  minimum=0.0,
439
- maximum=10.0,
440
  value=1.5
441
  )
442
  sampletemp = gr.Slider(
443
  label="sample temperature",
444
  minimum=0.1,
445
- maximum=2.0,
446
- value=1.0
 
447
  )
448
 
449
 
@@ -459,7 +503,7 @@ with gr.Blocks() as demo:
459
  label="typical filtering ",
460
  value=False
461
  )
462
- typical_mass = gr.Slider(
463
  label="typical mass (should probably stay between 0.1 and 0.5)",
464
  minimum=0.01,
465
  maximum=0.99,
@@ -472,6 +516,13 @@ with gr.Blocks() as demo:
472
  step=1,
473
  value=64
474
  )
 
 
 
 
 
 
 
475
 
476
  use_coarse2fine = gr.Checkbox(
477
  label="use coarse2fine",
@@ -495,10 +546,6 @@ with gr.Blocks() as demo:
495
  value=0.0
496
  )
497
 
498
- use_new_trick = gr.Checkbox(
499
- label="new trick",
500
- value=False
501
- )
502
 
503
  seed = gr.Number(
504
  label="seed (0 for random)",
@@ -560,6 +607,8 @@ with gr.Blocks() as demo:
560
  beat_mask_downbeats,
561
  seed,
562
  lora_choice,
 
 
563
  }
564
 
565
  # connect widgets
@@ -589,4 +638,4 @@ with gr.Blocks() as demo:
589
  outputs=[thank_you, download_file]
590
  )
591
 
592
- demo.launch(share=True, enable_queue=False, debug=True)
 
18
 
19
  conf = argbind.parse_args()
20
 
21
+
22
+ from torch_pitch_shift import pitch_shift, get_fast_shifts
23
+ def shift_pitch(signal, interval: int):
24
+ signal.samples = pitch_shift(
25
+ signal.samples,
26
+ shift=interval,
27
+ sample_rate=signal.sample_rate
28
+ )
29
+ return signal
30
+
31
  def load_interface():
32
  with argbind.scope(conf):
33
  interface = Interface()
 
105
  out_dir = OUT_DIR / str(uuid.uuid4())
106
  out_dir.mkdir()
107
  sig = at.AudioSignal(data[input_audio])
108
+ sig = interface.preprocess(sig)
109
+
110
+ if data[pitch_shift_amt] != 0:
111
+ sig = shift_pitch(sig, data[pitch_shift_amt])
112
 
113
  z = interface.encode(sig)
114
 
 
148
  mask = pmask.codebook_unmask(mask, ncc)
149
 
150
 
151
+ print(f"dropout {data[dropout]}")
152
+ print(f"masktemp {data[masktemp]}")
153
+ print(f"sampletemp {data[sampletemp]}")
154
+ print(f"top_p {data[top_p]}")
155
+ print(f"prefix_s {data[prefix_s]}")
156
+ print(f"suffix_s {data[suffix_s]}")
157
+ print(f"rand_mask_intensity {data[rand_mask_intensity]}")
158
+ print(f"num_steps {data[num_steps]}")
159
+ print(f"periodic_p {data[periodic_p]}")
160
+ print(f"periodic_w {data[periodic_w]}")
161
+ print(f"n_conditioning_codebooks {data[n_conditioning_codebooks]}")
162
+ print(f"use_coarse2fine {data[use_coarse2fine]}")
163
+ print(f"onset_mask_width {data[onset_mask_width]}")
164
+ print(f"beat_mask_width {data[beat_mask_width]}")
165
+ print(f"beat_mask_downbeats {data[beat_mask_downbeats]}")
166
+ print(f"stretch_factor {data[stretch_factor]}")
167
+ print(f"seed {data[seed]}")
168
+ print(f"pitch_shift_amt {data[pitch_shift_amt]}")
169
+ print(f"sample_cutoff {data[sample_cutoff]}")
170
+
171
+
172
  _top_p = data[top_p] if data[top_p] > 0 else None
173
  # save the mask as a txt file
174
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
 
187
  top_p=_top_p,
188
  gen_fn=interface.coarse.generate,
189
  seed=_seed,
190
+ sample_cutoff=data[sample_cutoff],
191
  )
192
 
193
  if use_coarse2fine:
 
391
  onset_mask_width = gr.Slider(
392
  label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
393
  minimum=0,
394
+ maximum=100,
395
  step=1,
396
  value=5,
397
  )
 
409
 
410
 
411
  with gr.Accordion("extras ", open=False):
412
+ pitch_shift_amt = gr.Slider(
413
+ label="pitch shift amount (semitones)",
414
+ minimum=-12,
415
+ maximum=12,
416
+ step=1,
417
+ value=0,
418
+ )
419
+
420
  rand_mask_intensity = gr.Slider(
421
  label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
422
  minimum=0.0,
 
479
  masktemp = gr.Slider(
480
  label="mask temperature",
481
  minimum=0.0,
482
+ maximum=100.0,
483
  value=1.5
484
  )
485
  sampletemp = gr.Slider(
486
  label="sample temperature",
487
  minimum=0.1,
488
+ maximum=10.0,
489
+ value=1.0,
490
+ step=0.001
491
  )
492
 
493
 
 
503
  label="typical filtering ",
504
  value=False
505
  )
506
+ typical_mass = gr.Slider(
507
  label="typical mass (should probably stay between 0.1 and 0.5)",
508
  minimum=0.01,
509
  maximum=0.99,
 
516
  step=1,
517
  value=64
518
  )
519
+ sample_cutoff = gr.Slider(
520
+ label="sample cutoff",
521
+ minimum=0.0,
522
+ maximum=1.0,
523
+ value=0.5,
524
+ step=0.01
525
+ )
526
 
527
  use_coarse2fine = gr.Checkbox(
528
  label="use coarse2fine",
 
546
  value=0.0
547
  )
548
 
 
 
 
 
549
 
550
  seed = gr.Number(
551
  label="seed (0 for random)",
 
607
  beat_mask_downbeats,
608
  seed,
609
  lora_choice,
610
+ pitch_shift_amt,
611
+ sample_cutoff
612
  }
613
 
614
  # connect widgets
 
638
  outputs=[thank_you, download_file]
639
  )
640
 
641
+ demo.launch(share=True, enable_queue=True, debug=True)
conf/lora/lora.yml CHANGED
@@ -4,14 +4,16 @@ $include:
4
  fine_tune: True
5
 
6
  train/AudioDataset.n_examples: 100000000
7
- val/AudioDataset.n_examples: 100
8
 
9
 
10
  NoamScheduler.warmup: 500
11
 
12
  batch_size: 7
13
  num_workers: 7
14
- save_iters: [100000, 200000, 300000, 4000000, 500000]
 
 
15
 
16
  AdamW.lr: 0.0001
17
 
 
4
  fine_tune: True
5
 
6
  train/AudioDataset.n_examples: 100000000
7
+ val/AudioDataset.n_examples: 500
8
 
9
 
10
  NoamScheduler.warmup: 500
11
 
12
  batch_size: 7
13
  num_workers: 7
14
+ save_iters: [10000, 20000, 30000, 40000, 50000]
15
+ sample_freq: 1000
16
+ val_freq: 500
17
 
18
  AdamW.lr: 0.0001
19
 
scripts/exp/fine_tune.py CHANGED
@@ -48,10 +48,10 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
48
  }
49
 
50
  interface_conf = {
51
- "Interface.coarse_ckpt": f"./models/vampnet/coarse.pth",
52
  "Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
53
 
54
- "Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth",
55
  "Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
56
  "Interface.wavebeat_ckpt": "./models/wavebeat.pth",
57
 
 
48
  }
49
 
50
  interface_conf = {
51
+ "Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth",
52
  "Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
53
 
54
+ "Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth",
55
  "Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
56
  "Interface.wavebeat_ckpt": "./models/wavebeat.pth",
57
 
scripts/utils/augment.py CHANGED
@@ -5,34 +5,19 @@ from audiotools import AudioSignal
5
 
6
  import argbind
7
  import tqdm
 
8
 
9
 
10
- from pedalboard import (
11
- Compressor, Gain, Chorus, LadderFilter, Phaser, Convolution, Reverb, Pedalboard
12
- )
13
- from pedalboard.io import AudioFile
14
 
15
- # Read in a whole file, resampling to our desired sample rate:
16
- samplerate = 44100.0
17
- with AudioFile('guitar-input.wav').resampled_to(samplerate) as f:
18
- audio = f.read(f.frames)
19
-
20
- # Make a pretty interesting sounding guitar pedalboard:
21
- board = Pedalboard([
22
- Compressor(threshold_db=-50, ratio=25),
23
- Gain(gain_db=30),
24
- Chorus(),
25
- LadderFilter(mode=LadderFilter.Mode.HPF12, cutoff_hz=900),
26
- Phaser(),
27
- Convolution("./guitar_amp.wav", 1.0),
28
- Reverb(room_size=0.25),
29
- ])
30
 
31
 
32
  @argbind.bind(without_prefix=True)
33
  def augment(
34
- audio_folder: Path,
35
- dest_folder: Path,
36
  n_augmentations: int = 10,
37
  ):
38
  """
@@ -41,7 +26,8 @@ def augment(
41
  The dest foler will contain a folder for each of the clean dataset's files.
42
  Under each of these folders, there will be a clean file and many augmented files.
43
  """
44
-
 
45
  audio_files = at.util.find_audio(audio_folder)
46
 
47
  for audio_file in tqdm.tqdm(audio_files):
@@ -49,5 +35,33 @@ def augment(
49
  subdir = subtree / audio_file.stem
50
  subdir.mkdir(parents=True, exist_ok=True)
51
 
52
- # apply pedalboard transforms
53
- for i in range(n_augmentations):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  import argbind
7
  import tqdm
8
+ import torch
9
 
10
 
11
+ from torch_pitch_shift import pitch_shift, get_fast_shifts
12
+ from torch_time_stretch import time_stretch, get_fast_stretches
 
 
13
 
14
+ from audiotools.core.util import sample_from_dist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  @argbind.bind(without_prefix=True)
18
  def augment(
19
+ audio_folder: Path = None,
20
+ dest_folder: Path = None,
21
  n_augmentations: int = 10,
22
  ):
23
  """
 
26
  The dest foler will contain a folder for each of the clean dataset's files.
27
  Under each of these folders, there will be a clean file and many augmented files.
28
  """
29
+ assert audio_folder is not None
30
+ assert dest_folder is not None
31
  audio_files = at.util.find_audio(audio_folder)
32
 
33
  for audio_file in tqdm.tqdm(audio_files):
 
35
  subdir = subtree / audio_file.stem
36
  subdir.mkdir(parents=True, exist_ok=True)
37
 
38
+ src = AudioSignal(audio_file).to("cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+
41
+ for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
42
+ # apply pedalboard transforms
43
+ for j in range(n_augmentations):
44
+ # pitch shift between -7 and 7 semitones
45
+ import random
46
+ dst = chunk.clone()
47
+ dst.samples = pitch_shift(
48
+ dst.samples,
49
+ shift=random.choice(get_fast_shifts(src.sample_rate,
50
+ condition=lambda x: x >= 0.25 and x <= 1.0)),
51
+ sample_rate=src.sample_rate
52
+ )
53
+ dst.samples = time_stretch(
54
+ dst.samples,
55
+ stretch=random.choice(get_fast_stretches(src.sample_rate,
56
+ condition=lambda x: x >= 0.667 and x <= 1.5, )),
57
+ sample_rate=src.sample_rate,
58
+ )
59
+
60
+ dst.cpu().write(subdir / f"{i}-{j}.wav")
61
+
62
+
63
+ if __name__ == "__main__":
64
+ args = argbind.parse_args()
65
+
66
+ with argbind.scope(args):
67
+ augment()
scripts/utils/remove_quiet_files.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # removes files with loudness below 24db
2
+
3
+ from pathlib import Path
4
+ import shutil
5
+ import audiotools as at
6
+ import argbind
7
+
8
+ @argbind.bind(without_prefix=True)
9
+ def remove_quiet_files(
10
+ src_dir: Path = None,
11
+ dest_dir: Path = None,
12
+ min_loudness: float = -30,
13
+ ):
14
+ # copy src to dest
15
+ dest_dir.mkdir(parents=True, exist_ok=True)
16
+ shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
17
+
18
+ audio_files = at.util.find_audio(dest_dir)
19
+ for audio_file in audio_files:
20
+ sig = at.AudioSignal(audio_file)
21
+ if sig.loudness() < min_loudness:
22
+ audio_file.unlink()
23
+ print(f"removed {audio_file}")
24
+
25
+ if __name__ == "__main__":
26
+ args = argbind.parse_args()
27
+
28
+ with argbind.scope(args):
29
+ remove_quiet_files()
scripts/xeno-canto-dl.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from xenopy import Query
2
+
3
+
4
+ SPECIES = [
5
+ "American Robin",
6
+ "Northern Cardinal",
7
+ "Mourning Dove",
8
+ "American Crow",
9
+ "Baltimore Oriole",
10
+ "Blue Jay",
11
+ "Eastern Bluebird",
12
+ "House Finch",
13
+ "American Goldfinch",
14
+ "House Sparrow",
15
+ "Song Sparrow",
16
+ "Tufted Titmouse",
17
+ "White-breasted Nuthatch",
18
+ "European Starling",
19
+ "American Redstart",
20
+ "Red-winged Blackbird",
21
+ "Brown-headed Cowbird",
22
+ "Common Grackle",
23
+ "Boat-tailed Grackle",
24
+ "Common Yellowthroat",
25
+ "Northern Mockingbird",
26
+ "Carolina Wren",
27
+ "Eastern Meadowlark",
28
+ "Chipping Sparrow",
29
+ "Tree Swallow",
30
+ "Barn Swallow",
31
+ "Cliff Swallow",
32
+ "Pine Siskin",
33
+ "Indigo Bunting",
34
+ "Eastern Towhee",
35
+ "Carolina Chickadee",
36
+ "Great Crested Flycatcher",
37
+ "Eastern Wood-Pewee",
38
+ "Ovenbird",
39
+ "Northern Flicker",
40
+ "Red-eyed Vireo",
41
+ "American Woodcock",
42
+ "Eastern Phoebe",
43
+ "Downy Woodpecker",
44
+ "Scarlet Tanager",
45
+ "Yellow Warbler",
46
+ "White-eyed Vireo",
47
+ "Common Loon",
48
+ "White-throated Sparrow",
49
+ "Yellow-throated Vireo",
50
+ "Great Blue Heron",
51
+ "Belted Kingfisher",
52
+ "Pied-billed Grebe",
53
+ "Wild Turkey",
54
+ "Wood Thrush",
55
+ "Rose-breasted Grosbeak",
56
+ "Field Sparrow",
57
+ "Hooded Warbler",
58
+ "Northern Parula",
59
+ "Chestnut-sided Warbler",
60
+ "Blue-winged Warbler",
61
+ "Red-bellied Woodpecker",
62
+ "Yellow-billed Cuckoo",
63
+ "Gray Catbird",
64
+ "Northern Saw-whet Owl",
65
+ "Osprey",
66
+ "Common Nighthawk",
67
+ "Broad-winged Hawk",
68
+ "Black-throated Green Warbler",
69
+ "Great Horned Owl",
70
+ "Common Raven",
71
+ "Barred Owl",
72
+ "Canada Warbler",
73
+ "Magnolia Warbler",
74
+ "Black-and-white Warbler",
75
+ "Eastern Kingbird",
76
+ "Swainson's Thrush",
77
+ "Worm-eating Warbler",
78
+ "Prairie Warbler",
79
+ "Baltimore Oriole",
80
+ "Black-throated Blue Warbler",
81
+ "Louisiana Waterthrush",
82
+ "Blackburnian Warbler",
83
+ "Black-capped Chickadee",
84
+ "Cerulean Warbler",
85
+ "Red-shouldered Hawk",
86
+ "Cooper's Hawk",
87
+ "Yellow-throated Warbler",
88
+ "Blue-headed Vireo",
89
+ "Blackpoll Warbler",
90
+ "Ruffed Grouse",
91
+ "Kentucky Warbler",
92
+ "Hermit Thrush",
93
+ "Cedar Waxwing",
94
+ "Eastern Screech-Owl",
95
+ "Northern Goshawk",
96
+ "Green Heron",
97
+ "Red-tailed Hawk",
98
+ "Black Vulture",
99
+ "Hairy Woodpecker",
100
+ "Golden-crowned Kinglet",
101
+ "Ruby-crowned Kinglet",
102
+ "Bicknell's Thrush",
103
+ "Blue-gray Gnatcatcher",
104
+ "Veery",
105
+ "Pileated Woodpecker",
106
+ "Purple Finch",
107
+ "White-crowned Sparrow",
108
+ "Snow Bunting",
109
+ "Pine Grosbeak",
110
+ "American Tree Sparrow",
111
+ "Dark-eyed Junco",
112
+ "Snowy Owl",
113
+ "White-winged Crossbill",
114
+ "Red Crossbill",
115
+ "Common Redpoll",
116
+ "Northern Shrike",
117
+ "Northern Harrier",
118
+ "Rough-legged Hawk",
119
+ "Long-eared Owl",
120
+ "Evening Grosbeak",
121
+ "Northern Pintail",
122
+ "American Black Duck",
123
+ "Mallard",
124
+ "Canvasback",
125
+ "Redhead",
126
+ "Ring-necked Duck",
127
+ "Greater Scaup",
128
+ "Lesser Scaup",
129
+ "Bufflehead",
130
+ "Common Goldeneye",
131
+ "Hooded Merganser",
132
+ "Common Merganser",
133
+ "Red-breasted Merganser",
134
+ "Ruddy Duck",
135
+ "Wood Duck",
136
+ "Gadwall",
137
+ "American Wigeon",
138
+ "Northern Shoveler",
139
+ "Green-winged Teal",
140
+ "Blue-winged Teal",
141
+ "Cinnamon Teal",
142
+ "Ringed Teal",
143
+ "Cape Teal",
144
+ "Northern Fulmar",
145
+ "Yellow-billed Loon",
146
+ "Red-throated Loon",
147
+ "Arctic Loon",
148
+ "Pacific Loon",
149
+ "Horned Grebe",
150
+ "Red-necked Grebe",
151
+ "Eared Grebe",
152
+ "Western Grebe",
153
+ "Clark's Grebe",
154
+ "Double-crested Cormorant",
155
+ "Pelagic Cormorant",
156
+ "Great Cormorant",
157
+ "American White Pelican",
158
+ "Brown Pelican",
159
+ "Brandt's Cormorant",
160
+ "Least Bittern",
161
+ "Great Egret",
162
+ "Snowy Egret",
163
+ "Little Blue Heron",
164
+ "Tricolored Heron",
165
+ "Reddish Egret",
166
+ "Black-crowned Night-Heron",
167
+ "Yellow-crowned Night-Heron",
168
+ "White Ibis",
169
+ "Glossy Ibis",
170
+ "Roseate Spoonbill",
171
+ "Wood Stork",
172
+ "Black-bellied Whistling-Duck",
173
+ "Fulvous Whistling-Duck",
174
+ "Greater White-fronted Goose",
175
+ "Snow Goose",
176
+ "Ross's Goose",
177
+ "Canada Goose",
178
+ "Brant",
179
+ "Mute Swan",
180
+ "Tundra Swan",
181
+ "Whooper Swan",
182
+ "Sandhill Crane",
183
+ "Black-necked Stilt",
184
+ "American Avocet",
185
+ "Northern Jacana",
186
+ "Greater Yellowlegs",
187
+ "Lesser Yellowlegs",
188
+ "Willet",
189
+ "Spotted Sandpiper",
190
+ "Upland Sandpiper",
191
+ "Whimbrel",
192
+ "Long-billed Curlew",
193
+ "Marbled Godwit",
194
+ "Ruddy Turnstone",
195
+ "Red Knot",
196
+ "Sanderling",
197
+ "Semipalmated Sandpiper",
198
+ "Western Sandpiper",
199
+ "Least Sandpiper",
200
+ "White-rumped Sandpiper",
201
+ "Baird's Sandpiper",
202
+ "Pectoral Sandpiper",
203
+ "Dunlin",
204
+ "Buff-breasted Sandpiper",
205
+ "Short-billed Dowitcher",
206
+ "Long-billed Dowitcher",
207
+ "Common Snipe",
208
+ "American Woodcock",
209
+ "Wilson's Phalarope",
210
+ "Red-necked Phalarope",
211
+ "Red Phalarope"
212
+ ]
213
+
214
+ from pathlib import Path
215
+
216
+ def remove_spaces(s):
217
+ return s.replace(" ", "")
218
+
219
+ for species in SPECIES:
220
+ if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
221
+ continue
222
+ try:
223
+ q = Query(
224
+ name=species, q="A", length="10-30",
225
+ )
226
+
227
+ # retrieve metadata
228
+ metafiles = q.retrieve_meta(verbose=True)
229
+ # retrieve recordings
230
+ q.retrieve_recordings(multiprocess=True, nproc=10, attempts=10, outdir="/media/CHONK/hugo/xeno-canto-full/")
231
+
232
+ except:
233
+ print("Failed to download " + species)
234
+ continue
vampnet/mask.py CHANGED
@@ -191,29 +191,47 @@ def onset_mask(
191
  width: int = 1
192
  ):
193
  import librosa
194
-
195
- onset_indices = librosa.onset.onset_detect(
196
- y=sig.clone().to_mono().samples.cpu().numpy()[0, 0],
197
- sr=sig.sample_rate,
198
- hop_length=interface.codec.hop_length,
199
- backtrack=True,
200
- )
201
-
202
- # create a mask, set onset
203
- mask = torch.ones_like(z)
204
- n_timesteps = z.shape[-1]
205
-
206
- for onset_index in onset_indices:
207
- onset_index = min(onset_index, n_timesteps - 1)
208
- onset_index = max(onset_index, 0)
209
- mask[:, :, onset_index - width:onset_index + width] = 0.0
210
-
211
- print(mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  return mask
214
 
215
 
216
 
217
  if __name__ == "__main__":
218
- torch.set_printoptions(threshold=10000)
219
-
 
191
  width: int = 1
192
  ):
193
  import librosa
194
+ import madmom
195
+ from madmom.features.onsets import RNNOnsetProcessor, OnsetPeakPickingProcessor
196
+ import tempfile
197
+ import numpy as np
198
+
199
+ with tempfile.NamedTemporaryFile(suffix='.wav') as f:
200
+ sig = sig.clone()
201
+ sig.write(f.name)
202
+
203
+ proc = RNNOnsetProcessor(online=False)
204
+ onsetproc = OnsetPeakPickingProcessor(threshold=0.3,
205
+ fps=sig.sample_rate/interface.codec.hop_length)
206
+
207
+ act = proc(f.name)
208
+ onset_times = onsetproc(act)
209
+
210
+ # convert to indices for z array
211
+ onset_indices = librosa.time_to_frames(onset_times, sr=sig.sample_rate, hop_length=interface.codec.hop_length)
212
+
213
+ if onset_indices.shape[0] == 0:
214
+ mask = empty_mask(z)
215
+ print(f"no onsets found, returning empty mask")
216
+ else:
217
+ torch.set_printoptions(threshold=1000)
218
+ print("onset indices: ", onset_indices)
219
+ print("onset times: ", onset_times)
220
+
221
+ # create a mask, set onset
222
+ mask = torch.ones_like(z)
223
+ n_timesteps = z.shape[-1]
224
+
225
+ for onset_index in onset_indices:
226
+ onset_index = min(onset_index, n_timesteps - 1)
227
+ onset_index = max(onset_index, 0)
228
+ mask[:, :, onset_index - width:onset_index + width] = 0.0
229
+
230
+ print(mask)
231
 
232
  return mask
233
 
234
 
235
 
236
  if __name__ == "__main__":
237
+ pass
 
vampnet/modules/transformer.py CHANGED
@@ -367,15 +367,6 @@ class TransformerLayer(nn.Module):
367
 
368
  return x, position_bias, encoder_decoder_position_bias
369
 
370
- def t_schedule(n_steps, max_temp=1.0, min_temp=0.0, k=1.0):
371
- x = np.linspace(0, 1, n_steps)
372
- a = (0.5 - min_temp) / (max_temp - min_temp)
373
-
374
- x = (x * 12) - 6
375
- x0 = np.log((1 / a - 1) + 1e-5) / k
376
- y = (1 / (1 + np.exp(- k *(x-x0))))[::-1]
377
-
378
- return y
379
 
380
  class TransformerStack(nn.Module):
381
  def __init__(
@@ -598,7 +589,7 @@ class VampNet(at.ml.BaseModel):
598
  top_p=None,
599
  return_signal=True,
600
  seed: int = None,
601
- sample_cutoff: float = 0.5
602
  ):
603
  if seed is not None:
604
  at.util.seed(seed)
@@ -651,7 +642,6 @@ class VampNet(at.ml.BaseModel):
651
  #################
652
  # begin sampling #
653
  #################
654
- t_sched = t_schedule(sampling_steps, max_temp=sampling_temperature)
655
 
656
  for i in range(sampling_steps):
657
  logging.debug(f"step {i} of {sampling_steps}")
@@ -680,7 +670,7 @@ class VampNet(at.ml.BaseModel):
680
  logits, sample=(
681
  (i / sampling_steps) <= sample_cutoff
682
  ),
683
- temperature=t_sched[i],
684
  typical_filtering=typical_filtering, typical_mass=typical_mass,
685
  typical_min_tokens=typical_min_tokens,
686
  top_k=None, top_p=top_p, return_probs=True,
@@ -843,7 +833,11 @@ def sample_from_logits(
843
 
844
 
845
 
846
- def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
 
 
 
 
847
  """
848
  Args:
849
  num_to_mask (int): number of tokens to mask
@@ -856,7 +850,8 @@ def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: floa
856
  logging.debug(f"temperature: {temperature}")
857
  logging.debug("")
858
 
859
- confidence = torch.log(probs) + temperature * gumbel_noise_like(probs)
 
860
  logging.debug(f"confidence shape: {confidence.shape}")
861
 
862
  sorted_confidence, sorted_idx = confidence.sort(dim=-1)
 
367
 
368
  return x, position_bias, encoder_decoder_position_bias
369
 
 
 
 
 
 
 
 
 
 
370
 
371
  class TransformerStack(nn.Module):
372
  def __init__(
 
589
  top_p=None,
590
  return_signal=True,
591
  seed: int = None,
592
+ sample_cutoff: float = 0.5,
593
  ):
594
  if seed is not None:
595
  at.util.seed(seed)
 
642
  #################
643
  # begin sampling #
644
  #################
 
645
 
646
  for i in range(sampling_steps):
647
  logging.debug(f"step {i} of {sampling_steps}")
 
670
  logits, sample=(
671
  (i / sampling_steps) <= sample_cutoff
672
  ),
673
+ temperature=sampling_temperature,
674
  typical_filtering=typical_filtering, typical_mass=typical_mass,
675
  typical_min_tokens=typical_min_tokens,
676
  top_k=None, top_p=top_p, return_probs=True,
 
833
 
834
 
835
 
836
+ def mask_by_random_topk(
837
+ num_to_mask: int,
838
+ probs: torch.Tensor,
839
+ temperature: float = 1.0,
840
+ ):
841
  """
842
  Args:
843
  num_to_mask (int): number of tokens to mask
 
850
  logging.debug(f"temperature: {temperature}")
851
  logging.debug("")
852
 
853
+ noise = gumbel_noise_like(probs)
854
+ confidence = torch.log(probs) + temperature * noise
855
  logging.debug(f"confidence shape: {confidence.shape}")
856
 
857
  sorted_confidence, sorted_idx = confidence.sort(dim=-1)