| | """Beat/kick detection using madmom's RNN beat tracker.""" |
| |
|
| | import json |
| | import subprocess |
| | import tempfile |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import numpy as np |
| | from madmom.features.beats import DBNBeatTrackingProcessor, RNNBeatProcessor |
| |
|
| | |
| | HIGHPASS_CUTOFF = 50 |
| | LOWPASS_CUTOFF = 500 |
| |
|
| |
|
| | def _bandpass_filter(input_path: Path) -> Path: |
| | """Apply a 50-200 Hz bandpass filter to isolate kick drum transients. |
| | |
| | Returns path to a temporary filtered WAV file. |
| | """ |
| | filtered = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) |
| | filtered.close() |
| | subprocess.run([ |
| | "ffmpeg", "-y", |
| | "-i", str(input_path), |
| | "-af", f"highpass=f={HIGHPASS_CUTOFF},lowpass=f={LOWPASS_CUTOFF}", |
| | str(filtered.name), |
| | ], check=True, capture_output=True) |
| | return Path(filtered.name) |
| |
|
| |
|
| | def detect_beats( |
| | drum_stem_path: str | Path, |
| | min_bpm: float = 55.0, |
| | max_bpm: float = 215.0, |
| | transition_lambda: float = 100, |
| | fps: int = 1000, |
| | ) -> np.ndarray: |
| | """Detect beat timestamps from a drum stem using madmom. |
| | |
| | Uses an ensemble of bidirectional LSTMs to produce a beat activation |
| | function, then a Dynamic Bayesian Network to decode beat positions. |
| | |
| | Args: |
| | drum_stem_path: Path to the isolated drum stem WAV file. |
| | min_bpm: Minimum expected tempo. Narrow this if you know the song's |
| | approximate BPM for better accuracy. |
| | max_bpm: Maximum expected tempo. |
| | transition_lambda: Tempo smoothness — higher values penalise tempo |
| | changes more (100 = very steady, good for most pop/rock). |
| | fps: Frames per second for the DBN decoder. The RNN outputs at 100fps; |
| | higher values interpolate for finer timestamp resolution (1ms at 1000fps). |
| | |
| | Returns: |
| | 1D numpy array of beat timestamps in seconds, sorted chronologically. |
| | """ |
| | drum_stem_path = Path(drum_stem_path) |
| |
|
| | |
| | filtered_path = _bandpass_filter(drum_stem_path) |
| |
|
| | |
| | act_proc = RNNBeatProcessor() |
| | activations = act_proc(str(filtered_path)) |
| |
|
| | |
| | filtered_path.unlink(missing_ok=True) |
| |
|
| | |
| | if fps != 100: |
| | from scipy.interpolate import interp1d |
| | n_frames = len(activations) |
| | t_orig = np.linspace(0, n_frames / 100, n_frames, endpoint=False) |
| | n_new = int(n_frames * fps / 100) |
| | t_new = np.linspace(0, n_frames / 100, n_new, endpoint=False) |
| | activations = interp1d(t_orig, activations, kind="cubic", fill_value="extrapolate")(t_new) |
| | activations = np.clip(activations, 0, None) |
| |
|
| | |
| | |
| | |
| | beat_proc = DBNBeatTrackingProcessor( |
| | min_bpm=min_bpm, |
| | max_bpm=max_bpm, |
| | transition_lambda=transition_lambda, |
| | fps=fps, |
| | correct=False, |
| | ) |
| | beats = beat_proc(activations) |
| |
|
| | return beats |
| |
|
| |
|
| | def detect_drop( |
| | audio_path: str | Path, |
| | beat_times: np.ndarray, |
| | window_sec: float = 0.5, |
| | ) -> float: |
| | """Find the beat where the biggest energy jump occurs (the drop). |
| | |
| | Computes RMS energy in a window around each beat and returns the beat |
| | with the largest increase compared to the previous beat. |
| | |
| | Args: |
| | audio_path: Path to the full mix audio file. |
| | beat_times: Array of beat timestamps in seconds. |
| | window_sec: Duration of the analysis window around each beat. |
| | |
| | Returns: |
| | Timestamp (seconds) of the detected drop beat. |
| | """ |
| | import librosa |
| |
|
| | y, sr = librosa.load(str(audio_path), sr=None, mono=True) |
| | half_win = int(window_sec / 2 * sr) |
| |
|
| | rms_values = [] |
| | for t in beat_times: |
| | center = int(t * sr) |
| | start = max(0, center - half_win) |
| | end = min(len(y), center + half_win) |
| | segment = y[start:end] |
| | rms = np.sqrt(np.mean(segment ** 2)) if len(segment) > 0 else 0.0 |
| | rms_values.append(rms) |
| |
|
| | rms_values = np.array(rms_values) |
| |
|
| | |
| | diffs = np.diff(rms_values) |
| | drop_idx = int(np.argmax(diffs)) + 1 |
| | drop_time = float(beat_times[drop_idx]) |
| |
|
| | print(f" Drop detected at beat {drop_idx + 1}: {drop_time:.3f}s " |
| | f"(energy jump: {diffs[drop_idx - 1]:.4f})") |
| | return drop_time |
| |
|
| |
|
| | def select_beats( |
| | beats: np.ndarray, |
| | max_duration: float = 15.0, |
| | min_interval: float = 0.3, |
| | ) -> np.ndarray: |
| | """Select a subset of beats for video generation. |
| | |
| | Filters beats to fit within a duration limit and enforces a minimum |
| | interval between consecutive beats (to avoid generating too many frames). |
| | |
| | Args: |
| | beats: Array of beat timestamps in seconds. |
| | max_duration: Maximum video duration in seconds. |
| | min_interval: Minimum time between selected beats in seconds. |
| | Beats closer together than this are skipped. |
| | |
| | Returns: |
| | Filtered array of beat timestamps. |
| | """ |
| | if len(beats) == 0: |
| | return beats |
| |
|
| | |
| | beats = beats[beats <= max_duration] |
| |
|
| | if len(beats) == 0: |
| | return beats |
| |
|
| | |
| | selected = [beats[0]] |
| | for beat in beats[1:]: |
| | if beat - selected[-1] >= min_interval: |
| | selected.append(beat) |
| |
|
| | return np.array(selected) |
| |
|
| |
|
| | def save_beats( |
| | beats: np.ndarray, |
| | output_path: str | Path, |
| | ) -> Path: |
| | """Save beat timestamps to a JSON file. |
| | |
| | Format matches the project convention (same style as lyrics.json): |
| | a list of objects with beat index and timestamp. |
| | |
| | Args: |
| | beats: Array of beat timestamps in seconds. |
| | output_path: Path to save the JSON file. |
| | |
| | Returns: |
| | Path to the saved JSON file. |
| | """ |
| | output_path = Path(output_path) |
| | output_path.parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | data = [ |
| | {"beat": i + 1, "time": round(float(t), 3)} |
| | for i, t in enumerate(beats) |
| | ] |
| |
|
| | with open(output_path, "w") as f: |
| | json.dump(data, f, indent=2) |
| |
|
| | return output_path |
| |
|
| |
|
| | def run( |
| | drum_stem_path: str | Path, |
| | output_dir: Optional[str | Path] = None, |
| | min_bpm: float = 55.0, |
| | max_bpm: float = 215.0, |
| | ) -> dict: |
| | """Full beat detection pipeline: detect, select, and save. |
| | |
| | Args: |
| | drum_stem_path: Path to the isolated drum stem WAV file. |
| | output_dir: Directory to save beats.json. Defaults to the |
| | parent of the drum stem's parent (e.g. data/Gone/ if |
| | stem is at data/Gone/stems/drums.wav). |
| | min_bpm: Minimum expected tempo. |
| | max_bpm: Maximum expected tempo. |
| | |
| | Returns: |
| | Dict with 'all_beats', 'selected_beats', and 'beats_path'. |
| | """ |
| | drum_stem_path = Path(drum_stem_path) |
| |
|
| | if output_dir is None: |
| | |
| | output_dir = drum_stem_path.parent.parent |
| | output_dir = Path(output_dir) |
| |
|
| | all_beats = detect_beats(drum_stem_path, min_bpm=min_bpm, max_bpm=max_bpm) |
| | selected = select_beats(all_beats) |
| |
|
| | |
| | song_dir = output_dir.parent if output_dir.name.startswith("run_") else output_dir |
| | audio_path = None |
| | for ext in [".wav", ".mp3", ".flac", ".m4a"]: |
| | candidates = list(song_dir.glob(f"*{ext}")) |
| | if candidates: |
| | audio_path = candidates[0] |
| | break |
| |
|
| | drop_time = None |
| | if audio_path and len(all_beats) > 2: |
| | drop_time = detect_drop(audio_path, all_beats) |
| |
|
| | beats_path = save_beats(all_beats, output_dir / "beats.json") |
| |
|
| | |
| | if drop_time is not None: |
| | drop_path = output_dir / "drop.json" |
| | with open(drop_path, "w") as f: |
| | json.dump({"drop_time": round(drop_time, 3)}, f, indent=2) |
| |
|
| | return { |
| | "all_beats": all_beats, |
| | "selected_beats": selected, |
| | "beats_path": beats_path, |
| | "drop_time": drop_time, |
| | } |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import sys |
| |
|
| | if len(sys.argv) < 2: |
| | print("Usage: python -m src.beat_detector <drum_stem.wav>") |
| | sys.exit(1) |
| |
|
| | result = run(sys.argv[1]) |
| | all_beats = result["all_beats"] |
| | selected = result["selected_beats"] |
| |
|
| | print(f"Detected {len(all_beats)} beats (saved to {result['beats_path']})") |
| | print(f"Selected {len(selected)} beats (max 15s, min 0.3s apart):") |
| | for i, t in enumerate(selected): |
| | print(f" Beat {i + 1}: {t:.3f}s") |
| |
|