Hugo Flores Garcia commited on
Commit
91f8638
1 Parent(s): a63cce0

towards beat tracking in the interface

Browse files

beats (squash me)

beat tracker [squash]

Files changed (3) hide show
  1. vampnet/beats.py +252 -0
  2. vampnet/interface.py +65 -4
  3. vampnet/modules/base.py +7 -0
vampnet/beats.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any
7
+ from typing import List
8
+ from typing import Tuple
9
+ from typing import Union
10
+
11
+ import librosa
12
+ import numpy as np
13
+ from audiotools import AudioSignal
14
+
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+
18
+ ###################
19
+ # beat sync utils #
20
+ ###################
21
+
22
+ AGGREGATOR_REGISTRY = {
23
+ "mean": np.mean,
24
+ "median": np.median,
25
+ "max": np.max,
26
+ "min": np.min,
27
+ }
28
+
29
+
30
+ def list_aggregators() -> list:
31
+ return list(AGGREGATOR_REGISTRY.keys())
32
+
33
+
34
+ @dataclass
35
+ class TimeSegment:
36
+ start: float
37
+ end: float
38
+
39
+ @property
40
+ def duration(self):
41
+ return self.end - self.start
42
+
43
+ def __str__(self) -> str:
44
+ return f"{self.start} - {self.end}"
45
+
46
+ def find_overlapping_segment(
47
+ self, segments: List["TimeSegment"]
48
+ ) -> Union["TimeSegment", None]:
49
+ """Find the first segment that overlaps with this segment, or None if no segment overlaps"""
50
+ for s in segments:
51
+ if s.start <= self.start and s.end >= self.end:
52
+ return s
53
+ return None
54
+
55
+
56
+ def mkdir(path: Union[Path, str]) -> Path:
57
+ p = Path(path)
58
+ p.mkdir(parents=True, exist_ok=True)
59
+ return p
60
+
61
+
62
+
63
+ ###################
64
+ # beat data #
65
+ ###################
66
+ @dataclass
67
+ class BeatSegment(TimeSegment):
68
+ downbeat: bool = False # if there's a downbeat on the start_time
69
+
70
+
71
+ class Beats:
72
+ def __init__(self, beat_times, downbeat_times):
73
+ if isinstance(beat_times, np.ndarray):
74
+ beat_times = beat_times.tolist()
75
+ if isinstance(downbeat_times, np.ndarray):
76
+ downbeat_times = downbeat_times.tolist()
77
+ self._beat_times = beat_times
78
+ self._downbeat_times = downbeat_times
79
+ self._use_downbeats = False
80
+
81
+ def use_downbeats(self, use_downbeats: bool = True):
82
+ """use downbeats instead of beats when calling beat_times"""
83
+ self._use_downbeats = use_downbeats
84
+
85
+ def beat_segments(self, signal: AudioSignal) -> List[BeatSegment]:
86
+ """
87
+ segments a song into time segments corresponding to beats.
88
+ the first segment starts at 0 and ends at the first beat time.
89
+ the last segment starts at the last beat time and ends at the end of the song.
90
+ """
91
+ beat_times = self._beat_times.copy()
92
+ downbeat_times = self._downbeat_times
93
+ beat_times.insert(0, 0)
94
+ beat_times.append(signal.signal_duration)
95
+
96
+ downbeat_ids = np.intersect1d(beat_times, downbeat_times, return_indices=True)[
97
+ 1
98
+ ]
99
+ is_downbeat = [
100
+ True if i in downbeat_ids else False for i in range(len(beat_times))
101
+ ]
102
+ segments = [
103
+ BeatSegment(start_time, end_time, downbeat)
104
+ for start_time, end_time, downbeat in zip(
105
+ beat_times[:-1], beat_times[1:], is_downbeat
106
+ )
107
+ ]
108
+ return segments
109
+
110
+ def get_beats(self) -> np.ndarray:
111
+ """returns an array of beat times, in seconds
112
+ if downbeats is True, returns an array of downbeat times, in seconds
113
+ """
114
+ return np.array(
115
+ self._downbeat_times if self._use_downbeats else self._beat_times
116
+ )
117
+
118
+ @property
119
+ def beat_times(self) -> np.ndarray:
120
+ """return beat times"""
121
+ return np.array(self._beat_times)
122
+
123
+ @property
124
+ def downbeat_times(self) -> np.ndarray:
125
+ """return downbeat times"""
126
+ return np.array(self._downbeat_times)
127
+
128
+ def beat_times_to_feature_frames(
129
+ self, signal: AudioSignal, features: np.ndarray
130
+ ) -> np.ndarray:
131
+ """convert beat times to frames, given an array of time-varying features"""
132
+ beat_times = self.get_beats()
133
+ beat_frames = (
134
+ beat_times * signal.sample_rate / signal.signal_length * features.shape[-1]
135
+ ).astype(np.int64)
136
+ return beat_frames
137
+
138
+ def sync_features(
139
+ self, feature_frames: np.ndarray, features: np.ndarray, aggregate="median"
140
+ ) -> np.ndarray:
141
+ """sync features to beats"""
142
+ if aggregate not in AGGREGATOR_REGISTRY:
143
+ raise ValueError(f"unknown aggregation method {aggregate}")
144
+
145
+ return librosa.util.sync(
146
+ features, feature_frames, aggregate=AGGREGATOR_REGISTRY[aggregate]
147
+ )
148
+
149
+ def to_json(self) -> dict:
150
+ """return beats and downbeats as json"""
151
+ return {
152
+ "beats": self._beat_times,
153
+ "downbeats": self._downbeat_times,
154
+ "use_downbeats": self._use_downbeats,
155
+ }
156
+
157
+ @classmethod
158
+ def from_dict(cls, data: dict):
159
+ """load beats and downbeats from json"""
160
+ inst = cls(data["beats"], data["downbeats"])
161
+ inst.use_downbeats(data["use_downbeats"])
162
+ return inst
163
+
164
+ def save(self, output_dir: Path):
165
+ """save beats and downbeats to json"""
166
+ mkdir(output_dir)
167
+ with open(output_dir / "beats.json", "w") as f:
168
+ json.dump(self.to_json(), f)
169
+
170
+ @classmethod
171
+ def load(cls, input_dir: Path):
172
+ """load beats and downbeats from json"""
173
+ beats_file = Path(input_dir) / "beats.json"
174
+ with open(beats_file, "r") as f:
175
+ data = json.load(f)
176
+ return cls.from_dict(data)
177
+
178
+
179
+ ###################
180
+ # beat tracking #
181
+ ###################
182
+
183
+
184
+ class BeatTracker:
185
+ def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
186
+ """extract beats from an audio signal"""
187
+ raise NotImplementedError
188
+
189
+ def __call__(self, signal: AudioSignal) -> Beats:
190
+ """extract beats from an audio signal
191
+ NOTE: if the first beat (and/or downbeat) is detected within the first 100ms of the audio,
192
+ it is discarded. This is to avoid empty bins with no beat synced features in the first beat.
193
+ Args:
194
+ signal (AudioSignal): signal to beat track
195
+ Returns:
196
+ Tuple[np.ndarray, np.ndarray]: beats and downbeats
197
+ """
198
+ beats, downbeats = self.extract_beats(signal)
199
+ return Beats(beats, downbeats)
200
+
201
+
202
+ class WaveBeat(BeatTracker):
203
+ def __init__(self, ckpt_dir: str = "checkpoints/wavebeat", device: str = "cpu"):
204
+ from wavebeat.dstcn import dsTCNModel
205
+
206
+ ckpts = list((ckpt_dir).glob("*.ckpt"))
207
+ assert len(ckpts) > 0, f"no checkpoints found for wavebeat in {ckpt_dir}"
208
+
209
+ model = dsTCNModel.load_from_checkpoint(ckpts[-1])
210
+ model.eval()
211
+
212
+ self.device = device
213
+ self.model = model
214
+
215
+ def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
216
+ """returns beat and downbeat times, in seconds"""
217
+ # extract beats
218
+ beats, downbeats = self.model.predict_beats_from_array(
219
+ audio=signal.audio_data.squeeze(0),
220
+ sr=signal.sample_rate,
221
+ use_gpu=self.device is not "cpu",
222
+ )
223
+
224
+ return beats, downbeats
225
+
226
+
227
+ class MadmomBeats(BeatTracker):
228
+ def __init__(self):
229
+ raise NotImplementedError
230
+
231
+ def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
232
+ """returns beat and downbeat times, in seconds"""
233
+ pass
234
+
235
+
236
+ BEAT_TRACKER_REGISTRY = {
237
+ "wavebeat": WaveBeat,
238
+ "madmom": MadmomBeats,
239
+ }
240
+
241
+
242
+ def list_beat_trackers() -> list:
243
+ return list(BEAT_TRACKER_REGISTRY.keys())
244
+
245
+
246
+ def load_beat_tracker(beat_tracker: str, **kwargs) -> BeatTracker:
247
+ if beat_tracker not in BEAT_TRACKER_REGISTRY:
248
+ raise ValueError(
249
+ f"Unknown beat tracker {beat_tracker}. Available: {list_beat_trackers()}"
250
+ )
251
+
252
+ return BEAT_TRACKER_REGISTRY[beat_tracker](**kwargs)
vampnet/interface.py CHANGED
@@ -7,10 +7,10 @@ from audiotools import AudioSignal
7
  import tqdm
8
 
9
  from .modules.transformer import VampNet
 
10
  from lac.model.lac import LAC
11
 
12
 
13
-
14
  def signal_concat(
15
  audio_signals: list,
16
  ):
@@ -83,12 +83,72 @@ class Interface:
83
  .ensure_max_of_audio(1.0)
84
  )
85
  return signal
 
86
  @torch.inference_mode()
87
  def encode(self, signal: AudioSignal):
88
  signal = self.preprocess(signal).to(self.device)
89
  z = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
90
  return z
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def coarse_to_fine(
93
  self,
94
  coarse_z: torch.Tensor,
@@ -231,7 +291,8 @@ class Interface:
231
  downsample_factor: int = None,
232
  intensity: float = 1.0,
233
  debug=False,
234
- swap_prefix_suffix=False,
 
235
  **kwargs
236
  ):
237
  z = self.encode(signal)
@@ -265,7 +326,8 @@ class Interface:
265
  n_prefix=n_prefix,
266
  n_suffix=n_suffix,
267
  downsample_factor=downsample_factor,
268
- mask=cz_mask
 
269
  )
270
  if debug:
271
  print("tokens to infer")
@@ -415,7 +477,6 @@ class Interface:
415
  output.truncate_samples(original_length)
416
  return output
417
 
418
-
419
  # create a loop of a single region with variations
420
  # TODO: this would work nicer if we could trim at the beat
421
  # otherwise the model has to awkwardly fill up space that won't match
 
7
  import tqdm
8
 
9
  from .modules.transformer import VampNet
10
+ from .beats import WaveBeat
11
  from lac.model.lac import LAC
12
 
13
 
 
14
  def signal_concat(
15
  audio_signals: list,
16
  ):
 
83
  .ensure_max_of_audio(1.0)
84
  )
85
  return signal
86
+
87
  @torch.inference_mode()
88
  def encode(self, signal: AudioSignal):
89
  signal = self.preprocess(signal).to(self.device)
90
  z = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
91
  return z
92
 
93
+ def make_beat_mask(self,
94
+ signal: AudioSignal,
95
+ before_beat_s: float = 0.1,
96
+ after_beat_s: float = 0.1,
97
+ mask_downbeats: float = 0.1,
98
+ mask_upbeats: float = 0.1,
99
+ downbeat_downsample_factor: int = None,
100
+ beat_downsample_factor: int = None,
101
+ invert: bool = False,
102
+ ):
103
+ """make a beat synced mask. that is, make a mask that
104
+ places 1s at and around the beat, and 0s everywhere else.
105
+ """
106
+ assert hasattr(self, "beat_tracker"), "No beat tracker loaded"
107
+
108
+ # get the beat times
109
+ beats, downbeats = self.beat_tracker.extract_beats(signal)
110
+
111
+ # get the beat indices in z
112
+ beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
113
+
114
+ # remove downbeats from beats
115
+ beats_z = beats_z[~torch.isin(beats_z, downbeats_z)]
116
+
117
+ # make the mask
118
+ seq_len = self.s2t(signal.duration)
119
+ mask = torch.zeros(seq_len, device=self.device)
120
+
121
+ mask_b4 = self.s2t(before_beat_s)
122
+ mask_after = self.s2t(after_beat_s)
123
+
124
+ if beat_downsample_factor is not None:
125
+ if beat_downsample_factor < 1:
126
+ raise ValueError("mask_beat_downsample_factor must be >= 1 or None")
127
+ else:
128
+ beat_downsample_factor = 1
129
+
130
+ if downbeat_downsample_factor is not None:
131
+ if downbeat_downsample_factor < 1:
132
+ raise ValueError("mask_beat_downsample_factor must be >= 1 or None")
133
+ else:
134
+ downbeat_downsample_factor = 1
135
+
136
+ beats_z = beats_z[::beat_downsample_factor]
137
+ downbeats_z = downbeats_z[::downbeat_downsample_factor]
138
+
139
+ if mask_upbeats:
140
+ for beat_idx in beats_z:
141
+ mask[beat_idx - mask_b4:beat_idx + mask_after] = 1
142
+
143
+ if mask_downbeats:
144
+ for downbeat_idx in downbeats_z:
145
+ mask[downbeat_idx - mask_b4:downbeat_idx + mask_after] = 1
146
+
147
+ if invert:
148
+ mask = 1 - mask
149
+
150
+ return mask
151
+
152
  def coarse_to_fine(
153
  self,
154
  coarse_z: torch.Tensor,
 
291
  downsample_factor: int = None,
292
  intensity: float = 1.0,
293
  debug=False,
294
+ swap_prefix_suffix=False,
295
+ ext_mask=None,
296
  **kwargs
297
  ):
298
  z = self.encode(signal)
 
326
  n_prefix=n_prefix,
327
  n_suffix=n_suffix,
328
  downsample_factor=downsample_factor,
329
+ mask=cz_mask,
330
+ ext_mask=ext_mask
331
  )
332
  if debug:
333
  print("tokens to infer")
 
477
  output.truncate_samples(original_length)
478
  return output
479
 
 
480
  # create a loop of a single region with variations
481
  # TODO: this would work nicer if we could trim at the beat
482
  # otherwise the model has to awkwardly fill up space that won't match
vampnet/modules/base.py CHANGED
@@ -42,6 +42,7 @@ class VampBase(at.ml.BaseModel):
42
  r: torch.Tensor,
43
  random_x: Optional[torch.Tensor] = None,
44
  mask: Optional[torch.Tensor] = None,
 
45
  n_prefix: Optional[torch.Tensor] = None,
46
  n_suffix: Optional[torch.Tensor] = None,
47
  downsample_factor: Optional[int] = None,
@@ -99,6 +100,12 @@ class VampBase(at.ml.BaseModel):
99
  else:
100
  raise ValueError(f"invalid noise mode {self.noise_mode}")
101
 
 
 
 
 
 
 
102
  x = x * (1 - mask) + random_x * mask
103
  return x, mask
104
 
 
42
  r: torch.Tensor,
43
  random_x: Optional[torch.Tensor] = None,
44
  mask: Optional[torch.Tensor] = None,
45
+ ext_mask: Optional[torch.Tensor] = None,
46
  n_prefix: Optional[torch.Tensor] = None,
47
  n_suffix: Optional[torch.Tensor] = None,
48
  downsample_factor: Optional[int] = None,
 
100
  else:
101
  raise ValueError(f"invalid noise mode {self.noise_mode}")
102
 
103
+ # add the external mask if we were given one
104
+ if ext_mask is not None:
105
+ assert ext_mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
106
+ assert ext_mask.shape == x.shape, "mask must be same shape as x"
107
+ mask = (mask + ext_mask).bool().long()
108
+
109
  x = x * (1 - mask) + random_x * mask
110
  return x, mask
111