| import torch |
| import numpy as np |
| from tqdm import tqdm |
| from scipy.signal import find_peaks |
| import argparse |
| import os |
|
|
| from .model import ResNet |
| from ..baseline1.utils import MultiViewSpectrogram |
| from ..data.load import ds |
| from ..data.eval import evaluate_all, format_results |
|
|
|
|
| def get_activation_function(model, waveform, device): |
| """ |
| Computes probability curve over time. |
| """ |
| processor = MultiViewSpectrogram().to(device) |
| waveform = waveform.unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| spec = processor(waveform) |
|
|
| |
| mean = spec.mean(dim=(2, 3), keepdim=True) |
| std = spec.std(dim=(2, 3), keepdim=True) + 1e-6 |
| spec = (spec - mean) / std |
|
|
| |
| |
| |
| spec = torch.nn.functional.pad(spec, (50, 50)) |
| windows = spec.unfold(3, 101, 1) |
| windows = windows.permute(3, 0, 1, 2, 4).squeeze(1) |
|
|
| |
| activations = [] |
| batch_size = 128 |
| for i in range(0, len(windows), batch_size): |
| batch = windows[i : i + batch_size] |
| out = model(batch) |
| activations.append(out.cpu().numpy()) |
|
|
| return np.concatenate(activations).flatten() |
|
|
|
|
| def pick_peaks(activations, hop_length=160, sr=16000): |
| """ |
| Smooth with Hamming window and report local maxima. |
| """ |
| |
| window = np.hamming(5) |
| window /= window.sum() |
| smoothed = np.convolve(activations, window, mode="same") |
|
|
| |
| peaks, _ = find_peaks(smoothed, height=0.5, distance=5) |
|
|
| timestamps = peaks * hop_length / sr |
| return timestamps.tolist() |
|
|
|
|
| def visualize_track( |
| audio: np.ndarray, |
| sr: int, |
| pred_beats: list[float], |
| pred_downbeats: list[float], |
| gt_beats: list[float], |
| gt_downbeats: list[float], |
| output_dir: str, |
| track_idx: int, |
| time_range: tuple[float, float] | None = None, |
| ): |
| """ |
| Create and save visualizations for a single track. |
| """ |
| from ..data.viz import plot_waveform_with_beats, save_figure |
|
|
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| fig = plot_waveform_with_beats( |
| audio, |
| sr, |
| pred_beats, |
| gt_beats, |
| pred_downbeats, |
| gt_downbeats, |
| title=f"Track {track_idx}: Beat Comparison", |
| time_range=time_range, |
| ) |
| save_figure(fig, os.path.join(output_dir, f"track_{track_idx:03d}.png")) |
|
|
|
|
| def synthesize_audio( |
| audio: np.ndarray, |
| sr: int, |
| pred_beats: list[float], |
| pred_downbeats: list[float], |
| gt_beats: list[float], |
| gt_downbeats: list[float], |
| output_dir: str, |
| track_idx: int, |
| click_volume: float = 0.5, |
| ): |
| """ |
| Create and save audio files with click tracks for a single track. |
| """ |
| from ..data.audio import create_comparison_audio, save_audio |
|
|
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| audio_pred, audio_gt, audio_both = create_comparison_audio( |
| audio, |
| pred_beats, |
| pred_downbeats, |
| gt_beats, |
| gt_downbeats, |
| sr=sr, |
| click_volume=click_volume, |
| ) |
|
|
| |
| save_audio( |
| audio_pred, os.path.join(output_dir, f"track_{track_idx:03d}_pred.wav"), sr |
| ) |
| save_audio(audio_gt, os.path.join(output_dir, f"track_{track_idx:03d}_gt.wav"), sr) |
| save_audio( |
| audio_both, os.path.join(output_dir, f"track_{track_idx:03d}_both.wav"), sr |
| ) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Evaluate beat tracking models with visualization and audio synthesis" |
| ) |
| parser.add_argument( |
| "--model-dir", |
| type=str, |
| default="outputs/baseline2", |
| help="Base directory containing trained models (with 'beats' and 'downbeats' subdirs)", |
| ) |
| parser.add_argument( |
| "--num-samples", |
| type=int, |
| default=116, |
| help="Number of samples to evaluate", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default="outputs/eval_baseline2", |
| help="Directory to save visualizations and audio", |
| ) |
| parser.add_argument( |
| "--visualize", |
| action="store_true", |
| help="Generate visualization plots for each track", |
| ) |
| parser.add_argument( |
| "--synthesize", |
| action="store_true", |
| help="Generate audio files with click tracks", |
| ) |
| parser.add_argument( |
| "--viz-tracks", |
| type=int, |
| default=5, |
| help="Number of tracks to visualize/synthesize (default: 5)", |
| ) |
| parser.add_argument( |
| "--time-range", |
| type=float, |
| nargs=2, |
| default=None, |
| metavar=("START", "END"), |
| help="Time range for visualization in seconds (default: full track)", |
| ) |
| parser.add_argument( |
| "--click-volume", |
| type=float, |
| default=0.5, |
| help="Volume of click sounds relative to audio (0.0 to 1.0)", |
| ) |
| parser.add_argument( |
| "--summary-plot", |
| action="store_true", |
| help="Generate summary evaluation plot", |
| ) |
| args = parser.parse_args() |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| beat_model = None |
| downbeat_model = None |
|
|
| has_beats = False |
| has_downbeats = False |
|
|
| beats_dir = os.path.join(args.model_dir, "beats") |
| downbeats_dir = os.path.join(args.model_dir, "downbeats") |
|
|
| if os.path.exists(os.path.join(beats_dir, "model.safetensors")): |
| beat_model = ResNet.from_pretrained(beats_dir).to(DEVICE) |
| beat_model.eval() |
| has_beats = True |
| print(f"Loaded Beat Model from {beats_dir}") |
| else: |
| print(f"Warning: No beat model found in {beats_dir}") |
|
|
| if os.path.exists(os.path.join(downbeats_dir, "model.safetensors")): |
| downbeat_model = ResNet.from_pretrained(downbeats_dir).to(DEVICE) |
| downbeat_model.eval() |
| has_downbeats = True |
| print(f"Loaded Downbeat Model from {downbeats_dir}") |
| else: |
| print(f"Warning: No downbeat model found in {downbeats_dir}") |
|
|
| if not has_beats and not has_downbeats: |
| print("No models found. Please run training first.") |
| return |
|
|
| predictions = [] |
| ground_truths = [] |
| audio_data = [] |
|
|
| |
| test_set = ds["train"].select(range(args.num_samples)) |
|
|
| print("Running evaluation...") |
| for i, item in enumerate(tqdm(test_set)): |
| waveform = torch.tensor(item["audio"]["array"], dtype=torch.float32) |
| waveform_device = waveform.to(DEVICE) |
|
|
| pred_entry = {"beats": [], "downbeats": []} |
|
|
| |
| if has_beats: |
| act_b = get_activation_function(beat_model, waveform_device, DEVICE) |
| pred_entry["beats"] = pick_peaks(act_b) |
|
|
| |
| if has_downbeats: |
| act_d = get_activation_function(downbeat_model, waveform_device, DEVICE) |
| pred_entry["downbeats"] = pick_peaks(act_d) |
|
|
| predictions.append(pred_entry) |
| ground_truths.append({"beats": item["beats"], "downbeats": item["downbeats"]}) |
|
|
| |
| if args.visualize or args.synthesize: |
| if i < args.viz_tracks: |
| audio_data.append( |
| { |
| "audio": waveform.numpy(), |
| "sr": item["audio"]["sampling_rate"], |
| "pred": pred_entry, |
| "gt": ground_truths[-1], |
| } |
| ) |
|
|
| |
| results = evaluate_all(predictions, ground_truths) |
| print(format_results(results)) |
|
|
| |
| if args.visualize or args.synthesize or args.summary_plot: |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| |
| if args.visualize: |
| print(f"\nGenerating visualizations for {len(audio_data)} tracks...") |
| viz_dir = os.path.join(args.output_dir, "plots") |
| for i, data in enumerate(tqdm(audio_data, desc="Visualizing")): |
| time_range = tuple(args.time_range) if args.time_range else None |
| visualize_track( |
| data["audio"], |
| data["sr"], |
| data["pred"]["beats"], |
| data["pred"]["downbeats"], |
| data["gt"]["beats"], |
| data["gt"]["downbeats"], |
| viz_dir, |
| i, |
| time_range=time_range, |
| ) |
| print(f"Saved visualizations to {viz_dir}") |
|
|
| |
| if args.synthesize: |
| print(f"\nSynthesizing audio for {len(audio_data)} tracks...") |
| audio_dir = os.path.join(args.output_dir, "audio") |
| for i, data in enumerate(tqdm(audio_data, desc="Synthesizing")): |
| synthesize_audio( |
| data["audio"], |
| data["sr"], |
| data["pred"]["beats"], |
| data["pred"]["downbeats"], |
| data["gt"]["beats"], |
| data["gt"]["downbeats"], |
| audio_dir, |
| i, |
| click_volume=args.click_volume, |
| ) |
| print(f"Saved audio files to {audio_dir}") |
| print(" *_pred.wav - Original audio with predicted beat clicks") |
| print(" *_gt.wav - Original audio with ground truth beat clicks") |
| print(" *_both.wav - Original audio with both predicted and GT clicks") |
|
|
| |
| if args.summary_plot: |
| from ..data.viz import plot_evaluation_summary, save_figure |
|
|
| print("\nGenerating summary plot...") |
| fig = plot_evaluation_summary(results, title="Beat Tracking Evaluation Summary") |
| summary_path = os.path.join(args.output_dir, "evaluation_summary.png") |
| save_figure(fig, summary_path) |
| print(f"Saved summary plot to {summary_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|