rikhoffbauer2 commited on
Commit
3806379
Β·
verified Β·
1 Parent(s): 9aaf4a2

Add drum sample extractor pipeline

Browse files
Files changed (1) hide show
  1. drum_extractor.py +843 -0
drum_extractor.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Drum Sample Extractor Pipeline
4
+ ===============================
5
+ Extracts individual drum samples from an audio file through:
6
+
7
+ 1. STEM SEPARATION β€” HTDemucs (v4 fine-tuned) isolates the drum track
8
+ 2. ONSET DETECTION β€” librosa detects individual hit boundaries
9
+ 3. INTRA-DRUM SEP β€” Spectral band splitting + optional AudioSep for overlapping sounds
10
+ 4. CLUSTERING β€” CLAP embeddings + auto-K KMeans groups identical hits
11
+ 5. SELECTION β€” Best representative per cluster (centroid-nearest + highest energy)
12
+ 6. SYNTHESIS (opt) β€” Weighted average of cluster members for an "ideal" sample
13
+
14
+ Usage:
15
+ python drum_extractor.py input.mp3 --output-dir ./samples
16
+ python drum_extractor.py input.wav --output-dir ./samples --no-gpu
17
+ python drum_extractor.py input.mp3 --output-dir ./samples --clap
18
+
19
+ Dependencies:
20
+ pip install demucs librosa soundfile scikit-learn numpy torch transformers
21
+ """
22
+
23
+ import argparse
24
+ import json
25
+ import os
26
+ import sys
27
+ import warnings
28
+ from collections import defaultdict
29
+ from dataclasses import dataclass, field
30
+ from pathlib import Path
31
+ from typing import Optional
32
+
33
+ import librosa
34
+ import numpy as np
35
+ import soundfile as sf
36
+ import torch
37
+
38
+ warnings.filterwarnings("ignore", category=FutureWarning)
39
+ warnings.filterwarnings("ignore", category=UserWarning)
40
+
41
+
42
+ # ─────────────────────────────────────────────────────────────────────────────
43
+ # Data structures
44
+ # ─────────────────────────────────────────────────────────────────────────────
45
+
46
+ @dataclass
47
+ class DrumHit:
48
+ """A single detected drum hit."""
49
+ audio: np.ndarray # mono waveform
50
+ sr: int # sample rate
51
+ onset_time: float # onset time in seconds (in the drum stem)
52
+ duration: float # duration in seconds
53
+ index: int # sequential index
54
+ rms_energy: float = 0.0
55
+ spectral_centroid: float = 0.0
56
+ rough_label: str = "" # spectral rough label: kick/snare/hihat/other
57
+ embedding: Optional[np.ndarray] = None
58
+ cluster_id: int = -1
59
+
60
+ def save(self, path: str):
61
+ sf.write(path, self.audio, self.sr, subtype='PCM_24')
62
+
63
+
64
+ @dataclass
65
+ class DrumCluster:
66
+ """A cluster of similar drum hits."""
67
+ cluster_id: int
68
+ label: str # e.g. "kick_0", "snare_1"
69
+ hits: list = field(default_factory=list)
70
+ best_hit_idx: int = 0 # index into self.hits
71
+ synthesized: Optional[np.ndarray] = None
72
+
73
+ @property
74
+ def best_hit(self) -> DrumHit:
75
+ return self.hits[self.best_hit_idx]
76
+
77
+ @property
78
+ def count(self) -> int:
79
+ return len(self.hits)
80
+
81
+
82
+ # ─────────────────────────────────────────────────────────────────────────────
83
+ # Stage 1: Drum stem extraction via Demucs
84
+ # ─────────────────────────────────────────────────────────────────────────────
85
+
86
+ def extract_drums_demucs(audio_path: str, device: str = "cpu") -> tuple[np.ndarray, int]:
87
+ """Extract drum stem using HTDemucs v4 (fine-tuned)."""
88
+ from demucs.pretrained import get_model
89
+ from demucs.apply import apply_model
90
+
91
+ print("=" * 60)
92
+ print("STAGE 1: Extracting drum stem with HTDemucs")
93
+ print("=" * 60)
94
+
95
+ # Try htdemucs_ft first (better drums), fall back to htdemucs
96
+ for model_name in ["htdemucs_ft", "htdemucs"]:
97
+ try:
98
+ model = get_model(model_name)
99
+ print(f" Loaded model: {model_name}")
100
+ break
101
+ except Exception as e:
102
+ print(f" Could not load {model_name}: {e}")
103
+ else:
104
+ raise RuntimeError("Could not load any Demucs model")
105
+
106
+ model.eval()
107
+ model.to(device)
108
+ target_sr = model.samplerate # 44100
109
+
110
+ # Load audio using librosa (works without FFmpeg system libs)
111
+ audio_np, sr = librosa.load(audio_path, sr=target_sr, mono=False)
112
+ if audio_np.ndim == 1:
113
+ audio_np = np.stack([audio_np, audio_np]) # mono β†’ stereo
114
+ elif audio_np.shape[0] == 1:
115
+ audio_np = np.concatenate([audio_np, audio_np], axis=0)
116
+ elif audio_np.shape[0] > 2:
117
+ audio_np = audio_np[:2]
118
+ wav = torch.from_numpy(audio_np).float() # [2, T]
119
+
120
+ wav = wav.unsqueeze(0).to(device) # [1, 2, T]
121
+ print(f" Audio: {wav.shape[-1] / target_sr:.1f}s, {target_sr}Hz")
122
+
123
+ # Separate
124
+ with torch.no_grad():
125
+ sources = apply_model(model, wav, device=device, shifts=1,
126
+ split=True, overlap=0.25, progress=True)
127
+
128
+ # sources: [1, n_sources, 2, T]
129
+ stem_names = model.sources # e.g. ['drums', 'bass', 'other', 'vocals']
130
+ drums_idx = stem_names.index('drums')
131
+ drums_wav = sources[0, drums_idx] # [2, T]
132
+
133
+ # Convert to mono numpy
134
+ drums_mono = drums_wav.mean(dim=0).cpu().numpy()
135
+ print(f" βœ“ Extracted drums: {len(drums_mono) / target_sr:.1f}s")
136
+
137
+ return drums_mono, target_sr
138
+
139
+
140
+ # ─────────────────────────────────────────────────────────────────────────────
141
+ # Stage 2: Onset detection & hit segmentation
142
+ # ─────────────────────────────────────────────────────────────────────────────
143
+
144
+ def detect_onsets(y: np.ndarray, sr: int,
145
+ pre_pad: float = 0.005,
146
+ min_hit_dur: float = 0.03,
147
+ max_hit_dur: float = 0.8,
148
+ min_gap: float = 0.02,
149
+ energy_threshold_db: float = -40.0) -> list[DrumHit]:
150
+ """Detect drum hit onsets and segment into individual hits."""
151
+ print("\n" + "=" * 60)
152
+ print("STAGE 2: Detecting drum hit onsets")
153
+ print("=" * 60)
154
+
155
+ # Multi-band onset detection for better drum coverage
156
+ onset_env_low = librosa.onset.onset_strength(
157
+ y=y, sr=sr, fmin=20, fmax=250, aggregate=np.median
158
+ )
159
+ onset_env_mid = librosa.onset.onset_strength(
160
+ y=y, sr=sr, fmin=250, fmax=4000, aggregate=np.median
161
+ )
162
+ onset_env_high = librosa.onset.onset_strength(
163
+ y=y, sr=sr, fmin=4000, fmax=sr // 2, aggregate=np.median
164
+ )
165
+
166
+ # Combine: normalize each band, then take max across bands
167
+ def norm(x):
168
+ mx = x.max()
169
+ return x / mx if mx > 0 else x
170
+
171
+ onset_env = np.maximum(norm(onset_env_low),
172
+ np.maximum(norm(onset_env_mid), norm(onset_env_high)))
173
+
174
+ # Detect onsets
175
+ wait_frames = max(1, int(min_gap * sr / 512))
176
+ onsets_frames = librosa.onset.onset_detect(
177
+ onset_envelope=onset_env,
178
+ sr=sr,
179
+ wait=wait_frames,
180
+ pre_avg=3,
181
+ post_avg=3,
182
+ pre_max=3,
183
+ post_max=5,
184
+ backtrack=True,
185
+ units='frames'
186
+ )
187
+ onset_times = librosa.frames_to_time(onsets_frames, sr=sr)
188
+
189
+ print(f" Raw onsets detected: {len(onset_times)}")
190
+
191
+ # Segment into hits
192
+ hits = []
193
+ energy_threshold = 10 ** (energy_threshold_db / 20)
194
+
195
+ for i, t in enumerate(onset_times):
196
+ start_sample = max(0, int((t - pre_pad) * sr))
197
+
198
+ if i + 1 < len(onset_times):
199
+ next_onset_sample = int(onset_times[i + 1] * sr)
200
+ end_sample = min(next_onset_sample, start_sample + int(max_hit_dur * sr))
201
+ else:
202
+ end_sample = min(len(y), start_sample + int(max_hit_dur * sr))
203
+
204
+ segment = y[start_sample:end_sample]
205
+
206
+ if len(segment) < int(min_hit_dur * sr):
207
+ continue
208
+ rms = np.sqrt(np.mean(segment ** 2))
209
+ if rms < energy_threshold:
210
+ continue
211
+
212
+ # Fade-out to avoid clicks
213
+ fade_len = min(int(0.005 * sr), len(segment) // 4)
214
+ if fade_len > 0:
215
+ segment = segment.copy()
216
+ segment[-fade_len:] *= np.linspace(1, 0, fade_len)
217
+
218
+ spectral_centroid = float(librosa.feature.spectral_centroid(
219
+ y=segment, sr=sr
220
+ ).mean())
221
+
222
+ hit = DrumHit(
223
+ audio=segment,
224
+ sr=sr,
225
+ onset_time=t,
226
+ duration=len(segment) / sr,
227
+ index=len(hits),
228
+ rms_energy=float(rms),
229
+ spectral_centroid=spectral_centroid,
230
+ )
231
+ hits.append(hit)
232
+
233
+ print(f" βœ“ Valid hits after filtering: {len(hits)}")
234
+ return hits
235
+
236
+
237
+ # ─────────────────────────────────────────────────────────────────────────────
238
+ # Stage 3: Rough spectral classification + intra-drum separation
239
+ # ─────────────────────────────────────────────────────────────────────────────
240
+
241
+ def rough_spectral_label(hit: DrumHit) -> str:
242
+ """Assign a rough drum type label based on spectral characteristics."""
243
+ y, sr = hit.audio, hit.sr
244
+ centroid = hit.spectral_centroid
245
+
246
+ D = np.abs(librosa.stft(y, n_fft=2048))
247
+ freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)
248
+
249
+ low_energy = np.sum(D[(freqs >= 20) & (freqs < 200)] ** 2)
250
+ mid_energy = np.sum(D[(freqs >= 200) & (freqs < 4000)] ** 2)
251
+ high_energy = np.sum(D[(freqs >= 4000)] ** 2)
252
+ total = low_energy + mid_energy + high_energy + 1e-10
253
+
254
+ low_ratio = low_energy / total
255
+ mid_ratio = mid_energy / total
256
+ high_ratio = high_energy / total
257
+ zcr = float(librosa.feature.zero_crossing_rate(y=y).mean())
258
+
259
+ if low_ratio > 0.5 and centroid < 800:
260
+ return "kick"
261
+ elif high_ratio > 0.35 and centroid > 4000:
262
+ return "hihat_closed" if hit.duration < 0.15 else "hihat_open"
263
+ elif high_ratio > 0.25 and centroid > 3000:
264
+ return "cymbal"
265
+ elif mid_ratio > 0.4 and zcr > 0.1 and centroid > 1000:
266
+ return "snare"
267
+ elif low_ratio > 0.3 and mid_ratio > 0.3:
268
+ return "tom"
269
+ elif centroid > 2500:
270
+ return "perc_high"
271
+ else:
272
+ return "perc_low"
273
+
274
+
275
+ def spectral_separate_hit(hit: DrumHit) -> dict[str, np.ndarray]:
276
+ """Decompose a single hit into spectral bands (kick/snare/hihat ranges)."""
277
+ y, sr = hit.audio, hit.sr
278
+ D = librosa.stft(y, n_fft=2048)
279
+ freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)
280
+
281
+ bands = {
282
+ "low": (20, 250), # kick range
283
+ "mid": (250, 4000), # snare/tom range
284
+ "high": (4000, sr // 2) # hihat/cymbal range
285
+ }
286
+
287
+ results = {}
288
+ for name, (fmin, fmax) in bands.items():
289
+ mask = (freqs >= fmin) & (freqs <= fmax)
290
+ D_band = np.zeros_like(D)
291
+ D_band[mask] = D[mask]
292
+ audio_band = librosa.istft(D_band, length=len(y))
293
+
294
+ if np.sqrt(np.mean(audio_band ** 2)) > 0.001:
295
+ results[name] = audio_band
296
+
297
+ return results
298
+
299
+
300
+ def classify_and_separate_hits(hits: list[DrumHit],
301
+ separate_overlaps: bool = True) -> list[DrumHit]:
302
+ """Classify hits and optionally split overlapping sounds into sub-hits."""
303
+ print("\n" + "=" * 60)
304
+ print("STAGE 3: Spectral classification & separation")
305
+ print("=" * 60)
306
+
307
+ all_hits = []
308
+ overlap_count = 0
309
+
310
+ for hit in hits:
311
+ label = rough_spectral_label(hit)
312
+ hit.rough_label = label
313
+
314
+ if separate_overlaps:
315
+ bands = spectral_separate_hit(hit)
316
+ if len(bands) >= 2:
317
+ energies = {k: np.sqrt(np.mean(v ** 2)) for k, v in bands.items()}
318
+ max_e = max(energies.values())
319
+ significant = {k: v for k, v in bands.items()
320
+ if energies[k] > 0.15 * max_e}
321
+
322
+ if len(significant) >= 2:
323
+ overlap_count += 1
324
+ band_labels = {"low": "kick", "mid": "snare", "high": "hihat"}
325
+ for band_name, band_audio in significant.items():
326
+ sub_hit = DrumHit(
327
+ audio=band_audio,
328
+ sr=hit.sr,
329
+ onset_time=hit.onset_time,
330
+ duration=hit.duration,
331
+ index=len(all_hits),
332
+ rms_energy=float(np.sqrt(np.mean(band_audio ** 2))),
333
+ spectral_centroid=float(librosa.feature.spectral_centroid(
334
+ y=band_audio, sr=hit.sr
335
+ ).mean()),
336
+ rough_label=band_labels.get(band_name, "other"),
337
+ )
338
+ all_hits.append(sub_hit)
339
+ continue
340
+
341
+ hit.index = len(all_hits)
342
+ all_hits.append(hit)
343
+
344
+ label_counts = defaultdict(int)
345
+ for h in all_hits:
346
+ label_counts[h.rough_label] += 1
347
+
348
+ print(f" Overlapping hits decomposed: {overlap_count}")
349
+ print(f" Total hits after separation: {len(all_hits)}")
350
+ print(f" Label distribution:")
351
+ for label, count in sorted(label_counts.items(), key=lambda x: -x[1]):
352
+ print(f" {label}: {count}")
353
+
354
+ return all_hits
355
+
356
+
357
+ # ─────────────────────────────────────────────────────────────────────────────
358
+ # Stage 4: Embedding & Clustering
359
+ # ─────────────────────────────────────────────────────────────────────────────
360
+
361
+ def compute_librosa_embeddings(hits: list[DrumHit]) -> np.ndarray:
362
+ """Compute rich librosa feature embeddings (58-dim) for all hits."""
363
+ embeddings = []
364
+ for hit in hits:
365
+ y, sr = hit.audio, hit.sr
366
+
367
+ min_len = int(0.05 * sr)
368
+ if len(y) < min_len:
369
+ y = np.pad(y, (0, min_len - len(y)))
370
+
371
+ mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20)
372
+ mfcc_mean = mfcc.mean(axis=1)
373
+ mfcc_std = mfcc.std(axis=1)
374
+
375
+ centroid = librosa.feature.spectral_centroid(y=y, sr=sr)
376
+ bandwidth = librosa.feature.spectral_bandwidth(y=y, sr=sr)
377
+ rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)
378
+ contrast = librosa.feature.spectral_contrast(y=y, sr=sr, n_bands=4)
379
+ flatness = librosa.feature.spectral_flatness(y=y)
380
+ zcr = librosa.feature.zero_crossing_rate(y=y)
381
+ rms = librosa.feature.rms(y=y)
382
+
383
+ onset_env = librosa.onset.onset_strength(y=y, sr=sr)
384
+ if len(onset_env) > 1:
385
+ onset_env_norm = onset_env / (onset_env.max() + 1e-10)
386
+ attack_feats = [
387
+ onset_env_norm.mean(),
388
+ onset_env_norm.std(),
389
+ float(np.argmax(onset_env_norm)) / len(onset_env_norm),
390
+ onset_env_norm[-1] if len(onset_env_norm) > 0 else 0,
391
+ ]
392
+ else:
393
+ attack_feats = [0, 0, 0, 0]
394
+
395
+ feat = np.concatenate([
396
+ mfcc_mean, # 20
397
+ mfcc_std, # 20
398
+ [centroid.mean(), centroid.std()], # 2
399
+ [bandwidth.mean(), bandwidth.std()], # 2
400
+ [rolloff.mean()], # 1
401
+ contrast.mean(axis=1), # 5
402
+ [flatness.mean()], # 1
403
+ [zcr.mean()], # 1
404
+ [rms.mean()], # 1
405
+ attack_feats, # 4
406
+ [hit.duration], # 1
407
+ ])
408
+ embeddings.append(feat)
409
+
410
+ embeddings = np.array(embeddings, dtype=np.float32)
411
+ mean = embeddings.mean(axis=0)
412
+ std = embeddings.std(axis=0) + 1e-8
413
+ embeddings = (embeddings - mean) / std
414
+
415
+ return embeddings
416
+
417
+
418
+ def compute_clap_embeddings(hits: list[DrumHit], device: str = "cpu") -> np.ndarray:
419
+ """Compute CLAP audio embeddings (semantic, 512-dim)."""
420
+ from transformers import ClapModel, ClapProcessor
421
+
422
+ print(" Loading CLAP model (laion/larger_clap_general)...")
423
+ model = ClapModel.from_pretrained("laion/larger_clap_general").to(device)
424
+ processor = ClapProcessor.from_pretrained("laion/larger_clap_general")
425
+ model.eval()
426
+
427
+ clap_sr = 48000
428
+ embeddings = []
429
+
430
+ for i, hit in enumerate(hits):
431
+ y_48k = librosa.resample(hit.audio, orig_sr=hit.sr, target_sr=clap_sr)
432
+ min_samples = int(0.5 * clap_sr)
433
+ if len(y_48k) < min_samples:
434
+ y_48k = np.pad(y_48k, (0, min_samples - len(y_48k)))
435
+
436
+ inputs = processor(audios=y_48k, sampling_rate=clap_sr, return_tensors="pt")
437
+ inputs = {k: v.to(device) for k, v in inputs.items()}
438
+
439
+ with torch.no_grad():
440
+ audio_embed = model.get_audio_features(**inputs)
441
+ embeddings.append(audio_embed.squeeze().cpu().numpy())
442
+
443
+ if (i + 1) % 50 == 0:
444
+ print(f" Embedded {i + 1}/{len(hits)}")
445
+
446
+ return np.array(embeddings, dtype=np.float32)
447
+
448
+
449
+ def cluster_hits(hits: list[DrumHit],
450
+ embeddings: np.ndarray,
451
+ min_clusters: int = 2,
452
+ max_clusters: int = 30) -> list[DrumCluster]:
453
+ """Cluster hits by embedding similarity, auto-selecting K via silhouette."""
454
+ from sklearn.cluster import KMeans
455
+ from sklearn.metrics import silhouette_score
456
+
457
+ print("\n" + "=" * 60)
458
+ print("STAGE 4: Clustering similar drum hits")
459
+ print("=" * 60)
460
+
461
+ n = len(hits)
462
+ max_clusters = min(max_clusters, n - 1)
463
+ if max_clusters < min_clusters:
464
+ max_clusters = min_clusters
465
+
466
+ # First group by rough label, then sub-cluster within each group
467
+ label_groups = defaultdict(list)
468
+ for i, hit in enumerate(hits):
469
+ label_groups[hit.rough_label].append(i)
470
+
471
+ all_clusters = []
472
+
473
+ for label, indices in label_groups.items():
474
+ if len(indices) < 2:
475
+ cluster = DrumCluster(
476
+ cluster_id=len(all_clusters),
477
+ label=f"{label}_0",
478
+ hits=[hits[i] for i in indices]
479
+ )
480
+ all_clusters.append(cluster)
481
+ continue
482
+
483
+ group_embeddings = embeddings[indices]
484
+ max_k = min(max(2, len(indices) // 3), 15)
485
+ best_k, best_score = 1, -1
486
+
487
+ for k in range(2, max_k + 1):
488
+ try:
489
+ km = KMeans(n_clusters=k, random_state=42, n_init=10, max_iter=300)
490
+ sub_labels = km.fit_predict(group_embeddings)
491
+ score = silhouette_score(group_embeddings, sub_labels)
492
+ if score > best_score:
493
+ best_k, best_score = k, score
494
+ except ValueError:
495
+ continue
496
+
497
+ if best_k >= 2:
498
+ km = KMeans(n_clusters=best_k, random_state=42, n_init=10)
499
+ sub_labels = km.fit_predict(group_embeddings)
500
+ else:
501
+ sub_labels = np.zeros(len(indices), dtype=int)
502
+
503
+ for sub_id in range(max(sub_labels) + 1):
504
+ member_mask = sub_labels == sub_id
505
+ member_indices = [indices[j] for j in range(len(indices)) if member_mask[j]]
506
+ cluster = DrumCluster(
507
+ cluster_id=len(all_clusters),
508
+ label=f"{label}_{sub_id}",
509
+ hits=[hits[i] for i in member_indices],
510
+ )
511
+ all_clusters.append(cluster)
512
+
513
+ print(f" {label}: {len(indices)} hits β†’ {best_k} sub-clusters "
514
+ f"(silhouette={best_score:.3f})")
515
+
516
+ print(f"\n βœ“ Total clusters: {len(all_clusters)}")
517
+ for c in all_clusters:
518
+ print(f" {c.label}: {c.count} hits")
519
+
520
+ return all_clusters
521
+
522
+
523
+ # ─────────────────────────────────────────────────────────────────────────────
524
+ # Stage 5: Best representative selection
525
+ # ─────────────────────────────────────────────────────────────────────────────
526
+
527
+ def select_best_representatives(clusters: list[DrumCluster],
528
+ embeddings_dict: dict = None):
529
+ """Select the best representative hit from each cluster.
530
+
531
+ Scoring: 60% representativeness (closest to centroid) + 40% energy (louder = cleaner).
532
+ """
533
+ print("\n" + "=" * 60)
534
+ print("STAGE 5: Selecting best representatives")
535
+ print("=" * 60)
536
+
537
+ for cluster in clusters:
538
+ if cluster.count == 1:
539
+ cluster.best_hit_idx = 0
540
+ continue
541
+
542
+ hit_features = []
543
+ for hit in cluster.hits:
544
+ feat = np.concatenate([
545
+ librosa.feature.mfcc(y=hit.audio, sr=hit.sr, n_mfcc=13).mean(axis=1),
546
+ [hit.rms_energy, hit.spectral_centroid, hit.duration]
547
+ ])
548
+ hit_features.append(feat)
549
+ hit_features = np.array(hit_features)
550
+
551
+ mean = hit_features.mean(axis=0)
552
+ std = hit_features.std(axis=0) + 1e-8
553
+ hit_features_norm = (hit_features - mean) / std
554
+
555
+ centroid = hit_features_norm.mean(axis=0)
556
+ centroid_dists = np.linalg.norm(hit_features_norm - centroid, axis=1)
557
+ centroid_scores = 1.0 - (centroid_dists / (centroid_dists.max() + 1e-8))
558
+
559
+ energies = np.array([h.rms_energy for h in cluster.hits])
560
+ energy_scores = energies / (energies.max() + 1e-8)
561
+
562
+ scores = 0.6 * centroid_scores + 0.4 * energy_scores
563
+ cluster.best_hit_idx = int(np.argmax(scores))
564
+
565
+ print(f" {cluster.label}: selected hit {cluster.best_hit_idx} "
566
+ f"(score={scores[cluster.best_hit_idx]:.3f}, "
567
+ f"energy={cluster.hits[cluster.best_hit_idx].rms_energy:.4f})")
568
+
569
+
570
+ # ─────────────────────────────────────────────────────────────────────────────
571
+ # Stage 6 (optional): Synthesize optimal sample from cluster
572
+ # ─────────────────────────────────────────────────────────────────────────────
573
+
574
+ def synthesize_from_cluster(cluster: DrumCluster) -> np.ndarray:
575
+ """
576
+ Synthesize an 'optimal' sample by weighted-averaging cluster members.
577
+
578
+ Aligns samples to their peak transient, normalizes lengths, then does a
579
+ weighted average in the time domain. This reduces noise/bleed while
580
+ preserving the core transient character.
581
+ """
582
+ if cluster.count == 1:
583
+ return cluster.hits[0].audio.copy()
584
+
585
+ sr = cluster.hits[0].sr
586
+ target_len = int(np.median([len(h.audio) for h in cluster.hits]))
587
+
588
+ aligned = []
589
+ weights = []
590
+ peak_pos_target = None
591
+
592
+ for i, hit in enumerate(cluster.hits):
593
+ audio = hit.audio.copy()
594
+ peak_pos = np.argmax(np.abs(audio))
595
+
596
+ if peak_pos_target is None:
597
+ peak_pos_target = peak_pos
598
+
599
+ # Shift to align peaks
600
+ shift = peak_pos_target - peak_pos
601
+ if shift > 0:
602
+ audio = np.pad(audio, (shift, 0))
603
+ elif shift < 0:
604
+ audio = audio[-shift:]
605
+
606
+ # Force exact length
607
+ if len(audio) >= target_len:
608
+ audio = audio[:target_len]
609
+ else:
610
+ audio = np.pad(audio, (0, target_len - len(audio)))
611
+
612
+ # Normalize amplitude
613
+ peak = np.abs(audio).max()
614
+ if peak > 0:
615
+ audio = audio / peak
616
+
617
+ aligned.append(audio)
618
+
619
+ # Double weight for the best sample
620
+ if i == cluster.best_hit_idx:
621
+ weights.append(2.0)
622
+ else:
623
+ weights.append(1.0)
624
+
625
+ aligned = np.array(aligned)
626
+ weights = np.array(weights)
627
+ weights = weights / weights.sum()
628
+
629
+ synthesized = np.average(aligned, axis=0, weights=weights)
630
+
631
+ peak = np.abs(synthesized).max()
632
+ if peak > 0:
633
+ synthesized = synthesized * (0.95 / peak)
634
+
635
+ return synthesized
636
+
637
+
638
+ # ─────────────────────────────────────────────────────────────────────────────
639
+ # Main pipeline
640
+ # ───��─────────────────────────────────────────────────────────────────────────
641
+
642
+ def run_pipeline(
643
+ audio_path: str,
644
+ output_dir: str = "./drum_samples",
645
+ use_gpu: bool = True,
646
+ use_clap: bool = False,
647
+ separate_overlaps: bool = True,
648
+ synthesize: bool = True,
649
+ min_hit_dur: float = 0.03,
650
+ max_hit_dur: float = 0.8,
651
+ energy_threshold_db: float = -40.0,
652
+ save_intermediates: bool = True,
653
+ ):
654
+ """Run the full drum sample extraction pipeline."""
655
+ device = "cuda" if (use_gpu and torch.cuda.is_available()) else "cpu"
656
+ print(f"Device: {device}")
657
+ print(f"Input: {audio_path}")
658
+ print(f"Output: {output_dir}")
659
+
660
+ output_dir = Path(output_dir)
661
+ output_dir.mkdir(parents=True, exist_ok=True)
662
+
663
+ # ── Stage 1: Extract drums ──
664
+ drums_audio, drums_sr = extract_drums_demucs(audio_path, device=device)
665
+
666
+ if save_intermediates:
667
+ drums_path = output_dir / "drums_stem.wav"
668
+ sf.write(str(drums_path), drums_audio, drums_sr, subtype='PCM_24')
669
+ print(f" Saved drum stem: {drums_path}")
670
+
671
+ # ── Stage 2: Detect onsets & segment ──
672
+ hits = detect_onsets(
673
+ drums_audio, drums_sr,
674
+ min_hit_dur=min_hit_dur,
675
+ max_hit_dur=max_hit_dur,
676
+ energy_threshold_db=energy_threshold_db,
677
+ )
678
+
679
+ if len(hits) == 0:
680
+ print("\n⚠ No drum hits detected! Try lowering energy_threshold_db.")
681
+ return
682
+
683
+ # ── Stage 3: Classify & optionally separate overlaps ──
684
+ hits = classify_and_separate_hits(hits, separate_overlaps=separate_overlaps)
685
+
686
+ if save_intermediates:
687
+ hits_dir = output_dir / "all_hits"
688
+ hits_dir.mkdir(exist_ok=True)
689
+ for hit in hits:
690
+ hit_path = hits_dir / f"hit_{hit.index:04d}_{hit.rough_label}_{hit.onset_time:.3f}s.wav"
691
+ hit.save(str(hit_path))
692
+
693
+ # ── Stage 4: Embed & cluster ──
694
+ print("\n" + "=" * 60)
695
+ print("STAGE 4a: Computing embeddings")
696
+ print("=" * 60)
697
+
698
+ if use_clap:
699
+ embeddings = compute_clap_embeddings(hits, device=device)
700
+ print(f" βœ“ CLAP embeddings: {embeddings.shape}")
701
+ else:
702
+ embeddings = compute_librosa_embeddings(hits)
703
+ print(f" βœ“ Librosa embeddings: {embeddings.shape}")
704
+
705
+ for i, hit in enumerate(hits):
706
+ hit.embedding = embeddings[i]
707
+
708
+ clusters = cluster_hits(hits, embeddings)
709
+
710
+ # ── Stage 5: Select best representatives ──
711
+ select_best_representatives(clusters)
712
+
713
+ # ── Stage 6: Optional synthesis ──
714
+ if synthesize:
715
+ print("\n" + "=" * 60)
716
+ print("STAGE 6: Synthesizing optimal samples")
717
+ print("=" * 60)
718
+ for cluster in clusters:
719
+ if cluster.count >= 2:
720
+ cluster.synthesized = synthesize_from_cluster(cluster)
721
+ print(f" {cluster.label}: synthesized from {cluster.count} hits")
722
+
723
+ # ── Export ──
724
+ print("\n" + "=" * 60)
725
+ print("EXPORT: Saving results")
726
+ print("=" * 60)
727
+
728
+ samples_dir = output_dir / "samples"
729
+ samples_dir.mkdir(exist_ok=True)
730
+
731
+ if synthesize:
732
+ synth_dir = output_dir / "synthesized"
733
+ synth_dir.mkdir(exist_ok=True)
734
+
735
+ manifest = []
736
+ for cluster in clusters:
737
+ best = cluster.best_hit
738
+
739
+ sample_name = f"{cluster.label}__best.wav"
740
+ sample_path = samples_dir / sample_name
741
+ best.save(str(sample_path))
742
+
743
+ entry = {
744
+ "cluster_id": cluster.cluster_id,
745
+ "label": cluster.label,
746
+ "count": cluster.count,
747
+ "best_sample": str(sample_path),
748
+ "best_onset_time": best.onset_time,
749
+ "best_duration": best.duration,
750
+ "best_rms_energy": best.rms_energy,
751
+ "best_spectral_centroid": best.spectral_centroid,
752
+ }
753
+
754
+ if synthesize and cluster.synthesized is not None:
755
+ synth_name = f"{cluster.label}__synthesized.wav"
756
+ synth_path = synth_dir / synth_name
757
+ sf.write(str(synth_path), cluster.synthesized, best.sr, subtype='PCM_24')
758
+ entry["synthesized_sample"] = str(synth_path)
759
+
760
+ manifest.append(entry)
761
+ print(f" βœ“ {cluster.label}: {cluster.count} hits β†’ {sample_path.name}")
762
+
763
+ # Save manifest
764
+ manifest_path = output_dir / "manifest.json"
765
+ with open(manifest_path, "w") as f:
766
+ json.dump(manifest, f, indent=2)
767
+ print(f"\n Manifest saved: {manifest_path}")
768
+
769
+ # Summary
770
+ print("\n" + "=" * 60)
771
+ print("SUMMARY")
772
+ print("=" * 60)
773
+ print(f" Input: {audio_path}")
774
+ print(f" Drum stem: {output_dir / 'drums_stem.wav'}")
775
+ print(f" Total hits: {len(hits)}")
776
+ print(f" Clusters: {len(clusters)}")
777
+ print(f" Samples saved: {samples_dir}")
778
+ if synthesize:
779
+ print(f" Synthesized: {synth_dir}")
780
+ print(f" Manifest: {manifest_path}")
781
+
782
+ return clusters
783
+
784
+
785
+ # ─────────────────────────────────────────────────────────────────────────────
786
+ # CLI
787
+ # ─────────────────────────────────────────────────────────────────────────────
788
+
789
+ def main():
790
+ parser = argparse.ArgumentParser(
791
+ description="Extract individual drum samples from an audio file",
792
+ formatter_class=argparse.RawDescriptionHelpFormatter,
793
+ epilog="""
794
+ Examples:
795
+ %(prog)s song.mp3 -o ./my_samples
796
+ %(prog)s drums.wav -o ./samples --no-gpu
797
+ %(prog)s song.wav -o ./samples --clap # Use CLAP for semantic clustering
798
+ %(prog)s song.wav -o ./samples --no-separate # Don't decompose overlaps
799
+ %(prog)s song.wav -o ./samples --no-synthesize # Skip synthesis step
800
+ """
801
+ )
802
+ parser.add_argument("input", help="Input audio file (mp3, wav, flac, etc.)")
803
+ parser.add_argument("-o", "--output-dir", default="./drum_samples",
804
+ help="Output directory (default: ./drum_samples)")
805
+ parser.add_argument("--no-gpu", action="store_true",
806
+ help="Force CPU-only processing")
807
+ parser.add_argument("--clap", action="store_true",
808
+ help="Use CLAP embeddings for clustering (slower, more semantic)")
809
+ parser.add_argument("--no-separate", action="store_true",
810
+ help="Don't separate overlapping drum sounds")
811
+ parser.add_argument("--no-synthesize", action="store_true",
812
+ help="Don't synthesize optimal samples from clusters")
813
+ parser.add_argument("--no-intermediates", action="store_true",
814
+ help="Don't save intermediate files (drum stem, individual hits)")
815
+ parser.add_argument("--min-hit-dur", type=float, default=0.03,
816
+ help="Minimum hit duration in seconds (default: 0.03)")
817
+ parser.add_argument("--max-hit-dur", type=float, default=0.8,
818
+ help="Maximum hit duration in seconds (default: 0.8)")
819
+ parser.add_argument("--energy-threshold", type=float, default=-40.0,
820
+ help="Energy threshold in dB for hit filtering (default: -40)")
821
+
822
+ args = parser.parse_args()
823
+
824
+ if not os.path.exists(args.input):
825
+ print(f"Error: Input file not found: {args.input}")
826
+ sys.exit(1)
827
+
828
+ run_pipeline(
829
+ audio_path=args.input,
830
+ output_dir=args.output_dir,
831
+ use_gpu=not args.no_gpu,
832
+ use_clap=args.clap,
833
+ separate_overlaps=not args.no_separate,
834
+ synthesize=not args.no_synthesize,
835
+ min_hit_dur=args.min_hit_dur,
836
+ max_hit_dur=args.max_hit_dur,
837
+ energy_threshold_db=args.energy_threshold,
838
+ save_intermediates=not args.no_intermediates,
839
+ )
840
+
841
+
842
+ if __name__ == "__main__":
843
+ main()