| """ |
| Autonomous parameter optimizer for the drum extraction pipeline. |
| |
| Runs a loop: |
| 1. Generate synthetic songs with known ground truth |
| 2. Run the extraction pipeline with current params |
| 3. Evaluate extraction quality against ground truth |
| 4. Use results to tune parameters for next iteration |
| |
| Uses Bayesian-ish optimization: maintain a history of (params β score), |
| then perturb the best-so-far params toward improving weak metrics. |
| """ |
|
|
| import json |
| import time |
| import traceback |
| import numpy as np |
| from copy import deepcopy |
| from dataclasses import dataclass, field |
| from pathlib import Path |
|
|
| from synth_generator import generate_test_song, SyntheticSong |
| from evaluation import evaluate_extraction, report_to_dict, EvalReport |
| from quality_metrics import drum_sample_score |
|
|
|
|
| @dataclass |
| class PipelineParams: |
| """All tunable parameters of the extraction pipeline.""" |
| |
| pre_pad: float = 0.005 |
| min_hit_dur: float = 0.03 |
| max_hit_dur: float = 0.8 |
| min_gap: float = 0.02 |
| energy_threshold_db: float = -40.0 |
| |
| |
| separate_overlaps: bool = True |
| overlap_energy_threshold: float = 0.15 |
| |
| |
| use_clap: bool = False |
| |
| |
| w_completeness: float = 0.30 |
| w_cleanness: float = 0.40 |
| w_onset: float = 0.20 |
| w_representativeness: float = 0.10 |
| |
| |
| synthesize: bool = True |
| synth_best_weight: float = 2.0 |
| |
| def to_dict(self) -> dict: |
| return self.__dict__.copy() |
| |
| @classmethod |
| def from_dict(cls, d: dict) -> 'PipelineParams': |
| valid_keys = cls.__dataclass_fields__.keys() |
| return cls(**{k: v for k, v in d.items() if k in valid_keys}) |
|
|
|
|
| @dataclass |
| class IterationResult: |
| """Result of one optimization iteration.""" |
| iteration: int |
| params: dict |
| eval_report: dict |
| overall_score: float |
| duration_seconds: float |
| test_config: dict |
| timestamp: str |
|
|
|
|
| @dataclass |
| class OptimizerState: |
| """Persistent state of the optimizer.""" |
| history: list = field(default_factory=list) |
| best_params: dict = field(default_factory=dict) |
| best_score: float = 0.0 |
| iteration: int = 0 |
|
|
|
|
| |
| |
| |
|
|
| def diagnose_and_perturb(params: PipelineParams, report: EvalReport, |
| rng: np.random.RandomState) -> PipelineParams: |
| """Analyze evaluation report and intelligently perturb parameters. |
| |
| Instead of random search, we diagnose specific failure modes from the |
| evaluation metrics and adjust the relevant parameters. |
| """ |
| new_params = deepcopy(params) |
| changes = [] |
|
|
| |
| if report.mean_onset_error_ms > 20: |
| |
| new_params.pre_pad = max(0.001, params.pre_pad * rng.uniform(0.5, 0.9)) |
| |
| new_params.min_gap = max(0.01, params.min_gap * rng.uniform(0.6, 0.9)) |
| changes.append(f"onset_error={report.mean_onset_error_ms:.1f}ms β tightened pre_pad/min_gap") |
|
|
| |
| if report.hit_count_accuracy < 0.7: |
| |
| new_params.energy_threshold_db = max(-60, params.energy_threshold_db - rng.uniform(2, 8)) |
| |
| new_params.min_hit_dur = max(0.01, params.min_hit_dur * rng.uniform(0.5, 0.8)) |
| changes.append(f"hit_acc={report.hit_count_accuracy:.2f} β lowered threshold/min_dur") |
|
|
| |
| total_ext = sum(m.n_hits_extracted for m in report.matches) if report.matches else 0 |
| total_gt = sum(m.n_hits_gt for m in report.matches) if report.matches else 1 |
| if total_ext > total_gt * 1.5: |
| |
| new_params.energy_threshold_db = min(-20, params.energy_threshold_db + rng.uniform(2, 5)) |
| new_params.min_hit_dur = min(0.08, params.min_hit_dur * rng.uniform(1.1, 1.5)) |
| changes.append(f"over-extraction ({total_ext} vs {total_gt} GT) β raised threshold") |
|
|
| |
| if report.mean_si_sdr < 5: |
| |
| |
| new_params.overlap_energy_threshold = params.overlap_energy_threshold + rng.uniform(-0.05, 0.05) |
| new_params.overlap_energy_threshold = np.clip(new_params.overlap_energy_threshold, 0.05, 0.4) |
| changes.append(f"SI-SDR={report.mean_si_sdr:.1f}dB β adjusted overlap threshold") |
|
|
| |
| if report.mean_sample_score < 50: |
| |
| |
| new_params.w_cleanness = min(0.6, params.w_cleanness + rng.uniform(0, 0.1)) |
| new_params.w_completeness = max(0.15, params.w_completeness + rng.uniform(-0.05, 0.05)) |
| |
| total_w = new_params.w_cleanness + new_params.w_completeness + new_params.w_onset + new_params.w_representativeness |
| new_params.w_cleanness /= total_w |
| new_params.w_completeness /= total_w |
| new_params.w_onset /= total_w |
| new_params.w_representativeness /= total_w |
| changes.append(f"sample_score={report.mean_sample_score:.1f} β adjusted selection weights") |
|
|
| |
| if report.mean_env_corr < 0.7: |
| new_params.max_hit_dur = min(1.5, params.max_hit_dur * rng.uniform(1.1, 1.3)) |
| changes.append(f"env_corr={report.mean_env_corr:.2f} β increased max_hit_dur") |
|
|
| |
| if len(report.unmatched_gt) > 0: |
| new_params.energy_threshold_db = max(-60, params.energy_threshold_db - rng.uniform(3, 6)) |
| changes.append(f"missed {report.unmatched_gt} β lowered energy threshold") |
|
|
| |
| if not changes: |
| |
| new_params.energy_threshold_db += rng.uniform(-3, 3) |
| new_params.pre_pad += rng.uniform(-0.002, 0.002) |
| new_params.pre_pad = max(0.001, new_params.pre_pad) |
| new_params.min_hit_dur += rng.uniform(-0.01, 0.01) |
| new_params.min_hit_dur = max(0.01, new_params.min_hit_dur) |
| changes.append("no specific issue β random exploration") |
|
|
| return new_params, changes |
|
|
|
|
| |
| |
| |
|
|
| def run_extraction_with_params(song: SyntheticSong, params: PipelineParams) -> tuple: |
| """Run the extraction pipeline with given params on a song. |
| Returns (clusters, all_hits) or raises on failure.""" |
| from drum_extractor import ( |
| detect_onsets, classify_and_separate_hits, |
| compute_librosa_embeddings, cluster_hits, |
| select_best_representatives, synthesize_from_cluster, |
| ) |
|
|
| |
| hits = detect_onsets( |
| song.drums_only, song.sr, |
| pre_pad=params.pre_pad, |
| min_hit_dur=params.min_hit_dur, |
| max_hit_dur=params.max_hit_dur, |
| min_gap=params.min_gap, |
| energy_threshold_db=params.energy_threshold_db, |
| ) |
|
|
| if len(hits) == 0: |
| return [], [] |
|
|
| |
| hits = classify_and_separate_hits(hits, separate_overlaps=params.separate_overlaps) |
|
|
| |
| embeddings = compute_librosa_embeddings(hits) |
| clusters = cluster_hits(hits, embeddings) |
|
|
| |
| for cluster in clusters: |
| if cluster.count == 1: |
| cluster.best_hit_idx = 0 |
| continue |
|
|
| scores = [] |
| base_label = cluster.label.rsplit('_', 1)[0] |
| |
| |
| hit_features = [] |
| for hit in cluster.hits: |
| import librosa |
| feat = np.concatenate([ |
| librosa.feature.mfcc(y=hit.audio, sr=hit.sr, n_mfcc=13).mean(axis=1), |
| [hit.rms_energy, hit.spectral_centroid, hit.duration] |
| ]) |
| hit_features.append(feat) |
| hit_features = np.array(hit_features) |
| mean_f = hit_features.mean(axis=0) |
| std_f = hit_features.std(axis=0) + 1e-8 |
| hit_features_norm = (hit_features - mean_f) / std_f |
| centroid = hit_features_norm.mean(axis=0) |
| dists = np.linalg.norm(hit_features_norm - centroid, axis=1) |
| radius = dists.max() + 1e-8 |
|
|
| for i, hit in enumerate(cluster.hits): |
| score = drum_sample_score( |
| hit.audio, hit.sr, base_label, |
| centroid_dist=dists[i], |
| cluster_radius=radius, |
| ) |
| scores.append(score['total']) |
| |
| cluster.best_hit_idx = int(np.argmax(scores)) |
|
|
| |
| if params.synthesize: |
| for cluster in clusters: |
| if cluster.count >= 2: |
| cluster.synthesized = synthesize_from_cluster(cluster) |
|
|
| return clusters, hits |
|
|
|
|
| def run_optimization_loop( |
| n_iterations: int = 10, |
| patterns: list = None, |
| initial_params: PipelineParams = None, |
| seed: int = 42, |
| log_callback=None, |
| ) -> OptimizerState: |
| """Run the full autonomous optimization loop. |
| |
| Args: |
| n_iterations: number of optimization iterations |
| patterns: list of pattern names to test with (cycles through them) |
| initial_params: starting pipeline parameters |
| seed: random seed |
| log_callback: function(str) called with log messages |
| """ |
| if patterns is None: |
| patterns = ['rock', 'funk', 'halftime'] |
| if initial_params is None: |
| initial_params = PipelineParams() |
|
|
| rng = np.random.RandomState(seed) |
| state = OptimizerState(best_params=initial_params.to_dict()) |
| current_params = deepcopy(initial_params) |
|
|
| def log(msg): |
| if log_callback: |
| log_callback(msg) |
| print(msg) |
|
|
| log(f"Starting optimization loop: {n_iterations} iterations") |
| log(f"Patterns: {patterns}") |
|
|
| for i in range(n_iterations): |
| t0 = time.time() |
| pattern_name = patterns[i % len(patterns)] |
| song_seed = seed + i * 17 |
|
|
| log(f"\n{'='*60}") |
| log(f"ITERATION {i+1}/{n_iterations} β pattern={pattern_name}, seed={song_seed}") |
| log(f"{'='*60}") |
|
|
| try: |
| |
| log(" Generating synthetic song...") |
| song = generate_test_song( |
| pattern_name=pattern_name, |
| bars=4, |
| bpm=100 + rng.randint(0, 40) * 2, |
| variation='medium', |
| seed=song_seed, |
| ) |
| log(f" β {song.duration:.1f}s, {song.bpm}BPM, " |
| f"{len(song.hits)} hits, {len(song.samples)} sample types") |
|
|
| |
| log(f" Running extraction with params: threshold={current_params.energy_threshold_db:.1f}dB, " |
| f"pre_pad={current_params.pre_pad:.3f}, min_dur={current_params.min_hit_dur:.3f}") |
| clusters, all_hits = run_extraction_with_params(song, current_params) |
| log(f" β {len(clusters)} clusters, {len(all_hits)} total hits") |
|
|
| |
| log(" Evaluating against ground truth...") |
| gt_samples = {name: s.audio for name, s in song.samples.items()} |
| gt_hit_map = [ |
| {'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity} |
| for h in song.hits |
| ] |
|
|
| report = evaluate_extraction( |
| extracted_clusters=clusters, |
| gt_samples=gt_samples, |
| gt_hit_map=gt_hit_map, |
| sr=song.sr, |
| all_hits=all_hits, |
| pipeline_params=current_params.to_dict(), |
| ) |
|
|
| duration = time.time() - t0 |
|
|
| log(f" RESULTS:") |
| log(f" Overall Score: {report.overall_score:.1f}/100") |
| log(f" SI-SDR: {report.mean_si_sdr:.1f} dB") |
| log(f" Sample Score: {report.mean_sample_score:.1f}/100") |
| log(f" Env Corr: {report.mean_env_corr:.3f}") |
| log(f" Onset Error: {report.mean_onset_error_ms:.1f} ms") |
| log(f" Hit Count Acc: {report.hit_count_accuracy:.2f}") |
| log(f" Matched: {len(report.matches)}/{len(song.samples)}") |
| if report.unmatched_gt: |
| log(f" β Unmatched GT: {report.unmatched_gt}") |
|
|
| |
| result = IterationResult( |
| iteration=i, |
| params=current_params.to_dict(), |
| eval_report=report_to_dict(report), |
| overall_score=report.overall_score, |
| duration_seconds=duration, |
| test_config={'pattern': pattern_name, 'bpm': song.bpm, 'seed': song_seed}, |
| timestamp=time.strftime('%Y-%m-%d %H:%M:%S'), |
| ) |
| state.history.append(result) |
|
|
| |
| if report.overall_score > state.best_score: |
| state.best_score = report.overall_score |
| state.best_params = current_params.to_dict() |
| log(f" β
NEW BEST SCORE: {report.overall_score:.1f}") |
|
|
| |
| new_params, changes = diagnose_and_perturb(current_params, report, rng) |
| log(f" Parameter adjustments:") |
| for change in changes: |
| log(f" β {change}") |
| current_params = new_params |
|
|
| except Exception as e: |
| log(f" β ERROR: {e}") |
| log(traceback.format_exc()) |
| |
| current_params.energy_threshold_db += rng.uniform(-5, 5) |
| state.history.append(IterationResult( |
| iteration=i, |
| params=current_params.to_dict(), |
| eval_report={'error': str(e)}, |
| overall_score=0.0, |
| duration_seconds=time.time() - t0, |
| test_config={'pattern': pattern_name}, |
| timestamp=time.strftime('%Y-%m-%d %H:%M:%S'), |
| )) |
|
|
| state.iteration = i + 1 |
|
|
| log(f"\n{'='*60}") |
| log(f"OPTIMIZATION COMPLETE") |
| log(f"{'='*60}") |
| log(f" Best score: {state.best_score:.1f}/100") |
| log(f" Best params: {json.dumps(state.best_params, indent=2)}") |
|
|
| return state |
|
|