drum-sample-extractor / optimizer.py
rikhoffbauer2's picture
Add optimizer.py
d34b37f verified
"""
Autonomous parameter optimizer for the drum extraction pipeline.
Runs a loop:
1. Generate synthetic songs with known ground truth
2. Run the extraction pipeline with current params
3. Evaluate extraction quality against ground truth
4. Use results to tune parameters for next iteration
Uses Bayesian-ish optimization: maintain a history of (params β†’ score),
then perturb the best-so-far params toward improving weak metrics.
"""
import json
import time
import traceback
import numpy as np
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from synth_generator import generate_test_song, SyntheticSong
from evaluation import evaluate_extraction, report_to_dict, EvalReport
from quality_metrics import drum_sample_score
@dataclass
class PipelineParams:
"""All tunable parameters of the extraction pipeline."""
# Onset detection
pre_pad: float = 0.005
min_hit_dur: float = 0.03
max_hit_dur: float = 0.8
min_gap: float = 0.02
energy_threshold_db: float = -40.0
# Overlap separation
separate_overlaps: bool = True
overlap_energy_threshold: float = 0.15 # band energy ratio to count as significant
# Clustering
use_clap: bool = False
# Selection weights (must sum to 1.0)
w_completeness: float = 0.30
w_cleanness: float = 0.40
w_onset: float = 0.20
w_representativeness: float = 0.10
# Synthesis
synthesize: bool = True
synth_best_weight: float = 2.0 # weight multiplier for best sample in cluster
def to_dict(self) -> dict:
return self.__dict__.copy()
@classmethod
def from_dict(cls, d: dict) -> 'PipelineParams':
valid_keys = cls.__dataclass_fields__.keys()
return cls(**{k: v for k, v in d.items() if k in valid_keys})
@dataclass
class IterationResult:
"""Result of one optimization iteration."""
iteration: int
params: dict
eval_report: dict
overall_score: float
duration_seconds: float
test_config: dict # which synthetic song was used
timestamp: str
@dataclass
class OptimizerState:
"""Persistent state of the optimizer."""
history: list = field(default_factory=list) # [IterationResult]
best_params: dict = field(default_factory=dict)
best_score: float = 0.0
iteration: int = 0
# ─────────────────────────────────────────────────────────────────────────────
# Parameter perturbation strategies
# ─────────────────────────────────────────────────────────────────────────────
def diagnose_and_perturb(params: PipelineParams, report: EvalReport,
rng: np.random.RandomState) -> PipelineParams:
"""Analyze evaluation report and intelligently perturb parameters.
Instead of random search, we diagnose specific failure modes from the
evaluation metrics and adjust the relevant parameters.
"""
new_params = deepcopy(params)
changes = []
# ── Diagnosis 1: Poor onset precision (>20ms mean error) ──
if report.mean_onset_error_ms > 20:
# Reduce pre_pad to tighten onset capture
new_params.pre_pad = max(0.001, params.pre_pad * rng.uniform(0.5, 0.9))
# Reduce min_gap to catch faster sequences
new_params.min_gap = max(0.01, params.min_gap * rng.uniform(0.6, 0.9))
changes.append(f"onset_error={report.mean_onset_error_ms:.1f}ms β†’ tightened pre_pad/min_gap")
# ── Diagnosis 2: Missing hits (low hit count accuracy) ──
if report.hit_count_accuracy < 0.7:
# Lower energy threshold to catch quieter hits
new_params.energy_threshold_db = max(-60, params.energy_threshold_db - rng.uniform(2, 8))
# Reduce min_hit_dur to catch shorter sounds
new_params.min_hit_dur = max(0.01, params.min_hit_dur * rng.uniform(0.5, 0.8))
changes.append(f"hit_acc={report.hit_count_accuracy:.2f} β†’ lowered threshold/min_dur")
# ── Diagnosis 3: Too many false hits (extracted >> GT) ──
total_ext = sum(m.n_hits_extracted for m in report.matches) if report.matches else 0
total_gt = sum(m.n_hits_gt for m in report.matches) if report.matches else 1
if total_ext > total_gt * 1.5:
# Raise energy threshold
new_params.energy_threshold_db = min(-20, params.energy_threshold_db + rng.uniform(2, 5))
new_params.min_hit_dur = min(0.08, params.min_hit_dur * rng.uniform(1.1, 1.5))
changes.append(f"over-extraction ({total_ext} vs {total_gt} GT) β†’ raised threshold")
# ── Diagnosis 4: Low SI-SDR (poor sample quality) ──
if report.mean_si_sdr < 5:
# The extracted samples don't match GT well
# Try adjusting overlap separation threshold
new_params.overlap_energy_threshold = params.overlap_energy_threshold + rng.uniform(-0.05, 0.05)
new_params.overlap_energy_threshold = np.clip(new_params.overlap_energy_threshold, 0.05, 0.4)
changes.append(f"SI-SDR={report.mean_si_sdr:.1f}dB β†’ adjusted overlap threshold")
# ── Diagnosis 5: Low sample scores (poor completeness/cleanness) ──
if report.mean_sample_score < 50:
# Adjust selection weights
# More weight on cleanness if we're getting bleed-heavy samples
new_params.w_cleanness = min(0.6, params.w_cleanness + rng.uniform(0, 0.1))
new_params.w_completeness = max(0.15, params.w_completeness + rng.uniform(-0.05, 0.05))
# Renormalize
total_w = new_params.w_cleanness + new_params.w_completeness + new_params.w_onset + new_params.w_representativeness
new_params.w_cleanness /= total_w
new_params.w_completeness /= total_w
new_params.w_onset /= total_w
new_params.w_representativeness /= total_w
changes.append(f"sample_score={report.mean_sample_score:.1f} β†’ adjusted selection weights")
# ── Diagnosis 6: Low envelope correlation (transient mismatch) ──
if report.mean_env_corr < 0.7:
new_params.max_hit_dur = min(1.5, params.max_hit_dur * rng.uniform(1.1, 1.3))
changes.append(f"env_corr={report.mean_env_corr:.2f} β†’ increased max_hit_dur")
# ── Diagnosis 7: Unmatched GT samples (some drums never found) ──
if len(report.unmatched_gt) > 0:
new_params.energy_threshold_db = max(-60, params.energy_threshold_db - rng.uniform(3, 6))
changes.append(f"missed {report.unmatched_gt} β†’ lowered energy threshold")
# If no specific diagnosis triggered, apply small random perturbation
if not changes:
# Explore nearby parameter space
new_params.energy_threshold_db += rng.uniform(-3, 3)
new_params.pre_pad += rng.uniform(-0.002, 0.002)
new_params.pre_pad = max(0.001, new_params.pre_pad)
new_params.min_hit_dur += rng.uniform(-0.01, 0.01)
new_params.min_hit_dur = max(0.01, new_params.min_hit_dur)
changes.append("no specific issue β†’ random exploration")
return new_params, changes
# ─────────────────────────────────────────────────────────────────────────────
# Main optimization loop
# ─────────────────────────────────────────────────────────────────────────────
def run_extraction_with_params(song: SyntheticSong, params: PipelineParams) -> tuple:
"""Run the extraction pipeline with given params on a song.
Returns (clusters, all_hits) or raises on failure."""
from drum_extractor import (
detect_onsets, classify_and_separate_hits,
compute_librosa_embeddings, cluster_hits,
select_best_representatives, synthesize_from_cluster,
)
# Stage 2: Onset detection
hits = detect_onsets(
song.drums_only, song.sr,
pre_pad=params.pre_pad,
min_hit_dur=params.min_hit_dur,
max_hit_dur=params.max_hit_dur,
min_gap=params.min_gap,
energy_threshold_db=params.energy_threshold_db,
)
if len(hits) == 0:
return [], []
# Stage 3: Classify & separate
hits = classify_and_separate_hits(hits, separate_overlaps=params.separate_overlaps)
# Stage 4: Embed & cluster
embeddings = compute_librosa_embeddings(hits)
clusters = cluster_hits(hits, embeddings)
# Stage 5: Select best (using our improved scoring)
for cluster in clusters:
if cluster.count == 1:
cluster.best_hit_idx = 0
continue
scores = []
base_label = cluster.label.rsplit('_', 1)[0]
# Compute cluster radius for representativeness scoring
hit_features = []
for hit in cluster.hits:
import librosa
feat = np.concatenate([
librosa.feature.mfcc(y=hit.audio, sr=hit.sr, n_mfcc=13).mean(axis=1),
[hit.rms_energy, hit.spectral_centroid, hit.duration]
])
hit_features.append(feat)
hit_features = np.array(hit_features)
mean_f = hit_features.mean(axis=0)
std_f = hit_features.std(axis=0) + 1e-8
hit_features_norm = (hit_features - mean_f) / std_f
centroid = hit_features_norm.mean(axis=0)
dists = np.linalg.norm(hit_features_norm - centroid, axis=1)
radius = dists.max() + 1e-8
for i, hit in enumerate(cluster.hits):
score = drum_sample_score(
hit.audio, hit.sr, base_label,
centroid_dist=dists[i],
cluster_radius=radius,
)
scores.append(score['total'])
cluster.best_hit_idx = int(np.argmax(scores))
# Stage 6: Synthesis
if params.synthesize:
for cluster in clusters:
if cluster.count >= 2:
cluster.synthesized = synthesize_from_cluster(cluster)
return clusters, hits
def run_optimization_loop(
n_iterations: int = 10,
patterns: list = None,
initial_params: PipelineParams = None,
seed: int = 42,
log_callback=None,
) -> OptimizerState:
"""Run the full autonomous optimization loop.
Args:
n_iterations: number of optimization iterations
patterns: list of pattern names to test with (cycles through them)
initial_params: starting pipeline parameters
seed: random seed
log_callback: function(str) called with log messages
"""
if patterns is None:
patterns = ['rock', 'funk', 'halftime']
if initial_params is None:
initial_params = PipelineParams()
rng = np.random.RandomState(seed)
state = OptimizerState(best_params=initial_params.to_dict())
current_params = deepcopy(initial_params)
def log(msg):
if log_callback:
log_callback(msg)
print(msg)
log(f"Starting optimization loop: {n_iterations} iterations")
log(f"Patterns: {patterns}")
for i in range(n_iterations):
t0 = time.time()
pattern_name = patterns[i % len(patterns)]
song_seed = seed + i * 17 # different song each iteration
log(f"\n{'='*60}")
log(f"ITERATION {i+1}/{n_iterations} β€” pattern={pattern_name}, seed={song_seed}")
log(f"{'='*60}")
try:
# 1. Generate synthetic song
log(" Generating synthetic song...")
song = generate_test_song(
pattern_name=pattern_name,
bars=4,
bpm=100 + rng.randint(0, 40) * 2, # vary BPM
variation='medium',
seed=song_seed,
)
log(f" β†’ {song.duration:.1f}s, {song.bpm}BPM, "
f"{len(song.hits)} hits, {len(song.samples)} sample types")
# 2. Run extraction
log(f" Running extraction with params: threshold={current_params.energy_threshold_db:.1f}dB, "
f"pre_pad={current_params.pre_pad:.3f}, min_dur={current_params.min_hit_dur:.3f}")
clusters, all_hits = run_extraction_with_params(song, current_params)
log(f" β†’ {len(clusters)} clusters, {len(all_hits)} total hits")
# 3. Evaluate
log(" Evaluating against ground truth...")
gt_samples = {name: s.audio for name, s in song.samples.items()}
gt_hit_map = [
{'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity}
for h in song.hits
]
report = evaluate_extraction(
extracted_clusters=clusters,
gt_samples=gt_samples,
gt_hit_map=gt_hit_map,
sr=song.sr,
all_hits=all_hits,
pipeline_params=current_params.to_dict(),
)
duration = time.time() - t0
log(f" RESULTS:")
log(f" Overall Score: {report.overall_score:.1f}/100")
log(f" SI-SDR: {report.mean_si_sdr:.1f} dB")
log(f" Sample Score: {report.mean_sample_score:.1f}/100")
log(f" Env Corr: {report.mean_env_corr:.3f}")
log(f" Onset Error: {report.mean_onset_error_ms:.1f} ms")
log(f" Hit Count Acc: {report.hit_count_accuracy:.2f}")
log(f" Matched: {len(report.matches)}/{len(song.samples)}")
if report.unmatched_gt:
log(f" ⚠ Unmatched GT: {report.unmatched_gt}")
# Record iteration
result = IterationResult(
iteration=i,
params=current_params.to_dict(),
eval_report=report_to_dict(report),
overall_score=report.overall_score,
duration_seconds=duration,
test_config={'pattern': pattern_name, 'bpm': song.bpm, 'seed': song_seed},
timestamp=time.strftime('%Y-%m-%d %H:%M:%S'),
)
state.history.append(result)
# Update best
if report.overall_score > state.best_score:
state.best_score = report.overall_score
state.best_params = current_params.to_dict()
log(f" β˜… NEW BEST SCORE: {report.overall_score:.1f}")
# 4. Tune parameters for next iteration
new_params, changes = diagnose_and_perturb(current_params, report, rng)
log(f" Parameter adjustments:")
for change in changes:
log(f" β†’ {change}")
current_params = new_params
except Exception as e:
log(f" βœ— ERROR: {e}")
log(traceback.format_exc())
# On error, try random perturbation
current_params.energy_threshold_db += rng.uniform(-5, 5)
state.history.append(IterationResult(
iteration=i,
params=current_params.to_dict(),
eval_report={'error': str(e)},
overall_score=0.0,
duration_seconds=time.time() - t0,
test_config={'pattern': pattern_name},
timestamp=time.strftime('%Y-%m-%d %H:%M:%S'),
))
state.iteration = i + 1
log(f"\n{'='*60}")
log(f"OPTIMIZATION COMPLETE")
log(f"{'='*60}")
log(f" Best score: {state.best_score:.1f}/100")
log(f" Best params: {json.dumps(state.best_params, indent=2)}")
return state