Spaces:
Sleeping
Sleeping
Hugo Flores
commited on
Commit
·
5a0a80a
1
Parent(s):
91f8638
beat tracker bugfixes
Browse files- requirements.txt +2 -1
- vampnet/beats.py +2 -5
- vampnet/interface.py +41 -10
- vampnet/modules/base.py +1 -2
requirements.txt
CHANGED
@@ -2,7 +2,8 @@ argbind>=0.3.1
|
|
2 |
pytorch-ignite
|
3 |
rich
|
4 |
audiotools @ git+https://github.com/descriptinc/lyrebird-audiotools.git@hf/backup-info
|
5 |
-
lac @ git+https://github.com/descriptinc/lyrebird-audio-codec.git@
|
|
|
6 |
tqdm
|
7 |
tensorboard
|
8 |
google-cloud-logging==2.2.0
|
|
|
2 |
pytorch-ignite
|
3 |
rich
|
4 |
audiotools @ git+https://github.com/descriptinc/lyrebird-audiotools.git@hf/backup-info
|
5 |
+
lac @ git+https://github.com/descriptinc/lyrebird-audio-codec.git@hf/vampnet-temp
|
6 |
+
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat.git
|
7 |
tqdm
|
8 |
tensorboard
|
9 |
google-cloud-logging==2.2.0
|
vampnet/beats.py
CHANGED
@@ -200,13 +200,10 @@ class BeatTracker:
|
|
200 |
|
201 |
|
202 |
class WaveBeat(BeatTracker):
|
203 |
-
def __init__(self,
|
204 |
from wavebeat.dstcn import dsTCNModel
|
205 |
|
206 |
-
|
207 |
-
assert len(ckpts) > 0, f"no checkpoints found for wavebeat in {ckpt_dir}"
|
208 |
-
|
209 |
-
model = dsTCNModel.load_from_checkpoint(ckpts[-1])
|
210 |
model.eval()
|
211 |
|
212 |
self.device = device
|
|
|
200 |
|
201 |
|
202 |
class WaveBeat(BeatTracker):
|
203 |
+
def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
|
204 |
from wavebeat.dstcn import dsTCNModel
|
205 |
|
206 |
+
model = dsTCNModel.load_from_checkpoint(ckpt_path)
|
|
|
|
|
|
|
207 |
model.eval()
|
208 |
|
209 |
self.device = device
|
vampnet/interface.py
CHANGED
@@ -3,6 +3,7 @@ from pathlib import Path
|
|
3 |
import math
|
4 |
|
5 |
import torch
|
|
|
6 |
from audiotools import AudioSignal
|
7 |
import tqdm
|
8 |
|
@@ -50,7 +51,10 @@ class Interface:
|
|
50 |
|
51 |
def s2t(self, seconds: float):
|
52 |
"""seconds to tokens"""
|
53 |
-
|
|
|
|
|
|
|
54 |
|
55 |
def s2t2s(self, seconds: float):
|
56 |
"""seconds to tokens to seconds"""
|
@@ -94,11 +98,12 @@ class Interface:
|
|
94 |
signal: AudioSignal,
|
95 |
before_beat_s: float = 0.1,
|
96 |
after_beat_s: float = 0.1,
|
97 |
-
mask_downbeats:
|
98 |
-
mask_upbeats:
|
99 |
downbeat_downsample_factor: int = None,
|
100 |
beat_downsample_factor: int = None,
|
101 |
-
|
|
|
102 |
):
|
103 |
"""make a beat synced mask. that is, make a mask that
|
104 |
places 1s at and around the beat, and 0s everywhere else.
|
@@ -112,7 +117,9 @@ class Interface:
|
|
112 |
beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
|
113 |
|
114 |
# remove downbeats from beats
|
115 |
-
beats_z = beats_z[~torch.isin(beats_z, downbeats_z)]
|
|
|
|
|
116 |
|
117 |
# make the mask
|
118 |
seq_len = self.s2t(signal.duration)
|
@@ -138,16 +145,26 @@ class Interface:
|
|
138 |
|
139 |
if mask_upbeats:
|
140 |
for beat_idx in beats_z:
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
if mask_downbeats:
|
144 |
for downbeat_idx in downbeats_z:
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
if invert:
|
148 |
mask = 1 - mask
|
149 |
|
150 |
-
return mask
|
151 |
|
152 |
def coarse_to_fine(
|
153 |
self,
|
@@ -293,6 +310,7 @@ class Interface:
|
|
293 |
debug=False,
|
294 |
swap_prefix_suffix=False,
|
295 |
ext_mask=None,
|
|
|
296 |
**kwargs
|
297 |
):
|
298 |
z = self.encode(signal)
|
@@ -319,7 +337,8 @@ class Interface:
|
|
319 |
|
320 |
_cz = cz.clone()
|
321 |
cz_mask = None
|
322 |
-
|
|
|
323 |
# add noise
|
324 |
cz_masked, cz_mask = self.coarse.add_noise(
|
325 |
_cz, r=1.0-intensity,
|
@@ -428,8 +447,9 @@ class Interface:
|
|
428 |
def variation(
|
429 |
self,
|
430 |
signal: AudioSignal,
|
431 |
-
overlap_hop_ratio: float = 1.0, # TODO: should this be fixed to 1.0? or should we overlap and replace instead of overlap add
|
432 |
verbose: bool = False,
|
|
|
|
|
433 |
**kwargs
|
434 |
):
|
435 |
signal = signal.clone()
|
@@ -442,6 +462,9 @@ class Interface:
|
|
442 |
math.ceil(signal.duration / self.coarse.chunk_size_s)
|
443 |
* self.coarse.chunk_size_s
|
444 |
)
|
|
|
|
|
|
|
445 |
hop_duration = self.coarse.chunk_size_s * overlap_hop_ratio
|
446 |
original_length = signal.length
|
447 |
|
@@ -460,10 +483,18 @@ class Interface:
|
|
460 |
signal.samples[i,...], signal.sample_rate
|
461 |
)
|
462 |
sig.to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
out_z = self.coarse_vamp_v2(
|
464 |
sig,
|
465 |
num_vamps=1,
|
466 |
swap_prefix_suffix=False,
|
|
|
|
|
467 |
**kwargs
|
468 |
)
|
469 |
if self.c2f is not None:
|
|
|
3 |
import math
|
4 |
|
5 |
import torch
|
6 |
+
import numpy as np
|
7 |
from audiotools import AudioSignal
|
8 |
import tqdm
|
9 |
|
|
|
51 |
|
52 |
def s2t(self, seconds: float):
|
53 |
"""seconds to tokens"""
|
54 |
+
if isinstance(seconds, np.ndarray):
|
55 |
+
return np.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
|
56 |
+
else:
|
57 |
+
return math.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
|
58 |
|
59 |
def s2t2s(self, seconds: float):
|
60 |
"""seconds to tokens to seconds"""
|
|
|
98 |
signal: AudioSignal,
|
99 |
before_beat_s: float = 0.1,
|
100 |
after_beat_s: float = 0.1,
|
101 |
+
mask_downbeats: bool = True,
|
102 |
+
mask_upbeats: bool = True,
|
103 |
downbeat_downsample_factor: int = None,
|
104 |
beat_downsample_factor: int = None,
|
105 |
+
dropout: float = 0.7,
|
106 |
+
invert: bool = True,
|
107 |
):
|
108 |
"""make a beat synced mask. that is, make a mask that
|
109 |
places 1s at and around the beat, and 0s everywhere else.
|
|
|
117 |
beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
|
118 |
|
119 |
# remove downbeats from beats
|
120 |
+
beats_z = torch.tensor(beats_z)[~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))]
|
121 |
+
beats_z = beats_z.tolist()
|
122 |
+
downbeats_z = downbeats_z.tolist()
|
123 |
|
124 |
# make the mask
|
125 |
seq_len = self.s2t(signal.duration)
|
|
|
145 |
|
146 |
if mask_upbeats:
|
147 |
for beat_idx in beats_z:
|
148 |
+
_slice = int(beat_idx - mask_b4), int(beat_idx + mask_after)
|
149 |
+
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
150 |
+
_m = torch.ones(num_steps, device=self.device)
|
151 |
+
_m = torch.nn.functional.dropout(_m, p=dropout)
|
152 |
+
|
153 |
+
mask[_slice[0]:_slice[1]] = _m
|
154 |
|
155 |
if mask_downbeats:
|
156 |
for downbeat_idx in downbeats_z:
|
157 |
+
_slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after)
|
158 |
+
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
159 |
+
_m = torch.ones(num_steps, device=self.device)
|
160 |
+
_m = torch.nn.functional.dropout(_m, p=dropout)
|
161 |
+
|
162 |
+
mask[_slice[0]:_slice[1]] = _m
|
163 |
|
164 |
if invert:
|
165 |
mask = 1 - mask
|
166 |
|
167 |
+
return mask[None, None, :].bool().long()
|
168 |
|
169 |
def coarse_to_fine(
|
170 |
self,
|
|
|
310 |
debug=False,
|
311 |
swap_prefix_suffix=False,
|
312 |
ext_mask=None,
|
313 |
+
verbose=False,
|
314 |
**kwargs
|
315 |
):
|
316 |
z = self.encode(signal)
|
|
|
337 |
|
338 |
_cz = cz.clone()
|
339 |
cz_mask = None
|
340 |
+
range_fn = tqdm.trange if verbose else range
|
341 |
+
for _ in range_fn(num_vamps):
|
342 |
# add noise
|
343 |
cz_masked, cz_mask = self.coarse.add_noise(
|
344 |
_cz, r=1.0-intensity,
|
|
|
447 |
def variation(
|
448 |
self,
|
449 |
signal: AudioSignal,
|
|
|
450 |
verbose: bool = False,
|
451 |
+
beat_mask: bool = False,
|
452 |
+
beat_mask_kwargs: dict = {},
|
453 |
**kwargs
|
454 |
):
|
455 |
signal = signal.clone()
|
|
|
462 |
math.ceil(signal.duration / self.coarse.chunk_size_s)
|
463 |
* self.coarse.chunk_size_s
|
464 |
)
|
465 |
+
# eventually we DO want overlap, but we want overlap-replace not
|
466 |
+
# overlap-add
|
467 |
+
overlap_hop_ratio = 1.0
|
468 |
hop_duration = self.coarse.chunk_size_s * overlap_hop_ratio
|
469 |
original_length = signal.length
|
470 |
|
|
|
483 |
signal.samples[i,...], signal.sample_rate
|
484 |
)
|
485 |
sig.to(self.device)
|
486 |
+
|
487 |
+
if beat_mask:
|
488 |
+
ext_mask = self.make_beat_mask(sig, **beat_mask_kwargs)
|
489 |
+
else:
|
490 |
+
ext_mask = None
|
491 |
+
|
492 |
out_z = self.coarse_vamp_v2(
|
493 |
sig,
|
494 |
num_vamps=1,
|
495 |
swap_prefix_suffix=False,
|
496 |
+
ext_mask=ext_mask,
|
497 |
+
verbose=verbose,
|
498 |
**kwargs
|
499 |
)
|
500 |
if self.c2f is not None:
|
vampnet/modules/base.py
CHANGED
@@ -103,8 +103,7 @@ class VampBase(at.ml.BaseModel):
|
|
103 |
# add the external mask if we were given one
|
104 |
if ext_mask is not None:
|
105 |
assert ext_mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
|
106 |
-
|
107 |
-
mask = (mask + ext_mask).bool().long()
|
108 |
|
109 |
x = x * (1 - mask) + random_x * mask
|
110 |
return x, mask
|
|
|
103 |
# add the external mask if we were given one
|
104 |
if ext_mask is not None:
|
105 |
assert ext_mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
|
106 |
+
mask = (mask * ext_mask).bool().long()
|
|
|
107 |
|
108 |
x = x * (1 - mask) + random_x * mask
|
109 |
return x, mask
|