Hugo Flores commited on
Commit
5a0a80a
1 Parent(s): 91f8638

beat tracker bugfixes

Browse files
requirements.txt CHANGED
@@ -2,7 +2,8 @@ argbind>=0.3.1
2
  pytorch-ignite
3
  rich
4
  audiotools @ git+https://github.com/descriptinc/lyrebird-audiotools.git@hf/backup-info
5
- lac @ git+https://github.com/descriptinc/lyrebird-audio-codec.git@main
 
6
  tqdm
7
  tensorboard
8
  google-cloud-logging==2.2.0
 
2
  pytorch-ignite
3
  rich
4
  audiotools @ git+https://github.com/descriptinc/lyrebird-audiotools.git@hf/backup-info
5
+ lac @ git+https://github.com/descriptinc/lyrebird-audio-codec.git@hf/vampnet-temp
6
+ wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat.git
7
  tqdm
8
  tensorboard
9
  google-cloud-logging==2.2.0
vampnet/beats.py CHANGED
@@ -200,13 +200,10 @@ class BeatTracker:
200
 
201
 
202
  class WaveBeat(BeatTracker):
203
- def __init__(self, ckpt_dir: str = "checkpoints/wavebeat", device: str = "cpu"):
204
  from wavebeat.dstcn import dsTCNModel
205
 
206
- ckpts = list((ckpt_dir).glob("*.ckpt"))
207
- assert len(ckpts) > 0, f"no checkpoints found for wavebeat in {ckpt_dir}"
208
-
209
- model = dsTCNModel.load_from_checkpoint(ckpts[-1])
210
  model.eval()
211
 
212
  self.device = device
 
200
 
201
 
202
  class WaveBeat(BeatTracker):
203
+ def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
204
  from wavebeat.dstcn import dsTCNModel
205
 
206
+ model = dsTCNModel.load_from_checkpoint(ckpt_path)
 
 
 
207
  model.eval()
208
 
209
  self.device = device
vampnet/interface.py CHANGED
@@ -3,6 +3,7 @@ from pathlib import Path
3
  import math
4
 
5
  import torch
 
6
  from audiotools import AudioSignal
7
  import tqdm
8
 
@@ -50,7 +51,10 @@ class Interface:
50
 
51
  def s2t(self, seconds: float):
52
  """seconds to tokens"""
53
- return math.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
 
 
 
54
 
55
  def s2t2s(self, seconds: float):
56
  """seconds to tokens to seconds"""
@@ -94,11 +98,12 @@ class Interface:
94
  signal: AudioSignal,
95
  before_beat_s: float = 0.1,
96
  after_beat_s: float = 0.1,
97
- mask_downbeats: float = 0.1,
98
- mask_upbeats: float = 0.1,
99
  downbeat_downsample_factor: int = None,
100
  beat_downsample_factor: int = None,
101
- invert: bool = False,
 
102
  ):
103
  """make a beat synced mask. that is, make a mask that
104
  places 1s at and around the beat, and 0s everywhere else.
@@ -112,7 +117,9 @@ class Interface:
112
  beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
113
 
114
  # remove downbeats from beats
115
- beats_z = beats_z[~torch.isin(beats_z, downbeats_z)]
 
 
116
 
117
  # make the mask
118
  seq_len = self.s2t(signal.duration)
@@ -138,16 +145,26 @@ class Interface:
138
 
139
  if mask_upbeats:
140
  for beat_idx in beats_z:
141
- mask[beat_idx - mask_b4:beat_idx + mask_after] = 1
 
 
 
 
 
142
 
143
  if mask_downbeats:
144
  for downbeat_idx in downbeats_z:
145
- mask[downbeat_idx - mask_b4:downbeat_idx + mask_after] = 1
 
 
 
 
 
146
 
147
  if invert:
148
  mask = 1 - mask
149
 
150
- return mask
151
 
152
  def coarse_to_fine(
153
  self,
@@ -293,6 +310,7 @@ class Interface:
293
  debug=False,
294
  swap_prefix_suffix=False,
295
  ext_mask=None,
 
296
  **kwargs
297
  ):
298
  z = self.encode(signal)
@@ -319,7 +337,8 @@ class Interface:
319
 
320
  _cz = cz.clone()
321
  cz_mask = None
322
- for _ in range(num_vamps):
 
323
  # add noise
324
  cz_masked, cz_mask = self.coarse.add_noise(
325
  _cz, r=1.0-intensity,
@@ -428,8 +447,9 @@ class Interface:
428
  def variation(
429
  self,
430
  signal: AudioSignal,
431
- overlap_hop_ratio: float = 1.0, # TODO: should this be fixed to 1.0? or should we overlap and replace instead of overlap add
432
  verbose: bool = False,
 
 
433
  **kwargs
434
  ):
435
  signal = signal.clone()
@@ -442,6 +462,9 @@ class Interface:
442
  math.ceil(signal.duration / self.coarse.chunk_size_s)
443
  * self.coarse.chunk_size_s
444
  )
 
 
 
445
  hop_duration = self.coarse.chunk_size_s * overlap_hop_ratio
446
  original_length = signal.length
447
 
@@ -460,10 +483,18 @@ class Interface:
460
  signal.samples[i,...], signal.sample_rate
461
  )
462
  sig.to(self.device)
 
 
 
 
 
 
463
  out_z = self.coarse_vamp_v2(
464
  sig,
465
  num_vamps=1,
466
  swap_prefix_suffix=False,
 
 
467
  **kwargs
468
  )
469
  if self.c2f is not None:
 
3
  import math
4
 
5
  import torch
6
+ import numpy as np
7
  from audiotools import AudioSignal
8
  import tqdm
9
 
 
51
 
52
  def s2t(self, seconds: float):
53
  """seconds to tokens"""
54
+ if isinstance(seconds, np.ndarray):
55
+ return np.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
56
+ else:
57
+ return math.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
58
 
59
  def s2t2s(self, seconds: float):
60
  """seconds to tokens to seconds"""
 
98
  signal: AudioSignal,
99
  before_beat_s: float = 0.1,
100
  after_beat_s: float = 0.1,
101
+ mask_downbeats: bool = True,
102
+ mask_upbeats: bool = True,
103
  downbeat_downsample_factor: int = None,
104
  beat_downsample_factor: int = None,
105
+ dropout: float = 0.7,
106
+ invert: bool = True,
107
  ):
108
  """make a beat synced mask. that is, make a mask that
109
  places 1s at and around the beat, and 0s everywhere else.
 
117
  beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
118
 
119
  # remove downbeats from beats
120
+ beats_z = torch.tensor(beats_z)[~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))]
121
+ beats_z = beats_z.tolist()
122
+ downbeats_z = downbeats_z.tolist()
123
 
124
  # make the mask
125
  seq_len = self.s2t(signal.duration)
 
145
 
146
  if mask_upbeats:
147
  for beat_idx in beats_z:
148
+ _slice = int(beat_idx - mask_b4), int(beat_idx + mask_after)
149
+ num_steps = mask[_slice[0]:_slice[1]].shape[0]
150
+ _m = torch.ones(num_steps, device=self.device)
151
+ _m = torch.nn.functional.dropout(_m, p=dropout)
152
+
153
+ mask[_slice[0]:_slice[1]] = _m
154
 
155
  if mask_downbeats:
156
  for downbeat_idx in downbeats_z:
157
+ _slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after)
158
+ num_steps = mask[_slice[0]:_slice[1]].shape[0]
159
+ _m = torch.ones(num_steps, device=self.device)
160
+ _m = torch.nn.functional.dropout(_m, p=dropout)
161
+
162
+ mask[_slice[0]:_slice[1]] = _m
163
 
164
  if invert:
165
  mask = 1 - mask
166
 
167
+ return mask[None, None, :].bool().long()
168
 
169
  def coarse_to_fine(
170
  self,
 
310
  debug=False,
311
  swap_prefix_suffix=False,
312
  ext_mask=None,
313
+ verbose=False,
314
  **kwargs
315
  ):
316
  z = self.encode(signal)
 
337
 
338
  _cz = cz.clone()
339
  cz_mask = None
340
+ range_fn = tqdm.trange if verbose else range
341
+ for _ in range_fn(num_vamps):
342
  # add noise
343
  cz_masked, cz_mask = self.coarse.add_noise(
344
  _cz, r=1.0-intensity,
 
447
  def variation(
448
  self,
449
  signal: AudioSignal,
 
450
  verbose: bool = False,
451
+ beat_mask: bool = False,
452
+ beat_mask_kwargs: dict = {},
453
  **kwargs
454
  ):
455
  signal = signal.clone()
 
462
  math.ceil(signal.duration / self.coarse.chunk_size_s)
463
  * self.coarse.chunk_size_s
464
  )
465
+ # eventually we DO want overlap, but we want overlap-replace not
466
+ # overlap-add
467
+ overlap_hop_ratio = 1.0
468
  hop_duration = self.coarse.chunk_size_s * overlap_hop_ratio
469
  original_length = signal.length
470
 
 
483
  signal.samples[i,...], signal.sample_rate
484
  )
485
  sig.to(self.device)
486
+
487
+ if beat_mask:
488
+ ext_mask = self.make_beat_mask(sig, **beat_mask_kwargs)
489
+ else:
490
+ ext_mask = None
491
+
492
  out_z = self.coarse_vamp_v2(
493
  sig,
494
  num_vamps=1,
495
  swap_prefix_suffix=False,
496
+ ext_mask=ext_mask,
497
+ verbose=verbose,
498
  **kwargs
499
  )
500
  if self.c2f is not None:
vampnet/modules/base.py CHANGED
@@ -103,8 +103,7 @@ class VampBase(at.ml.BaseModel):
103
  # add the external mask if we were given one
104
  if ext_mask is not None:
105
  assert ext_mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
106
- assert ext_mask.shape == x.shape, "mask must be same shape as x"
107
- mask = (mask + ext_mask).bool().long()
108
 
109
  x = x * (1 - mask) + random_x * mask
110
  return x, mask
 
103
  # add the external mask if we were given one
104
  if ext_mask is not None:
105
  assert ext_mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
106
+ mask = (mask * ext_mask).bool().long()
 
107
 
108
  x = x * (1 - mask) + random_x * mask
109
  return x, mask