AIvry commited on
Commit
1832e16
·
verified ·
1 Parent(s): 226ddaf

Upload 12 files

Browse files
Files changed (12) hide show
  1. argshield.py +144 -0
  2. audio.py +61 -0
  3. config.py +33 -0
  4. distortions.py +339 -0
  5. engine.py +455 -0
  6. hf_readme.md +136 -0
  7. hf_requirements.txt +26 -0
  8. init.py +4 -0
  9. main.py +24 -0
  10. metrics.py +549 -0
  11. models.py +333 -0
  12. utils.py +231 -0
argshield.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import argparse
3
+ import json
4
+ from pathlib import Path, PurePath
5
+ import importlib.util
6
+
7
+ from config import DEFAULT_ALPHA
8
+ from models import get_model_config
9
+
10
+ # Central table for default layers per model (kept identical to original table)
11
+ MODEL_DEFAULT_LAYER = {
12
+ "raw": None,
13
+ "wavlm": 24,
14
+ "wav2vec2": 24,
15
+ "hubert": 24,
16
+ "wavlm_base": 12,
17
+ "wav2vec2_base": 12,
18
+ "hubert_base": 12,
19
+ "wav2vec2_xlsr": 24,
20
+ "ast": 12,
21
+ }
22
+
23
+ def _read_manifest_json(path: Path):
24
+ text = Path(path).read_text(encoding="utf-8")
25
+ try:
26
+ return json.loads(text)
27
+ except json.JSONDecodeError as e:
28
+ raise SystemExit(f"Manifest must be JSON. Failed to parse: {e}")
29
+
30
+ def _read_manifest_py(path: Path):
31
+ spec = importlib.util.spec_from_file_location("manifest_mod", str(path))
32
+ if spec is None or spec.loader is None:
33
+ raise SystemExit(f"Could not load Python manifest: {path}")
34
+ mod = importlib.util.module_from_spec(spec)
35
+ spec.loader.exec_module(mod) # executes the .py file
36
+
37
+ if not hasattr(mod, "MANIFEST"):
38
+ raise SystemExit(f"Python manifest {path} must define a top-level variable MANIFEST")
39
+
40
+ manifest = mod.MANIFEST
41
+
42
+ def _to_str(p):
43
+ if isinstance(p, (Path, PurePath)):
44
+ return str(p)
45
+ if isinstance(p, str):
46
+ return p
47
+ raise TypeError(f"Path entry must be str or Path, got {type(p)}: {p}")
48
+
49
+ normalized = []
50
+ try:
51
+ for item in manifest:
52
+ mix_id = item["mixture_id"]
53
+ refs = [_to_str(x) for x in item["references"]]
54
+ systems = {}
55
+ for sys_name, lst in item["systems"].items():
56
+ systems[sys_name] = [_to_str(x) for x in lst]
57
+ normalized.append({
58
+ "mixture_id": mix_id,
59
+ "references": refs,
60
+ "systems": systems,
61
+ })
62
+ except (KeyError, TypeError, ValueError) as e:
63
+ raise SystemExit(f"Malformed MANIFEST in {path}: {e}")
64
+ return normalized
65
+
66
+ def _read_manifest(path: Path):
67
+ suffix = path.suffix.lower()
68
+ if suffix in {".py"}:
69
+ return _read_manifest_py(path)
70
+ elif suffix in {".json", ".txt"}:
71
+ return _read_manifest_json(path)
72
+ else:
73
+ raise SystemExit(f"Unsupported manifest type '{suffix}'. Use .py, .json, or .txt")
74
+
75
+ def _parse_args():
76
+ parser = argparse.ArgumentParser(
77
+ description="Run PS/PM experiment from a manifest file."
78
+ )
79
+ parser.add_argument(
80
+ "--manifest",
81
+ type=Path,
82
+ required=True,
83
+ help="Path to manifest (.py with MANIFEST or .json/.txt with JSON).",
84
+ )
85
+ parser.add_argument(
86
+ "--model",
87
+ type=str,
88
+ required=True,
89
+ help=("Embedding model. Choices: "
90
+ "raw, wavlm, wav2vec2, hubert, wavlm_base, wav2vec2_base, "
91
+ "hubert_base, wav2vec2_xlsr, ast"),
92
+ )
93
+ parser.add_argument(
94
+ "--layer",
95
+ type=int,
96
+ default=None,
97
+ help="Optional layer (validated per model). Omit to use the model default.",
98
+ )
99
+ parser.add_argument(
100
+ "--alpha",
101
+ type=float,
102
+ default=None,
103
+ help="Optional diffusion-maps alpha in [0,1] (default: config DEFAULT_ALPHA).",
104
+ )
105
+ parser.add_argument("--verbose", action="store_true", help="Verbose logging.")
106
+ parser.add_argument("--max-gpus", type=int, default=None, help="Limit GPUs to use (must be >= 0).")
107
+ return parser.parse_args()
108
+
109
+ def _validate_and_resolve(model: str, layer_opt: int|None, alpha_opt: float|None):
110
+ allowed_models = set(get_model_config(0).keys())
111
+ if model not in allowed_models:
112
+ raise SystemExit(f"Unknown --model '{model}'. Allowed: {sorted(allowed_models)}")
113
+
114
+ max_layer = MODEL_DEFAULT_LAYER.get(model)
115
+ if model == "raw":
116
+ layer_final = 0 if layer_opt is None else int(layer_opt)
117
+ else:
118
+ if layer_opt is None:
119
+ if max_layer is None:
120
+ raise SystemExit(f"--layer must be provided for model '{model}'.")
121
+ layer_final = max_layer
122
+ else:
123
+ layer_final = int(layer_opt)
124
+ if max_layer is not None and not (0 <= layer_final <= max_layer):
125
+ raise SystemExit(
126
+ f"--layer {layer_final} is out of range for '{model}'. "
127
+ f"Expected 0..{max_layer} (or omit to use default {max_layer})."
128
+ )
129
+
130
+ alpha_final = DEFAULT_ALPHA if alpha_opt is None else float(alpha_opt)
131
+ if not (0.0 <= alpha_final <= 1.0):
132
+ raise SystemExit("--alpha must be in [0, 1].")
133
+ return layer_final, alpha_final
134
+
135
+ def _validate_gpus(max_gpus_opt):
136
+ if max_gpus_opt is None:
137
+ return None
138
+ try:
139
+ mg = int(max_gpus_opt)
140
+ except Exception:
141
+ raise SystemExit("--max-gpus must be an integer >= 0.")
142
+ if mg < 0:
143
+ raise SystemExit("--max-gpus must be >= 0.")
144
+ return mg
audio.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import pyloudnorm as pyln
4
+ import torch
5
+
6
+ from config import SILENCE_RATIO, SR
7
+ from utils import hungarian, safe_corr_np
8
+ import warnings
9
+ warnings.filterwarnings("ignore", message="Possible clipped samples in output.")
10
+
11
+
12
+ def loudness_normalize(wav, sr=SR, target_lufs=-23.0):
13
+ meter = pyln.Meter(sr)
14
+ loudness = meter.integrated_loudness(wav)
15
+ normalized_wav = pyln.normalize.loudness(wav, loudness, target_lufs)
16
+ peak = np.max(np.abs(normalized_wav))
17
+ if peak > 1.0:
18
+ normalized_wav = normalized_wav / max(peak, 1e-12)
19
+ return np.clip(normalized_wav, -1.0, 1.0)
20
+
21
+
22
+ def frame_rms_torch(sig, win, hop):
23
+ dev = sig.device
24
+ frames = sig.unfold(0, win, hop)
25
+ if frames.size(0) and (frames.size(0) - 1) * hop == sig.numel() - win:
26
+ frames = frames[:-1]
27
+ rms = torch.sqrt((frames**2).mean(1) + 1e-12)
28
+ return rms.to(dev)
29
+
30
+
31
+ def make_union_voiced_mask(refs_tensors, win, hop):
32
+ device = refs_tensors[0].device
33
+ rms_vecs = [frame_rms_torch(r, win, hop) for r in refs_tensors]
34
+ lengths = [v.numel() for v in rms_vecs]
35
+ L_max = max(lengths)
36
+ silent_union = torch.zeros(L_max, dtype=torch.bool, device=device)
37
+ for idx, (rms, L) in enumerate(zip(rms_vecs, lengths)):
38
+ thr = SILENCE_RATIO * torch.sqrt((refs_tensors[idx] ** 2).mean())
39
+ sil = rms <= thr
40
+ silent_union[:L] |= sil
41
+ return ~silent_union
42
+
43
+
44
+ def assign_outputs_to_refs_by_corr(ref_paths, out_paths):
45
+ if not out_paths:
46
+ return [None] * len(ref_paths)
47
+ refs = [loudness_normalize(librosa.load(str(p), sr=SR)[0]) for p in ref_paths]
48
+ outs = [loudness_normalize(librosa.load(str(p), sr=SR)[0]) for p in out_paths]
49
+ n, m = len(refs), len(outs)
50
+ K = max(n, m)
51
+ C = np.ones((K, K), dtype=np.float64)
52
+ for i in range(n):
53
+ for j in range(m):
54
+ r = safe_corr_np(refs[i], outs[j])
55
+ C[i, j] = 1.0 - (r + 1.0) * 0.5 # lower = better
56
+ ri, cj = hungarian(C)
57
+ mapping = [None] * n
58
+ for i, j in zip(ri, cj):
59
+ if i < n and j < m:
60
+ mapping[i] = int(j)
61
+ return mapping
config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ import warnings
5
+ warnings.filterwarnings(
6
+ "ignore",
7
+ category=UserWarning,
8
+ message=r"^expandable_segments not supported on this platform"
9
+ )
10
+
11
+ SR = 16_000
12
+ RESULTS_ROOT = "results"
13
+ BATCH_SIZE = 2
14
+ ENERGY_WIN_MS = 20
15
+ ENERGY_HOP_MS = 20
16
+ SILENCE_RATIO = 0.1
17
+ EPS = 1e-4
18
+ COV_TOL = 1e-6
19
+
20
+ DEFAULT_LAYER = 2
21
+ DEFAULT_ADD_CI = True
22
+ DEFAULT_DELTA_CI = 0.05
23
+ DEFAULT_ALPHA = 1.0
24
+
25
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True,garbage_collection_threshold:0.6"
26
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
27
+
28
+ torch.backends.cudnn.benchmark = True
29
+ torch.backends.cudnn.deterministic = False
30
+ torch.backends.cudnn.enabled = True
31
+
32
+ if torch.cuda.is_available():
33
+ torch.cuda.set_per_process_memory_fraction(0.8)
distortions.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ from numpy.fft import irfft, rfft, rfftfreq
4
+ from scipy.signal import butter, filtfilt, lfilter
5
+
6
+ from config import ENERGY_WIN_MS, EPS, SR
7
+
8
+
9
+ def sig_stats(x):
10
+ A_pk = max(np.max(np.abs(x)), EPS)
11
+ A_rms = max(np.sqrt(np.mean(x**2)), EPS)
12
+ A_95 = max(np.percentile(np.abs(x), 95), EPS)
13
+ return A_pk, A_rms, A_95
14
+
15
+
16
+ def frame_distortions(
17
+ frame,
18
+ sr,
19
+ distortion_keys,
20
+ notch_freqs=None,
21
+ low_cutoffs=None,
22
+ high_cutoffs=None,
23
+ frame_start=0,
24
+ ):
25
+ notch_freqs = [] if notch_freqs is None else notch_freqs
26
+ low_cutoffs = [] if low_cutoffs is None else low_cutoffs
27
+ high_cutoffs = [] if high_cutoffs is None else high_cutoffs
28
+ distortions = {}
29
+
30
+ A_pk, A_rms, A_95 = sig_stats(frame)
31
+ frame_len = len(frame)
32
+ X = rfft(frame)
33
+ freqs = rfftfreq(frame_len, 1 / sr)
34
+ t = np.arange(frame_len) / sr
35
+
36
+ if ("notch" in distortion_keys) or distortion_keys == "all":
37
+ bw = 60.0
38
+ for f0 in notch_freqs:
39
+ Y = X.copy()
40
+ band = (freqs > f0 - bw) & (freqs < f0 + bw)
41
+ Y[band] = 0
42
+ distortions[f"Notch_{int(round(f0))}Hz"] = irfft(Y, n=len(frame))
43
+
44
+ if ("comb" in distortion_keys) or distortion_keys == "all":
45
+ for d_ms, decay in zip([2.5, 5, 7.5, 10, 12.5, 15], [0.4, 0.5, 0.6, 0.7, 0.9]):
46
+ D = int(sr * d_ms / 1000)
47
+ if D >= frame_len:
48
+ continue
49
+ out = frame.copy()
50
+ out[:-D] += decay * frame[D:]
51
+ distortions[f"Comb_{int(d_ms)}ms"] = out
52
+
53
+ if ("tremolo" in distortion_keys) or distortion_keys == "all":
54
+ depth = 1.0
55
+ t_centre = (frame_start + 0.5 * len(frame)) / sr
56
+ for r_hz in [1, 2, 4, 6]:
57
+ mod = (1 - depth) + depth * 0.5 * (1 + np.sin(2 * np.pi * r_hz * t_centre))
58
+ distortions[f"Tremolo_{r_hz}Hz"] = frame * mod
59
+
60
+ if ("noise" in distortion_keys) or distortion_keys == "all":
61
+ nyq = sr / 2
62
+ low_norm = 20 / nyq
63
+ high_freq = min(20_000, 0.45 * sr)
64
+ high_norm = min(high_freq / nyq, 0.99)
65
+ b_band, a_band = butter(5, [low_norm, high_norm], btype="band")
66
+
67
+ def add_noise(sig, snr_db, color="white"):
68
+ nl_target = 10 ** (snr_db / 10)
69
+ n = np.random.randn(len(sig))
70
+ if color == "pink":
71
+ n = np.cumsum(n)
72
+ n /= max(np.max(np.abs(n)), 1e-12)
73
+ elif color == "brown":
74
+ n = np.cumsum(np.cumsum(n))
75
+ n /= max(np.max(np.abs(n)), 1e-12)
76
+ n = lfilter(b_band, a_band, n)
77
+ rms_sig = np.sqrt(np.mean(sig**2))
78
+ rms_n = np.sqrt(np.mean(n**2)) + 1e-12
79
+ noise_rms = rms_sig / np.sqrt(nl_target)
80
+ noise_rms = max(noise_rms, rms_sig / np.sqrt(10 ** (15 / 10)))
81
+ n *= noise_rms / rms_n
82
+ return sig + n
83
+
84
+ for snr in [-15, -10, -5, 0, 5, 10, 15, 20, 25]:
85
+ for clr in ["white", "pink", "brown"]:
86
+ if (snr in [-15, -10, -5]) and (clr == "white"):
87
+ continue
88
+ distortions[f"{clr.capitalize()}Noise_{snr}dB"] = add_noise(
89
+ frame, snr, clr
90
+ )
91
+
92
+ if ("harmonic" in distortion_keys) or distortion_keys == "all":
93
+ for f_h, rel_amp in zip([100, 500, 1000, 4000], [0.4, 0.6, 0.8, 1.0]):
94
+ tone = (rel_amp * A_rms) * np.sin(2 * np.pi * f_h * t)
95
+ distortions[f"Harmonic_{f_h}Hz"] = frame + tone
96
+
97
+ if ("reverb" in distortion_keys) or distortion_keys == "all":
98
+ for tail_ms, decay in zip([50, 100, 200, 400], [0.3, 0.5, 0.7, 0.9]):
99
+ L = int(sr * tail_ms / 1000)
100
+ if L >= frame_len:
101
+ continue
102
+ irv = np.exp(-np.linspace(0, 6, L)) * decay
103
+ reverbed = np.convolve(frame, irv)[:frame_len]
104
+ distortions[f"Reverb_{tail_ms}ms"] = reverbed
105
+
106
+ if ("noisegate" in distortion_keys) or distortion_keys == "all":
107
+ for pct in [0.05, 0.10, 0.20, 0.40]:
108
+ thr = pct * A_95
109
+ g = frame.copy()
110
+ g[np.abs(g) < thr] = 0
111
+ distortions[f"NoiseGate_{int(pct * 100)}pct"] = g
112
+
113
+ if ("pitch_shift" in distortion_keys) or distortion_keys == "all":
114
+ n_fft = min(2048, frame_len // 2)
115
+ for shift in [-4, -2, 2, 4]:
116
+ y = librosa.effects.pitch_shift(frame, sr=sr, n_steps=shift, n_fft=n_fft)
117
+ distortions[f"PitchShift_{shift}st"] = y[:frame_len]
118
+
119
+ if ("lowpass" in distortion_keys) or distortion_keys == "all":
120
+ for fc in low_cutoffs:
121
+ if fc >= sr / 2 * 0.99:
122
+ continue
123
+ b, a = butter(6, fc / (sr / 2), btype="low")
124
+ distortions[f"Lowpass_{fc}Hz"] = filtfilt(b, a, frame)
125
+
126
+ if ("highpass" in distortion_keys) or distortion_keys == "all":
127
+ for fc in high_cutoffs:
128
+ if fc <= 20:
129
+ continue
130
+ b, a = butter(6, fc / (sr / 2), btype="high")
131
+ distortions[f"Highpass_{fc}Hz"] = filtfilt(b, a, frame)
132
+
133
+ if ("echo" in distortion_keys) or distortion_keys == "all":
134
+ for delay_ms, amp in zip([50, 100, 150], [0.4, 0.5, 0.7]):
135
+ D = int(sr * delay_ms / 1000)
136
+ if D >= frame_len:
137
+ continue
138
+ echo = np.pad(frame, (D, 0), "constant")[:-D] * amp
139
+ distortions[f"Echo_{delay_ms}ms"] = frame + echo
140
+
141
+ if ("clipping" in distortion_keys) or distortion_keys == "all":
142
+ for frac in [0.70, 0.50, 0.30]:
143
+ thr = frac * A_95
144
+ distortions[f"Clipping_{frac:.2f}p95"] = np.clip(frame, -thr, thr)
145
+
146
+ if ("vibrato" in distortion_keys) or distortion_keys == "all":
147
+ n_fft = min(2048, frame_len // 2)
148
+ base_depth = 0.03 * (A_rms / A_pk)
149
+ for rate_hz, scale in zip([3, 5, 7], [1.0, 1.3, 1.6]):
150
+ depth = np.clip(base_depth * scale, 0.01, 0.05)
151
+ y = librosa.effects.time_stretch(frame, rate=1 + depth, n_fft=n_fft)
152
+ distortions[f"Vibrato_{rate_hz}Hz"] = librosa.util.fix_length(
153
+ y, size=frame_len
154
+ )
155
+
156
+ return distortions
157
+
158
+
159
+ def apply_adv_distortions(ref, distortion_keys, sr=SR):
160
+ frame_len = int(ENERGY_WIN_MS * sr / 1000)
161
+ n_frames = int(np.ceil(len(ref) / frame_len))
162
+ pad_len = n_frames * frame_len - len(ref)
163
+ ref_padded = (
164
+ np.concatenate([ref, np.zeros(pad_len, dtype=ref.dtype)]) if pad_len else ref
165
+ )
166
+
167
+ X_full = rfft(ref_padded)
168
+ freqs_f = rfftfreq(len(ref_padded), 1 / sr)
169
+ mag_full = np.abs(X_full)
170
+
171
+ valid = (freqs_f > 80) & (freqs_f < 0.45 * sr)
172
+ cand_indices = np.argsort(mag_full[valid])[-60:]
173
+ cand_freqs = freqs_f[valid][cand_indices]
174
+ cand_freqs = cand_freqs[np.argsort(mag_full[valid][cand_indices])[::-1]]
175
+
176
+ selected_notch_freqs = []
177
+ for f0 in cand_freqs:
178
+ if all(abs(f0 - f_sel) > 300 for f_sel in selected_notch_freqs):
179
+ selected_notch_freqs.append(float(f0))
180
+ if len(selected_notch_freqs) >= 20:
181
+ break
182
+
183
+ mag2 = np.abs(X_full) ** 2
184
+ total_p = mag2.sum()
185
+ cum_low = np.cumsum(mag2)
186
+ q_low = [0.50, 0.70, 0.85, 0.95]
187
+ lowpass_cutoffs = []
188
+ for q in q_low:
189
+ idx = np.searchsorted(cum_low, q * total_p)
190
+ f_c = float(freqs_f[idx])
191
+ lowpass_cutoffs.append(round(f_c / 100.0) * 100)
192
+
193
+ cum_high = np.cumsum(mag2[::-1])
194
+ q_high = [0.05, 0.15, 0.30, 0.50]
195
+ highpass_cutoffs = []
196
+ for q in q_high:
197
+ idx = np.searchsorted(cum_high, q * total_p)
198
+ f_c = float(freqs_f[-1 - idx])
199
+ highpass_cutoffs.append(round(f_c / 100.0) * 100)
200
+
201
+ lowpass_cutoffs = sorted(set(lowpass_cutoffs))
202
+ highpass_cutoffs = sorted(set(highpass_cutoffs))
203
+
204
+ out = {}
205
+ for f in range(n_frames):
206
+ start, end = f * frame_len, (f + 1) * frame_len
207
+ frame = ref_padded[start:end]
208
+ frame_dists = frame_distortions(
209
+ frame,
210
+ sr,
211
+ distortion_keys,
212
+ notch_freqs=selected_notch_freqs,
213
+ low_cutoffs=lowpass_cutoffs,
214
+ high_cutoffs=highpass_cutoffs,
215
+ frame_start=start,
216
+ )
217
+ for lbl, sig in frame_dists.items():
218
+ if lbl not in out:
219
+ out[lbl] = np.zeros_like(ref_padded)
220
+ out[lbl][start:end] = sig
221
+
222
+ return list(out.values())
223
+
224
+
225
+ def apply_distortions(ref, distortion_keys, sr=SR):
226
+ distortions = {}
227
+ X = rfft(ref)
228
+ freqs = rfftfreq(len(ref), 1 / sr)
229
+ t = np.arange(len(ref)) / sr
230
+
231
+ if ("notch" in distortion_keys) or distortion_keys == "all":
232
+ for c in [500, 1000, 2000, 4000, 8000]:
233
+ Y = X.copy()
234
+ Y[(freqs > c - 50) & (freqs < c + 50)] = 0
235
+ distortions[f"Notch_{c}Hz"] = irfft(Y, n=len(ref))
236
+
237
+ if ("comb" in distortion_keys) or distortion_keys == "all":
238
+ for d, decay in zip([2.5, 5, 7.5, 10, 12.5, 15], [0.4, 0.5, 0.6, 0.7, 0.9]):
239
+ D = int(sr * d / 1000)
240
+ if D >= len(ref):
241
+ continue
242
+ cpy = ref.copy()
243
+ if len(ref) > D:
244
+ cpy[:-D] += decay * ref[D:]
245
+ distortions[f"Comb_{int(d)}ms"] = cpy
246
+
247
+ if ("tremolo" in distortion_keys) or distortion_keys == "all":
248
+ for r, depth in zip([1, 2, 4, 6], [0.3, 0.5, 0.8, 1.0]):
249
+ mod = (1 - depth) + depth * 0.5 * (1 + np.sin(2 * np.pi * r * t))
250
+ distortions[f"Tremolo_{r}Hz"] = ref * mod
251
+
252
+ if ("noise" in distortion_keys) or distortion_keys == "all":
253
+
254
+ def add_noise(signal, snr_db, color):
255
+ rms = np.sqrt(np.mean(signal**2))
256
+ nl = 10 ** (snr_db / 10)
257
+ noise_rms = rms / np.sqrt(nl)
258
+ n = np.random.randn(len(signal))
259
+ if color == "pink":
260
+ n = np.cumsum(n)
261
+ n /= max(np.max(np.abs(n)), 1e-12)
262
+ elif color == "brown":
263
+ n = np.cumsum(np.cumsum(n))
264
+ n /= max(np.max(np.abs(n)), 1e-12)
265
+ return signal + noise_rms * n
266
+
267
+ for snr in [-15, -10, -5, 0, 5, 10, 15, 20, 25]:
268
+ for clr in ["white", "pink", "brown"]:
269
+ if snr in [-15, -10, -5] and clr in ["white"]:
270
+ continue
271
+ distortions[f"{clr.capitalize()}Noise_{snr}dB"] = add_noise(
272
+ ref, snr, clr
273
+ )
274
+
275
+ if ("harmonic" in distortion_keys) or distortion_keys == "all":
276
+ for f_h, amp in zip([100, 500, 1000, 4000], [0.02, 0.03, 0.05, 0.08]):
277
+ tone = amp * np.sin(2 * np.pi * f_h * t)
278
+ distortions[f"Harmonic_{f_h}Hz"] = ref + tone
279
+
280
+ if ("reverb" in distortion_keys) or distortion_keys == "all":
281
+ for tail_ms, decay in zip([5, 10, 15, 20], [0.3, 0.5, 0.7, 0.9, 1.1]):
282
+ L = int(sr * tail_ms / 1000)
283
+ if L >= len(ref):
284
+ continue
285
+ irv = np.exp(-np.linspace(0, 3, L)) * decay
286
+ reverbed = np.convolve(ref, irv)[: len(ref)]
287
+ distortions[f"Reverb_{tail_ms}ms"] = reverbed
288
+
289
+ if ("noisegate" in distortion_keys) or distortion_keys == "all":
290
+ for thr in [0.005, 0.01, 0.02, 0.04]:
291
+ g = ref.copy()
292
+ g[np.abs(g) < thr] = 0
293
+ distortions[f"NoiseGate_{thr}"] = g
294
+
295
+ if ("pitch_shift" in distortion_keys) or distortion_keys == "all":
296
+ n_fft = min(2048, len(ref) // 2)
297
+ for shift in [-4, -2, 2, 4]:
298
+ shifted = librosa.effects.pitch_shift(
299
+ y=ref, sr=sr, n_steps=shift, n_fft=n_fft
300
+ )
301
+ distortions[f"PitchShift_{shift}st"] = shifted[: len(ref)]
302
+
303
+ if ("lowpass" in distortion_keys) or distortion_keys == "all":
304
+ for freq in [2000, 3000, 4000, 6000]:
305
+ if freq >= (sr / 2):
306
+ continue
307
+ b, a = butter(4, freq / (sr / 2), "low")
308
+ distortions[f"Lowpass_{freq}Hz"] = filtfilt(b, a, ref)
309
+
310
+ if ("highpass" in distortion_keys) or distortion_keys == "all":
311
+ for freq in [100, 300, 500, 800]:
312
+ if freq >= (sr / 2):
313
+ continue
314
+ b, a = butter(4, freq / (sr / 2), "high")
315
+ distortions[f"Highpass_{freq}Hz"] = filtfilt(b, a, ref)
316
+
317
+ if ("echo" in distortion_keys) or distortion_keys == "all":
318
+ for delay_ms, amp in zip([5, 10, 15, 20], [0.3, 0.5, 0.7]):
319
+ delay = int(sr * delay_ms / 1000)
320
+ if delay >= len(ref):
321
+ continue
322
+ echo = np.pad(ref, (delay, 0), "constant")[:-delay] * amp
323
+ distortions[f"Echo_{delay_ms}ms"] = ref + echo
324
+
325
+ if ("clipping" in distortion_keys) or distortion_keys == "all":
326
+ for thr in [0.3, 0.5, 0.7]:
327
+ distortions[f"Clipping_{thr}"] = np.clip(ref, -thr, thr)
328
+
329
+ if ("vibrato" in distortion_keys) or distortion_keys == "all":
330
+ for rate, depth in zip([3, 5, 7], [0.001, 0.002, 0.003]):
331
+ vibrato = np.sin(2 * np.pi * rate * t) * depth
332
+ vibrato_signal = librosa.effects.time_stretch(
333
+ ref, rate=1 + float(vibrato.mean()), n_fft=min(2048, len(ref) // 2)
334
+ )
335
+ distortions[f"Vibrato_{rate}Hz"] = librosa.util.fix_length(
336
+ vibrato_signal, size=len(ref)
337
+ )
338
+
339
+ return list(distortions.values())
engine.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from datetime import datetime
5
+ import librosa
6
+ import pandas as pd
7
+ from audio import (
8
+ assign_outputs_to_refs_by_corr,
9
+ loudness_normalize,
10
+ make_union_voiced_mask,
11
+ )
12
+ from config import *
13
+ from distortions import apply_adv_distortions, apply_distortions
14
+ from metrics import (
15
+ compute_pm,
16
+ compute_ps,
17
+ diffusion_map_torch,
18
+ pm_ci_components_full,
19
+ ps_ci_components_full,
20
+ )
21
+ from models import embed_batch, load_model
22
+ from utils import *
23
+
24
+
25
+ def compute_mapss_measures(
26
+ models,
27
+ mixtures,
28
+ *,
29
+ systems=None,
30
+ algos=None,
31
+ experiment_id=None,
32
+ layer=DEFAULT_LAYER,
33
+ add_ci=DEFAULT_ADD_CI,
34
+ alpha=DEFAULT_ALPHA,
35
+ seed=42,
36
+ on_missing="skip",
37
+ verbose=False,
38
+ max_gpus=None,
39
+ ):
40
+ gpu_distributor = GPUWorkDistributor(max_gpus)
41
+ ngpu = get_gpu_count(max_gpus)
42
+
43
+ if on_missing not in {"skip", "error"}:
44
+ raise ValueError("on_missing must be 'skip' or 'error'.")
45
+
46
+ torch.manual_seed(seed)
47
+ random.seed(seed)
48
+ np.random.seed(seed)
49
+ if torch.cuda.is_available():
50
+ torch.cuda.manual_seed_all(seed)
51
+
52
+ canon_mix = canonicalize_mixtures(mixtures, systems=systems)
53
+
54
+ mixture_entries = []
55
+ for m in canon_mix:
56
+ entries = []
57
+ for i, refp in enumerate(m.refs):
58
+ sid = m.speaker_ids[i]
59
+ entries.append(
60
+ {"id": sid, "ref": Path(refp), "mixture": m.mixture_id, "outs": {}}
61
+ )
62
+ mixture_entries.append(entries)
63
+
64
+ for m, mix_entries in zip(canon_mix, mixture_entries):
65
+ for algo, out_list in (m.systems or {}).items():
66
+ mapping = assign_outputs_to_refs_by_corr(
67
+ [e["ref"] for e in mix_entries], out_list
68
+ )
69
+ for idx, e in enumerate(mix_entries):
70
+ j = mapping[idx]
71
+ if j is not None:
72
+ e["outs"][algo] = out_list[j]
73
+
74
+ if algos is None:
75
+ algos_to_run = sorted(
76
+ {algo for m in canon_mix for algo in (m.systems or {}).keys()}
77
+ )
78
+ else:
79
+ algos_to_run = list(algos)
80
+
81
+ exp_id = experiment_id or datetime.now().strftime("%Y%m%d_%H%M%S")
82
+ exp_root = os.path.join(RESULTS_ROOT, f"experiment_{exp_id}")
83
+ os.makedirs(exp_root, exist_ok=True)
84
+
85
+ params = {
86
+ "models": models,
87
+ "layer": layer,
88
+ "add_ci": add_ci,
89
+ "alpha": alpha,
90
+ "seed": seed,
91
+ "batch_size": BATCH_SIZE,
92
+ "ngpu": ngpu,
93
+ "max_gpus": max_gpus,
94
+ }
95
+
96
+ with open(os.path.join(exp_root, "params.json"), "w") as f:
97
+ json.dump(params, f, indent=2)
98
+
99
+ canon_struct = [
100
+ {
101
+ "mixture_id": m.mixture_id,
102
+ "references": [str(p) for p in m.refs],
103
+ "systems": {
104
+ a: [str(p) for p in outs] for a, outs in (m.systems or {}).items()
105
+ },
106
+ "speaker_ids": m.speaker_ids,
107
+ }
108
+ for m in canon_mix
109
+ ]
110
+
111
+ with open(os.path.join(exp_root, "manifest_canonical.json"), "w") as f:
112
+ json.dump(canon_struct, f, indent=2)
113
+
114
+ print(f"Starting experiment {exp_id} with {ngpu} GPUs")
115
+ print(f"Results will be saved to: {exp_root}")
116
+
117
+ clear_gpu_memory()
118
+ get_gpu_memory_info(verbose)
119
+
120
+ flat_entries = [e for mix in mixture_entries for e in mix]
121
+ all_refs = {}
122
+
123
+ if verbose:
124
+ print("Loading reference signals...")
125
+ for e in flat_entries:
126
+ wav, _ = librosa.load(str(e["ref"]), sr=SR)
127
+ all_refs[e["id"]] = torch.from_numpy(loudness_normalize(wav))
128
+
129
+ if verbose:
130
+ print("Computing voiced masks...")
131
+
132
+ win = int(ENERGY_WIN_MS * SR / 1000)
133
+ hop = int(ENERGY_HOP_MS * SR / 1000)
134
+ voiced_mask_mix = []
135
+
136
+ for i, mix in enumerate(mixture_entries):
137
+ if verbose:
138
+ print(f" Computing mask for mixture {i + 1}/{len(mixture_entries)}")
139
+
140
+ if ngpu > 0:
141
+ with torch.cuda.device(0):
142
+ refs_for_mix = [all_refs[e["id"]].cuda() for e in mix]
143
+ mask = make_union_voiced_mask(refs_for_mix, win, hop)
144
+ voiced_mask_mix.append(mask.cpu())
145
+ # Explicitly delete GPU tensors
146
+ for ref in refs_for_mix:
147
+ del ref
148
+ torch.cuda.empty_cache()
149
+ else:
150
+ refs_for_mix = [all_refs[e["id"]].cpu() for e in mix]
151
+ mask = make_union_voiced_mask(refs_for_mix, win, hop)
152
+ voiced_mask_mix.append(mask.cpu())
153
+
154
+ ordered_speakers = [e["id"] for e in flat_entries]
155
+
156
+ for algo_idx, algo in enumerate(algos_to_run):
157
+ if verbose:
158
+ print(f"\nProcessing Algorithm {algo_idx + 1}/{len(algos_to_run)}: {algo}")
159
+
160
+ algo_dir = os.path.join(exp_root, algo)
161
+ os.makedirs(algo_dir, exist_ok=True)
162
+
163
+ all_outs = {}
164
+ missing = []
165
+
166
+ for mix_idx, mix in enumerate(mixture_entries):
167
+ for e in mix:
168
+ assigned_path = e.get("outs", {}).get(algo)
169
+ if assigned_path is None:
170
+ missing.append((e["mixture"], e["id"]))
171
+ continue
172
+
173
+ wav, _ = librosa.load(str(assigned_path), sr=SR)
174
+ all_outs[e["id"]] = torch.from_numpy(loudness_normalize(wav))
175
+
176
+ if missing:
177
+ msg = f"[{algo}] missing outputs for {len(missing)} speaker(s)"
178
+ if on_missing == "error":
179
+ raise FileNotFoundError(msg)
180
+ else:
181
+ if verbose:
182
+ warnings.warn(msg + " Skipping those speakers.")
183
+
184
+ if not all_outs:
185
+ if verbose:
186
+ warnings.warn(f"[{algo}] No outputs provided. Skipping algorithm.")
187
+ continue
188
+
189
+ ps_ts = {m: {s: [] for s in ordered_speakers} for m in models}
190
+ pm_ts = {m: {s: [] for s in ordered_speakers} for m in models}
191
+ ps_bias_ts = {m: {s: [] for s in ordered_speakers} for m in models}
192
+ ps_prob_ts = {m: {s: [] for s in ordered_speakers} for m in models}
193
+ pm_bias_ts = {m: {s: [] for s in ordered_speakers} for m in models}
194
+ pm_prob_ts = {m: {s: [] for s in ordered_speakers} for m in models}
195
+
196
+ for model_idx, mname in enumerate(models):
197
+ if verbose:
198
+ print(f" Processing Model {model_idx + 1}/{len(models)}: {mname}")
199
+
200
+ for metric_type in ["PS", "PM"]:
201
+ clear_gpu_memory()
202
+ gc.collect()
203
+
204
+ model_wrapper, layer_eff = load_model(mname, layer, max_gpus)
205
+ get_gpu_memory_info(verbose)
206
+
207
+ embs_by_mix = {}
208
+ labels_by_mix = {}
209
+
210
+ for k, mix in enumerate(mixture_entries):
211
+ speakers_this_mix = [e for e in mix if e["id"] in all_outs]
212
+ if not speakers_this_mix:
213
+ continue
214
+
215
+ if verbose:
216
+ print(
217
+ f"Processing mixture {k + 1}/{len(mixture_entries)} for {metric_type}"
218
+ )
219
+
220
+ all_signals_mix = []
221
+ all_masks_mix = []
222
+ all_labels_mix = []
223
+
224
+ for e in speakers_this_mix:
225
+ s = e["id"]
226
+
227
+ if metric_type == "PS":
228
+ dists = [
229
+ loudness_normalize(d)
230
+ for d in apply_distortions(all_refs[s].numpy(), "all")
231
+ ]
232
+ else:
233
+ dists = [
234
+ loudness_normalize(d)
235
+ for d in apply_adv_distortions(
236
+ all_refs[s].numpy(), "all"
237
+ )
238
+ ]
239
+
240
+ sigs = [all_refs[s].numpy(), all_outs[s].numpy()] + dists
241
+ lbls = ["ref", "out"] + [f"d{i}" for i in range(len(dists))]
242
+
243
+ masks = [voiced_mask_mix[k]] * len(sigs)
244
+ all_signals_mix.extend(sigs)
245
+ all_masks_mix.extend(masks)
246
+ all_labels_mix.extend([f"{s}-{l}" for l in lbls])
247
+
248
+ try:
249
+ # Process in smaller batches
250
+ batch_size = min(2, BATCH_SIZE)
251
+ embeddings_list = []
252
+
253
+ for i in range(0, len(all_signals_mix), batch_size):
254
+ batch_sigs = all_signals_mix[i:i + batch_size]
255
+ batch_masks = all_masks_mix[i:i + batch_size]
256
+
257
+ batch_embs = embed_batch(
258
+ batch_sigs,
259
+ batch_masks,
260
+ model_wrapper,
261
+ layer_eff,
262
+ use_mlm=False,
263
+ )
264
+
265
+ if batch_embs.numel() > 0:
266
+ embeddings_list.append(batch_embs.cpu())
267
+
268
+ torch.cuda.empty_cache()
269
+
270
+ if embeddings_list:
271
+ embeddings = torch.cat(embeddings_list, dim=0)
272
+ embs_by_mix[k] = embeddings
273
+ labels_by_mix[k] = all_labels_mix
274
+
275
+ except Exception as ex:
276
+ if verbose:
277
+ print(f" ERROR processing mixture {k + 1}: {ex}")
278
+ continue
279
+ finally:
280
+ # Always clean up after processing a mixture
281
+ del all_signals_mix, all_masks_mix
282
+ if 'embeddings_list' in locals():
283
+ del embeddings_list
284
+ clear_gpu_memory()
285
+ gc.collect()
286
+
287
+ if verbose:
288
+ print(f" Computing {metric_type} scores for {mname}...")
289
+
290
+ # Process mixtures with their stored embeddings and labels
291
+ with ThreadPoolExecutor(
292
+ max_workers=min(2, ngpu if ngpu > 0 else 1)
293
+ ) as executor:
294
+ for k in range(len(mixture_entries)):
295
+ if k not in embs_by_mix:
296
+ continue
297
+
298
+ E, L, D = embs_by_mix[k].shape
299
+ if L == 0:
300
+ if verbose:
301
+ print(f" WARNING: mixture {k + 1} produced 0 frames after masking; skipping.")
302
+ continue
303
+
304
+ # Get the labels for this mixture
305
+ labels_for_mix = labels_by_mix[k]
306
+
307
+ def process_frame(f, embeddings_mix, labels_mix):
308
+ try:
309
+ frame_emb = embeddings_mix[:, f, :].detach().cpu().numpy()
310
+
311
+ if add_ci:
312
+ coords_d, coords_c, eigvals, k_sub_gauss = (
313
+ gpu_distributor.execute_on_gpu(
314
+ diffusion_map_torch,
315
+ frame_emb,
316
+ labels_mix,
317
+ alpha=alpha,
318
+ eig_solver="full",
319
+ return_eigs=True,
320
+ return_complement=True,
321
+ return_cval=add_ci,
322
+ )
323
+ )
324
+ else:
325
+ coords_d = gpu_distributor.execute_on_gpu(
326
+ diffusion_map_torch,
327
+ frame_emb,
328
+ labels_mix,
329
+ alpha=alpha,
330
+ eig_solver="full",
331
+ return_eigs=False,
332
+ return_complement=False,
333
+ return_cval=False,
334
+ )
335
+ coords_c = None
336
+ eigvals = None
337
+ k_sub_gauss = 1
338
+
339
+ if metric_type == "PS":
340
+ score = compute_ps(
341
+ coords_d, labels_mix, max_gpus
342
+ )
343
+ bias = prob = None
344
+ if add_ci:
345
+ bias, prob = ps_ci_components_full(
346
+ coords_d,
347
+ coords_c,
348
+ eigvals,
349
+ labels_mix,
350
+ delta=DEFAULT_DELTA_CI,
351
+ )
352
+ return f, "PS", score, bias, prob
353
+ else:
354
+ score = compute_pm(
355
+ coords_d, labels_mix, "gamma", max_gpus
356
+ )
357
+ bias = prob = None
358
+ if add_ci:
359
+ bias, prob = pm_ci_components_full(
360
+ coords_d,
361
+ coords_c,
362
+ eigvals,
363
+ labels_mix,
364
+ delta=DEFAULT_DELTA_CI,
365
+ K=k_sub_gauss,
366
+ )
367
+ return f, "PM", score, bias, prob
368
+
369
+ except Exception as ex:
370
+ if verbose:
371
+ print(f" ERROR frame {f + 1}: {ex}")
372
+ return None
373
+
374
+ futures = [
375
+ executor.submit(process_frame, f, embs_by_mix[k], labels_for_mix)
376
+ for f in range(L)
377
+ ]
378
+ for fut in futures:
379
+ result = fut.result()
380
+ if result is None:
381
+ continue
382
+
383
+ f, metric, score, bias, prob = result
384
+
385
+ if metric == "PS":
386
+ for sp in score:
387
+ ps_ts[mname][sp].append(score[sp])
388
+ if add_ci and bias is not None:
389
+ ps_bias_ts[mname][sp].append(bias[sp])
390
+ ps_prob_ts[mname][sp].append(prob[sp])
391
+ else:
392
+ for sp in score:
393
+ pm_ts[mname][sp].append(score[sp])
394
+ if add_ci and bias is not None:
395
+ pm_bias_ts[mname][sp].append(bias[sp])
396
+ pm_prob_ts[mname][sp].append(prob[sp])
397
+
398
+ # Clean up after processing all mixtures for this metric
399
+ del embs_by_mix, labels_by_mix
400
+ clear_gpu_memory()
401
+ gc.collect()
402
+
403
+ del model_wrapper
404
+ clear_gpu_memory()
405
+ gc.collect()
406
+
407
+ if verbose:
408
+ print(f" Saving results for {algo}...")
409
+
410
+ for m in models:
411
+
412
+ def _pad(vec, n):
413
+ return vec + [np.nan] * (n - len(vec))
414
+
415
+ max_len = 0
416
+ for s in ordered_speakers:
417
+ max_len = max(max_len, len(ps_ts[m][s]), len(pm_ts[m][s]))
418
+
419
+ pd.DataFrame(
420
+ {s: _pad(ps_ts[m][s], max_len) for s in ordered_speakers}
421
+ ).to_csv(os.path.join(algo_dir, f"ps_scores_{m}.csv"), index=False)
422
+
423
+ pd.DataFrame(
424
+ {s: _pad(pm_ts[m][s], max_len) for s in ordered_speakers}
425
+ ).to_csv(os.path.join(algo_dir, f"pm_scores_{m}.csv"), index=False)
426
+
427
+ if add_ci:
428
+ ci_cols = {}
429
+ for s in ordered_speakers:
430
+ ci_cols[f"{s}_ps_bias"] = _pad(ps_bias_ts[m][s], max_len)
431
+ ci_cols[f"{s}_ps_prob"] = _pad(ps_prob_ts[m][s], max_len)
432
+ ci_cols[f"{s}_pm_bias"] = _pad(pm_bias_ts[m][s], max_len)
433
+ ci_cols[f"{s}_pm_prob"] = _pad(pm_prob_ts[m][s], max_len)
434
+ pd.DataFrame(ci_cols).to_csv(
435
+ os.path.join(algo_dir, f"ci_{m}.csv"), index=False
436
+ )
437
+
438
+ del all_outs
439
+ clear_gpu_memory()
440
+ gc.collect()
441
+
442
+ print(f"\nEXPERIMENT COMPLETED")
443
+ print(f"Results saved to: {exp_root}")
444
+
445
+ del all_refs, voiced_mask_mix
446
+
447
+ # Import and call the cleanup function
448
+ from models import cleanup_all_models
449
+ cleanup_all_models()
450
+
451
+ clear_gpu_memory()
452
+ get_gpu_memory_info(verbose)
453
+ gc.collect()
454
+
455
+ return exp_root
hf_readme.md ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MAPSS Multi Source Audio Perceptual Separation Scores
3
+ emoji: 🎵
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.0.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # MAPSS: Multi-source Audio Perceptual Separation Scores
14
+
15
+ Evaluate audio source separation quality using Perceptual Similarity (PS) and Perceptual Matching (PM) metrics.
16
+
17
+ ## Features
18
+
19
+ - **Perceptual Similarity (PS)**: Measures how similar separated outputs are to reference sources in perceptual embedding space
20
+ - **Perceptual Matching (PM)**: Evaluates robustness against a comprehensive set of audio distortions
21
+ - **Multiple embedding models**: Support for WavLM, Wav2Vec2, HuBERT, AST, and more
22
+ - **Automatic output-to-reference matching**: Uses correlation-based Hungarian algorithm
23
+ - **GPU-optimized processing**: Efficient batch processing with memory management
24
+ - **Diffusion maps**: Advanced dimensionality reduction for perceptual space analysis
25
+
26
+ ## Input Format
27
+
28
+ Upload a ZIP file containing:
29
+ ```
30
+ your_mixture.zip
31
+ ├── references/ # Original clean sources
32
+ │ ├── speaker1.wav
33
+ │ ├── speaker2.wav
34
+ │ └── ...
35
+ └── outputs/ # Separated outputs from your algorithm
36
+ ├── separated1.wav
37
+ ├── separated2.wav
38
+ └── ...
39
+ ```
40
+
41
+ ### Audio Requirements
42
+ - Format: WAV files
43
+ - Sample rate: Any (automatically resampled to 16kHz)
44
+ - Channels: Mono or stereo (converted to mono)
45
+ - Number of files: Equal number of references and outputs
46
+
47
+ ## Output Format
48
+
49
+ The tool generates a ZIP file containing:
50
+ - `ps_scores_{model}.csv`: PS scores for each speaker/source (0-1, higher is better)
51
+ - `pm_scores_{model}.csv`: PM scores for each speaker/source (0-1, higher is better)
52
+ - `params.json`: Experiment parameters used
53
+ - `manifest_canonical.json`: File mapping and processing details
54
+
55
+ ### Score Interpretation
56
+ - **PS Score**: Perceptual Similarity
57
+ - 1.0 = Perfect separation (output identical to reference)
58
+ - 0.5 = Moderate separation quality
59
+ - 0.0 = Poor separation (output closer to other sources)
60
+
61
+ - **PM Score**: Perceptual Matching (robustness)
62
+ - 1.0 = Highly robust to distortions
63
+ - 0.5 = Moderate robustness
64
+ - 0.0 = Not robust (easily confused with distorted versions)
65
+
66
+ ## Available Models
67
+
68
+ | Model | Description | Default Layer | Use Case |
69
+ |-------|-------------|---------------|----------|
70
+ | `raw` | Raw waveform features | N/A | Baseline comparison |
71
+ | `wavlm` | WavLM Large | 24 | Best overall performance |
72
+ | `wav2vec2` | Wav2Vec2 Large | 24 | Strong performance |
73
+ | `hubert` | HuBERT Large | 24 | Good for speech |
74
+ | `wavlm_base` | WavLM Base | 12 | Faster, good quality |
75
+ | `wav2vec2_base` | Wav2Vec2 Base | 12 | Faster processing |
76
+ | `hubert_base` | HuBERT Base | 12 | Faster for speech |
77
+ | `wav2vec2_xlsr` | Wav2Vec2 XLSR-53 | 24 | Multilingual |
78
+ | `ast` | Audio Spectrogram Transformer | 12 | General audio |
79
+
80
+ ## Parameters
81
+
82
+ - **Model**: Select the embedding model for feature extraction
83
+ - **Layer**: Which transformer layer to use (auto-selected by default)
84
+ - **Alpha**: Diffusion maps parameter (0.0-1.0, default: 1.0)
85
+ - 0.0 = No normalization
86
+ - 1.0 = Full normalization (recommended)
87
+
88
+ ## How It Works
89
+
90
+ 1. **Feature Extraction**: Audio signals are processed through pre-trained self-supervised models to extract perceptual embeddings
91
+ 2. **Voice Activity Detection**: Automatic detection of voiced segments using energy-based masking
92
+ 3. **Diffusion Maps**: Embeddings are projected using diffusion maps for robust dimensionality reduction
93
+ 4. **PS Computation**: Measures Mahalanobis distance between separated outputs and references vs other sources
94
+ 5. **PM Computation**: Evaluates against comprehensive distortions including:
95
+ - Noise (white, pink, brown at various SNRs)
96
+ - Filtering (lowpass, highpass, notch, comb)
97
+ - Effects (reverb, echo, tremolo, vibrato)
98
+ - Distortions (clipping, pitch shift, time stretch)
99
+ 6. **Scoring**: Frame-level scores are computed and aggregated
100
+
101
+ ## Technical Details
102
+
103
+ - **Loudness normalization**: ITU-R BS.1770 standard (-23 LUFS)
104
+ - **Frame-based processing**: 20ms windows with 20ms hop
105
+ - **Correlation-based assignment**: Hungarian algorithm for optimal matching
106
+ - **Memory optimization**: Batch processing with automatic GPU memory management
107
+ - **Robust statistics**: Covariance regularization and outlier handling
108
+
109
+ ## Citation
110
+
111
+ If you use MAPSS in your research, please cite:
112
+
113
+ ```bibtex
114
+ @article{mapss2024,
115
+ title={MAPSS: Multi-source Audio Perceptual Separation Scores},
116
+ author={Your Name},
117
+ journal={arXiv preprint},
118
+ year={2024}
119
+ }
120
+ ```
121
+
122
+ ## Limitations
123
+
124
+ - Processing time scales with audio length and model size
125
+ - Memory requirements depend on number of sources and audio length
126
+ - Currently optimized for speech separation (music separation support in development)
127
+ - Maximum recommended sources: 10 per mixture
128
+
129
+ ## License
130
+
131
+ Code: MIT License
132
+ Paper: CC-BY-4.0
133
+
134
+ ## Support
135
+
136
+ For issues, questions, or contributions, please visit the [GitHub repository](https://github.com/yourusername/mapss).
hf_requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ gradio>=4.0.0
3
+ torch>=2.0.0
4
+ torchaudio>=2.0.0
5
+ transformers>=4.35.0
6
+ accelerate>=0.24.0
7
+
8
+ # Audio processing
9
+ librosa>=0.10.0
10
+ soundfile>=0.12.0
11
+ pyloudnorm>=0.1.0
12
+ scipy>=1.11.0
13
+ numpy>=1.24.0
14
+
15
+ # Data handling
16
+ pandas>=2.0.0
17
+
18
+ # Model specific
19
+ safetensors>=0.4.0
20
+ sentencepiece>=0.1.99 # For some tokenizers
21
+
22
+ # Optional optimizations
23
+ triton>=2.1.0 # For faster attention if available
24
+
25
+ # Memory management
26
+ psutil>=5.9.0
init.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from engine import compute_mapss_measures
2
+
3
+ __version__ = "1.0.0"
4
+ __all__ = ["compute_mapss_measures"]
main.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from pathlib import Path
3
+ from engine import compute_mapss_measures
4
+ from argshield import _parse_args, _read_manifest, _validate_and_resolve, _validate_gpus
5
+
6
+ def main():
7
+ args = _parse_args()
8
+
9
+ manifest = _read_manifest(Path(args.manifest))
10
+ layer_final, alpha_final = _validate_and_resolve(args.model, args.layer, args.alpha)
11
+ max_gpus_final = _validate_gpus(args.max_gpus)
12
+
13
+ results_dir = compute_mapss_measures(
14
+ models=[args.model],
15
+ mixtures=manifest,
16
+ verbose=args.verbose,
17
+ max_gpus=max_gpus_final,
18
+ layer=layer_final,
19
+ alpha=alpha_final,
20
+ )
21
+ print(f"Results saved to: {results_dir}")
22
+
23
+ if __name__ == "__main__":
24
+ main()
metrics.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch
6
+ from scipy.special import gammaincc
7
+ from scipy.stats import gamma
8
+
9
+ from config import COV_TOL, DEFAULT_DELTA_CI
10
+ from utils import get_gpu_count, mahalanobis_torch, safe_cov_torch
11
+
12
+
13
+ def pm_tail_gamma(d_out_sq, sq_dists):
14
+ """PM tail gamma exactly as original."""
15
+ mu = sq_dists.mean().item()
16
+ var = sq_dists.var(unbiased=True).item()
17
+ if var == 0.0:
18
+ return 1.0
19
+ k = (mu**2) / var
20
+ theta = var / mu
21
+ return float(1.0 - gamma.cdf(d_out_sq, a=k, scale=theta))
22
+
23
+
24
+ def pm_tail_rank(d_out_sq, sq_dists):
25
+ """PM tail rank exactly as original."""
26
+ rank = int((sq_dists < d_out_sq).sum().item())
27
+ n = sq_dists.numel()
28
+ return 1.0 - (rank + 0.5) / (n + 1.0)
29
+
30
+
31
+ def diffusion_map_torch(
32
+ X_np,
33
+ labels_by_mix,
34
+ *,
35
+ cutoff=0.99,
36
+ tol=1e-3,
37
+ diffusion_time=1,
38
+ alpha=0.0,
39
+ eig_solver="lobpcg",
40
+ k=None,
41
+ device=None,
42
+ return_eigs=False,
43
+ return_complement=False,
44
+ return_cval=False,
45
+ ):
46
+ """Diffusion map computation exactly as original."""
47
+ device = device or ("cuda:0" if torch.cuda.is_available() else "cpu")
48
+ X = torch.as_tensor(X_np, dtype=torch.float32, device=device)
49
+ N = X.shape[0]
50
+
51
+ if device != "cpu" and torch.cuda.is_available():
52
+ stream = torch.cuda.Stream(device=device)
53
+ ctx_dev = torch.cuda.device(device)
54
+ ctx_stream = torch.cuda.stream(stream)
55
+ else:
56
+ from contextlib import nullcontext
57
+
58
+ stream = None
59
+ ctx_dev = nullcontext()
60
+ ctx_stream = nullcontext()
61
+
62
+ with ctx_dev:
63
+ with ctx_stream:
64
+ if N > 1000:
65
+ chunk = min(500, N)
66
+ D2 = torch.zeros(N, N, device=device)
67
+ for i in range(0, N, chunk):
68
+ ei = min(i + chunk, N)
69
+ for j in range(0, N, chunk):
70
+ ej = min(j + chunk, N)
71
+ D2[i:ei, j:ej] = torch.cdist(X[i:ei], X[j:ej]).pow_(2)
72
+ else:
73
+ D2 = torch.cdist(X, X).pow_(2)
74
+
75
+ i, j = torch.triu_indices(
76
+ N, N, offset=1, device=None if device == "cpu" else device
77
+ )
78
+ eps = torch.median(D2[i, j])
79
+ K = torch.exp(-D2 / (2 * eps))
80
+ d = K.sum(dim=1)
81
+
82
+ if alpha != 0.0:
83
+ d_alpha_inv = d.pow(-alpha)
84
+ K *= d_alpha_inv[:, None] * d_alpha_inv[None, :]
85
+ d = K.sum(dim=1)
86
+
87
+ D_half_inv = torch.diag(torch.rsqrt(d))
88
+ K_sym = D_half_inv @ K @ D_half_inv
89
+
90
+ if eig_solver == "lobpcg":
91
+ m = k if k is not None else min(N - 1, 50)
92
+ init = torch.randn(N, m, device=device)
93
+ vals, vecs = torch.lobpcg(
94
+ K_sym, k=m, X=init, niter=200, tol=tol, largest=True
95
+ )
96
+ elif eig_solver == "full":
97
+ vals, vecs = torch.linalg.eigh(K_sym)
98
+ vals, vecs = vals.flip(0), vecs.flip(1)
99
+ if k is not None:
100
+ vecs = vecs[:, : k + 1]
101
+ vals = vals[: k + 1]
102
+ else:
103
+ raise ValueError(f"Unknown eig_solver '{eig_solver}'")
104
+
105
+ psi = vecs[:, 1:]
106
+ lam = vals[1:]
107
+ cum = torch.cumsum(lam, dim=0)
108
+ L = int((cum / cum[-1] < cutoff).sum().item()) + 1
109
+ lam_pow = lam.pow(diffusion_time)
110
+ psi_all = psi * lam_pow
111
+ Psi = psi_all[:, :L]
112
+ Psi_rest = psi_all[:, L:]
113
+
114
+ if return_cval:
115
+ indices_with_out = [
116
+ ii for ii, name in enumerate(labels_by_mix) if "out" in name
117
+ ]
118
+ valid_idx = torch.tensor(
119
+ [ii for ii in range(N) if ii not in indices_with_out], device=device
120
+ )
121
+ pi_min = d[valid_idx].min() / d[valid_idx].sum()
122
+ c_val = lam_pow[0] * pi_min.rsqrt() / math.log(2.0)
123
+
124
+ if stream is not None:
125
+ stream.synchronize()
126
+
127
+ if return_complement and return_eigs and return_cval:
128
+ return (
129
+ Psi.cpu().numpy(),
130
+ Psi_rest.cpu().numpy(),
131
+ lam.cpu().numpy(),
132
+ float(c_val),
133
+ )
134
+ if return_complement and return_eigs:
135
+ return Psi.cpu().numpy(), Psi_rest.cpu().numpy(), lam.cpu().numpy()
136
+ if return_complement:
137
+ return Psi.cpu().numpy(), Psi_rest.cpu().numpy()
138
+ if return_eigs:
139
+ return Psi.cpu().numpy(), lam.cpu().numpy()
140
+ return Psi.cpu().numpy()
141
+
142
+
143
+ def compute_ps(coords, labels, max_gpus=None):
144
+ ngpu = get_gpu_count(max_gpus)
145
+
146
+ if ngpu == 0:
147
+ coords_t = torch.tensor(coords)
148
+ spks_here = sorted({l.split("-")[0] for l in labels})
149
+ out = {}
150
+ for s in spks_here:
151
+ idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
152
+ out_i = labels.index(f"{s}-out")
153
+ ref_is = [i for i in idxs if i != out_i]
154
+ mu = coords_t[ref_is].mean(0)
155
+ cov = safe_cov_torch(coords_t[ref_is])
156
+ inv = torch.linalg.inv(cov)
157
+ A = mahalanobis_torch(coords_t[out_i], mu, inv)
158
+ B_list = []
159
+ for o in spks_here:
160
+ if o == s:
161
+ continue
162
+ o_idxs = [
163
+ i
164
+ for i, l in enumerate(labels)
165
+ if l.startswith(o) and not l.endswith("-out")
166
+ ]
167
+ mu_o = coords_t[o_idxs].mean(0)
168
+ inv_o = torch.linalg.inv(safe_cov_torch(coords_t[o_idxs]))
169
+ B_list.append(mahalanobis_torch(coords_t[out_i], mu_o, inv_o))
170
+ B_min = torch.min(torch.stack(B_list)) if B_list else torch.tensor(0.0)
171
+ out[s] = (1 - A / (A + B_min + 1e-6)).item()
172
+ return out
173
+
174
+ # GPU version
175
+ device = min(ngpu - 1, 1) # Use second GPU if available
176
+ device_str = f"cuda:{device}"
177
+ coords_t = torch.tensor(coords, device=device_str)
178
+ spks_here = sorted({l.split("-")[0] for l in labels})
179
+ out = {}
180
+
181
+ stream = torch.cuda.Stream(device=device_str)
182
+ with torch.cuda.device(device):
183
+ with torch.cuda.stream(stream):
184
+ for s in spks_here:
185
+ idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
186
+ out_i = labels.index(f"{s}-out")
187
+ ref_is = [i for i in idxs if i != out_i]
188
+ mu = coords_t[ref_is].mean(0)
189
+ cov = safe_cov_torch(coords_t[ref_is])
190
+ inv = torch.linalg.inv(cov)
191
+ A = mahalanobis_torch(coords_t[out_i], mu, inv)
192
+ B_list = []
193
+ for o in spks_here:
194
+ if o == s:
195
+ continue
196
+ o_idxs = [
197
+ i
198
+ for i, l in enumerate(labels)
199
+ if l.startswith(o) and not l.endswith("-out")
200
+ ]
201
+ mu_o = coords_t[o_idxs].mean(0)
202
+ inv_o = torch.linalg.inv(safe_cov_torch(coords_t[o_idxs]))
203
+ B_list.append(mahalanobis_torch(coords_t[out_i], mu_o, inv_o))
204
+ B_min = (
205
+ torch.min(torch.stack(B_list))
206
+ if B_list
207
+ else torch.tensor(0.0, device=device_str)
208
+ )
209
+ out[s] = (1 - A / (A + B_min + 1e-6)).item()
210
+ stream.synchronize()
211
+ return out
212
+
213
+
214
+ def compute_pm(coords, labels, pm_method, max_gpus=None):
215
+ ngpu = get_gpu_count(max_gpus)
216
+
217
+ if ngpu == 0:
218
+ coords_t = torch.tensor(coords)
219
+ spks_here = sorted({l.split("-")[0] for l in labels})
220
+ out = {}
221
+ for s in spks_here:
222
+ idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
223
+ ref_i = labels.index(f"{s}-ref")
224
+ out_i = labels.index(f"{s}-out")
225
+ d_idx = [i for i in idxs if i not in {ref_i, out_i}]
226
+ if len(d_idx) < 2:
227
+ out[s] = 0.0
228
+ continue
229
+ ref_v = coords_t[ref_i]
230
+ dist = coords_t[d_idx] - ref_v
231
+ N, D = dist.shape
232
+ cov = dist.T @ dist / (N - 1)
233
+ if torch.linalg.matrix_rank(cov) < D:
234
+ cov += torch.eye(D) * COV_TOL
235
+ inv = torch.linalg.inv(cov)
236
+ sq_dists = torch.stack(
237
+ [mahalanobis_torch(coords_t[i], ref_v, inv) ** 2 for i in d_idx]
238
+ )
239
+ d_out_sq = float(mahalanobis_torch(coords_t[out_i], ref_v, inv) ** 2)
240
+ pm_score = (
241
+ pm_tail_rank(d_out_sq, sq_dists)
242
+ if pm_method == "rank"
243
+ else pm_tail_gamma(d_out_sq, sq_dists)
244
+ )
245
+ out[s] = float(np.clip(pm_score, 0.0, 1.0))
246
+ return out
247
+
248
+ # GPU version
249
+ device = min(ngpu - 1, 1)
250
+ device_str = f"cuda:{device}"
251
+ coords_t = torch.tensor(coords, device=device_str)
252
+ spks_here = sorted({l.split("-")[0] for l in labels})
253
+ out = {}
254
+
255
+ stream = torch.cuda.Stream(device=device_str)
256
+ with torch.cuda.device(device):
257
+ with torch.cuda.stream(stream):
258
+ for s in spks_here:
259
+ idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
260
+ ref_i = labels.index(f"{s}-ref")
261
+ out_i = labels.index(f"{s}-out")
262
+ d_idx = [i for i in idxs if i not in {ref_i, out_i}]
263
+ if len(d_idx) < 2:
264
+ out[s] = 0.0
265
+ continue
266
+ ref_v = coords_t[ref_i]
267
+ dist = coords_t[d_idx] - ref_v
268
+ N, D = dist.shape
269
+ cov = dist.T @ dist / (N - 1)
270
+ if torch.linalg.matrix_rank(cov) < D:
271
+ cov += torch.eye(D, device=device_str) * COV_TOL
272
+ inv = torch.linalg.inv(cov)
273
+ sq_dists = torch.stack(
274
+ [mahalanobis_torch(coords_t[i], ref_v, inv) ** 2 for i in d_idx]
275
+ )
276
+ d_out_sq = float(mahalanobis_torch(coords_t[out_i], ref_v, inv) ** 2)
277
+ pm_score = (
278
+ pm_tail_rank(d_out_sq, sq_dists)
279
+ if pm_method == "rank"
280
+ else pm_tail_gamma(d_out_sq, sq_dists)
281
+ )
282
+ out[s] = float(np.clip(pm_score, 0.0, 1.0))
283
+ stream.synchronize()
284
+ return out
285
+
286
+
287
+ def pm_ci_components_full(
288
+ coords_d, coords_rest, eigvals, labels, *, delta=0.05, K=1.0, C1=1.0, C2=1.0
289
+ ):
290
+ """PM CI components exactly as original - complete implementation."""
291
+ _EPS = 1e-12
292
+
293
+ def _safe_x(a, theta):
294
+ return a / max(theta, _EPS)
295
+
296
+ D = coords_d.shape[1]
297
+ m = coords_rest.shape[1]
298
+ if m == 0:
299
+ z = {s: 0.0 for s in {l.split("-")[0] for l in labels}}
300
+ return z.copy(), z.copy()
301
+
302
+ X_d = torch.tensor(
303
+ coords_d, device="cuda:0" if torch.cuda.is_available() else "cpu"
304
+ )
305
+ X_c = torch.tensor(
306
+ coords_rest, device="cuda:0" if torch.cuda.is_available() else "cpu"
307
+ )
308
+ spk_ids = sorted({l.split("-")[0] for l in labels})
309
+ bias_ci = {}
310
+ prob_ci = {}
311
+
312
+ for s in spk_ids:
313
+ idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
314
+ ref_i = labels.index(f"{s}-ref")
315
+ out_i = labels.index(f"{s}-out")
316
+ dist_is = [i for i in idxs if i not in {ref_i, out_i}]
317
+ n_p = len(dist_is)
318
+
319
+ if n_p < 2:
320
+ bias_ci[s] = 0.0
321
+ prob_ci[s] = 0.0
322
+ continue
323
+
324
+ ref_d = X_d[ref_i]
325
+ ref_c = X_c[ref_i]
326
+ D_mat = X_d[dist_is] - ref_d
327
+ C_mat = X_c[dist_is] - ref_c
328
+ Sigma_d = safe_cov_torch(D_mat)
329
+ Sigma_c = safe_cov_torch(C_mat)
330
+ C_dc = D_mat.T @ C_mat / (n_p - 1)
331
+ inv_Sigma_d = torch.linalg.inv(Sigma_d)
332
+
333
+ S_i = (
334
+ Sigma_c
335
+ - C_dc.T @ inv_Sigma_d @ C_dc
336
+ + torch.eye(X_c.shape[1], device=X_c.device) * 1e-9
337
+ )
338
+ S_inv = torch.linalg.inv(S_i)
339
+
340
+ diff_out_d = X_d[out_i] - ref_d
341
+ diff_out_c = X_c[out_i] - ref_c
342
+ r_out = diff_out_c - C_dc.T @ inv_Sigma_d @ diff_out_d
343
+ delta_Gi_a = float(r_out @ S_inv @ r_out)
344
+
345
+ r_list = []
346
+ for p in dist_is:
347
+ d_p = X_d[p] - ref_d
348
+ c_p = X_c[p] - ref_c
349
+ r_p = c_p - C_dc.T @ inv_Sigma_d @ d_p
350
+ r_list.append(r_p)
351
+ R_p = torch.stack(r_list, dim=0)
352
+ delta_Gi_p = torch.sum(R_p @ S_inv * R_p, dim=1)
353
+ delta_Gi_mu_max = float(delta_Gi_p.max())
354
+
355
+ mah_sq = torch.stack(
356
+ [(X_d[i] - ref_d) @ inv_Sigma_d @ (X_d[i] - ref_d) for i in dist_is]
357
+ )
358
+ mu_g = float(mah_sq.mean())
359
+ sigma2_g = float(mah_sq.var(unbiased=True) + 1e-12)
360
+ sigma_g = math.sqrt(sigma2_g)
361
+
362
+ full_sq = mah_sq + delta_Gi_p
363
+ mu_full = float(full_sq.mean())
364
+ sigma2_full = float(full_sq.var(unbiased=True) + 1e-12)
365
+
366
+ if sigma2_g == 0.0:
367
+ delta_Gi_k = delta_Gi_theta = 0.0
368
+ else:
369
+ factor = delta_Gi_mu_max * n_p / (n_p - 1)
370
+ delta_Gi_k = 1.0 * factor * (mu_full + mu_g) / sigma2_g
371
+ delta_Gi_theta = 1.0 * factor * (sigma2_full + sigma2_g) / (mu_g**2 + 1e-9)
372
+
373
+ k_d = (mu_g**2) / max(sigma2_g, 1e-12)
374
+ theta_d = sigma2_g / max(mu_g, 1e-12)
375
+ a_d = float(diff_out_d @ inv_Sigma_d @ diff_out_d)
376
+
377
+ pm_center = gammaincc(k_d, _safe_x(a_d, theta_d))
378
+
379
+ corner_vals = []
380
+ for s_k in (-1, 1):
381
+ for s_theta in (-1, 1):
382
+ for s_a in (-1, 1):
383
+ k_c = max(k_d + s_k * delta_Gi_k, 1e-6)
384
+ theta_c = max(theta_d + s_theta * delta_Gi_theta, 1e-6)
385
+ a_c = max(a_d + s_a * delta_Gi_a, 1e-8)
386
+ corner_vals.append(gammaincc(k_c, _safe_x(a_c, theta_c)))
387
+
388
+ bias_ci[s] = max(abs(v - pm_center) for v in corner_vals)
389
+
390
+ # Probabilistic half-width
391
+ R_sq = float(mah_sq.max()) + 1e-12
392
+ log_term = math.log(6.0 / delta)
393
+ eps_mu = math.sqrt(2 * sigma2_g * log_term / n_p) + 3 * R_sq * log_term / n_p
394
+ eps_sigma = (
395
+ math.sqrt(2 * R_sq**2 * log_term / n_p) + 3 * R_sq**2 * log_term / n_p
396
+ )
397
+
398
+ g1_x = 2.0 * mu_g / (sigma2_g + 1e-9)
399
+ g1_y = -2.0 * mu_g**2 / (sigma_g**3 + 1e-9)
400
+ g2_x = -sigma2_g / (mu_g**2 + 1e-9)
401
+ g2_y = 2.0 * sigma_g / (mu_g + 1e-9)
402
+
403
+ delta_k = min(abs(g1_x) * eps_mu + abs(g1_y) * eps_sigma, 0.5 * k_d)
404
+ delta_theta = min(abs(g2_x) * eps_mu + abs(g2_y) * eps_sigma, 0.5 * theta_d)
405
+ delta_a = min(R_sq * math.sqrt(2 * log_term / n_p), 0.5 * a_d + 1e-12)
406
+
407
+ pm_corners = []
408
+ for s_k in (-1, 1):
409
+ for s_theta in (-1, 1):
410
+ for s_a in (-1, 1):
411
+ k_c = k_d + s_k * delta_k
412
+ theta_c = theta_d + s_theta * delta_theta
413
+ a_c = max(a_d + s_a * delta_a, 1e-8)
414
+ pm_corners.append(gammaincc(k_c, _safe_x(a_c, theta_c)))
415
+
416
+ prob_ci[s] = max(abs(pm - pm_center) for pm in pm_corners)
417
+
418
+ return bias_ci, prob_ci
419
+
420
+
421
+ def ps_ci_components_full(coords_d, coords_rest, eigvals, labels, *, delta=0.05):
422
+ """PS CI components exactly as original - complete implementation."""
423
+
424
+ def _mean_dev(lam_max, delta, n_eff):
425
+ return math.sqrt(2 * lam_max * math.log(2 / delta) / n_eff)
426
+
427
+ def _rel_cov_dev(lam_max, trace, delta, n_eff, C=1.0):
428
+ r = trace / lam_max
429
+ abs_dev = (
430
+ C * lam_max * (math.sqrt(r / n_eff) + (r + math.log(2 / delta)) / n_eff)
431
+ )
432
+ return abs_dev / lam_max
433
+
434
+ def _maha_eps_m(a_hat, lam_min, lam_max, mean_dev, rel_cov_dev):
435
+ term1 = 2 * math.sqrt(a_hat) * mean_dev * math.sqrt(lam_max / lam_min)
436
+ term2 = a_hat * rel_cov_dev
437
+ return term1 + term2
438
+
439
+ D = coords_d.shape[1]
440
+ m = coords_rest.shape[1]
441
+ if m == 0:
442
+ z = {s: 0.0 for s in set(l.split("-")[0] for l in labels)}
443
+ return z.copy(), z.copy()
444
+
445
+ X_d = torch.tensor(
446
+ coords_d, device="cuda:0" if torch.cuda.is_available() else "cpu"
447
+ )
448
+ X_c = torch.tensor(
449
+ coords_rest, device="cuda:0" if torch.cuda.is_available() else "cpu"
450
+ )
451
+ spk_ids = sorted({l.split("-")[0] for l in labels})
452
+ bias = {}
453
+ prob = {}
454
+
455
+ for s in spk_ids:
456
+ idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
457
+ out_i = labels.index(f"{s}-out")
458
+ ref_is = [i for i in idxs if i != out_i]
459
+
460
+ mu_d = X_d[ref_is].mean(0)
461
+ mu_c = X_c[ref_is].mean(0)
462
+ Sigma_d = safe_cov_torch(X_d[ref_is])
463
+ Sigma_c = safe_cov_torch(X_c[ref_is])
464
+ C_dc = (X_d[ref_is] - mu_d).T @ (X_c[ref_is] - mu_c) / (len(ref_is) - 1)
465
+ inv_Sd = torch.linalg.inv(Sigma_d)
466
+
467
+ lam_min = torch.linalg.eigvalsh(Sigma_d).min().clamp_min(1e-9).item()
468
+ lam_max = torch.linalg.eigvalsh(Sigma_d).max()
469
+ trace = torch.trace(Sigma_d).item()
470
+
471
+ diff_d = X_d[out_i] - mu_d
472
+ diff_c = X_c[out_i] - mu_c
473
+ A_d = float(mahalanobis_torch(X_d[out_i], mu_d, inv_Sd))
474
+
475
+ r_i = diff_c - C_dc.T @ inv_Sd @ diff_d
476
+ S_i = (
477
+ Sigma_c
478
+ - C_dc.T @ inv_Sd @ C_dc
479
+ + torch.eye(X_c.shape[1], device=X_c.device) * 1e-9
480
+ )
481
+ term_i = math.sqrt(float(r_i @ torch.linalg.solve(S_i, r_i)))
482
+
483
+ B_d, term_j = float("inf"), 0.0
484
+ Sig_o = None
485
+ for o in spk_ids:
486
+ if o == s:
487
+ continue
488
+ o_idxs = [
489
+ i
490
+ for i, l in enumerate(labels)
491
+ if l.startswith(o) and not l.endswith("-out")
492
+ ]
493
+ muo_d = X_d[o_idxs].mean(0)
494
+ muo_c = X_c[o_idxs].mean(0)
495
+ Sig_o_tmp = safe_cov_torch(X_d[o_idxs])
496
+ inv_So = torch.linalg.inv(Sig_o_tmp)
497
+ this_B = float(mahalanobis_torch(X_d[out_i], muo_d, inv_So))
498
+
499
+ if this_B < B_d:
500
+ B_d = this_B
501
+ Sig_o = Sig_o_tmp
502
+ diff_do = X_d[out_i] - muo_d
503
+ diff_co = X_c[out_i] - muo_c
504
+ C_oc = (
505
+ (X_d[o_idxs] - muo_d).T @ (X_c[o_idxs] - muo_c) / (len(o_idxs) - 1)
506
+ )
507
+ r_j = diff_co - C_oc.T @ inv_So @ diff_do
508
+ S_j = (
509
+ safe_cov_torch(X_c[o_idxs])
510
+ - C_oc.T @ inv_So @ C_oc
511
+ + torch.eye(X_c.shape[1], device=X_c.device) * 1e-9
512
+ )
513
+ term_j = math.sqrt(float(r_j @ torch.linalg.solve(S_j, r_j)))
514
+
515
+ denom = A_d + B_d
516
+ bias[s] = (B_d * term_i + A_d * term_j) / (denom**2)
517
+
518
+ if Sig_o is not None:
519
+ lam_min_o = torch.linalg.eigvalsh(Sig_o).min().clamp_min(1e-9).item()
520
+ lam_max_o = torch.linalg.eigvalsh(Sig_o).max().item()
521
+ trace_o = torch.trace(Sig_o).item()
522
+
523
+ n_eff = max(int(0.7 * len(ref_is)), 3)
524
+ RIDGE = 0.05
525
+ lam_min_eff = max(lam_min, RIDGE * lam_max.item())
526
+ lam_min_o_eff = max(lam_min_o, RIDGE * lam_max_o)
527
+
528
+ eps_i_sg = _maha_eps_m(
529
+ A_d,
530
+ lam_min_eff,
531
+ lam_max.item(),
532
+ _mean_dev(lam_max.item(), delta / 2, n_eff),
533
+ _rel_cov_dev(lam_max.item(), trace, delta / 2, n_eff),
534
+ )
535
+ eps_j_sg = _maha_eps_m(
536
+ B_d,
537
+ lam_min_o_eff,
538
+ lam_max_o,
539
+ _mean_dev(lam_max_o, delta / 2, n_eff),
540
+ _rel_cov_dev(lam_max_o, trace_o, delta / 2, n_eff),
541
+ )
542
+
543
+ grad_l2 = math.hypot(A_d, B_d) / (A_d + B_d) ** 2
544
+ ps_radius = grad_l2 * math.hypot(eps_i_sg, eps_j_sg)
545
+ prob[s] = min(1.0, ps_radius)
546
+ else:
547
+ prob[s] = 0.0
548
+
549
+ return bias, prob
models.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue
2
+ import threading
3
+ import gc
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from transformers import (
8
+ HubertModel,
9
+ Wav2Vec2FeatureExtractor,
10
+ Wav2Vec2Model,
11
+ WavLMModel,
12
+ ASTModel,
13
+ AutoFeatureExtractor,
14
+ )
15
+
16
+ from config import BATCH_SIZE, ENERGY_HOP_MS, ENERGY_WIN_MS, SR
17
+ from utils import get_gpu_count
18
+
19
+
20
+ class BalancedDualGPUModel:
21
+
22
+ def __init__(self, model_name, layer, max_gpus=None):
23
+ self.layer = layer
24
+ self.models = []
25
+ self.extractors = []
26
+ self.devices = []
27
+ ngpu = get_gpu_count(max_gpus)
28
+
29
+ for gpu_id in range(min(ngpu, 2)):
30
+ device = f"cuda:{gpu_id}"
31
+ self.devices.append(device)
32
+ ckpt, cls, _ = get_model_config(layer)[model_name]
33
+ if cls is ASTModel:
34
+ extractor = AutoFeatureExtractor.from_pretrained(ckpt)
35
+ else:
36
+ extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)
37
+
38
+ attn_impl = "eager" if cls in (WavLMModel, ASTModel) else "sdpa"
39
+ model = cls.from_pretrained(
40
+ ckpt,
41
+ output_hidden_states=True,
42
+ use_safetensors=True,
43
+ torch_dtype=torch.float16,
44
+ low_cpu_mem_usage=True,
45
+ attn_implementation=attn_impl
46
+ )
47
+ model.eval()
48
+ model = model.to(device)
49
+
50
+ for param in model.parameters():
51
+ param.requires_grad = False
52
+
53
+ self.extractors.append(extractor)
54
+ self.models.append(model)
55
+
56
+ self.gpu_queues = [queue.Queue() for _ in range(len(self.devices))]
57
+ self.result_queue = queue.Queue()
58
+ self.workers = []
59
+ for i in range(len(self.devices)):
60
+ worker = threading.Thread(target=self._gpu_worker, args=(i,))
61
+ worker.daemon = True
62
+ worker.start()
63
+ self.workers.append(worker)
64
+
65
+ def _gpu_worker(self, gpu_id):
66
+ device = self.devices[gpu_id]
67
+ model = self.models[gpu_id]
68
+ extractor = self.extractors[gpu_id]
69
+ while True:
70
+ task = self.gpu_queues[gpu_id].get()
71
+ if task is None:
72
+ break
73
+ signals, masks, use_mlm, task_id = task
74
+ try:
75
+ inputs = extractor(
76
+ signals, sampling_rate=SR, return_tensors="pt", padding=True
77
+ )
78
+ input_values = inputs.input_values.to(device, non_blocking=True)
79
+
80
+ torch.cuda.empty_cache()
81
+
82
+ orig_mode = model.training
83
+ model.train() if use_mlm else model.eval()
84
+ with torch.no_grad():
85
+ with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
86
+ hs = model(
87
+ input_values, output_hidden_states=True
88
+ ).hidden_states[self.layer]
89
+ model.train(orig_mode)
90
+
91
+ B, T, D = hs.shape
92
+ keep = []
93
+ for b in range(B):
94
+ mask_b = masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
95
+ mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
96
+ keep.append(hs[b][mask_t].cpu())
97
+
98
+ # Aggressive cleanup
99
+ del hs, input_values, inputs
100
+ torch.cuda.empty_cache()
101
+
102
+ if keep:
103
+ L_max = max(x.shape[0] for x in keep)
104
+ keep_padded = [
105
+ F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in keep
106
+ ]
107
+ result = torch.stack(keep_padded, dim=0)
108
+ else:
109
+ result = torch.empty(0, 0, 0)
110
+ self.result_queue.put((task_id, result))
111
+ except Exception as e:
112
+ self.result_queue.put((task_id, e))
113
+ finally:
114
+ # Always clear cache after processing
115
+ torch.cuda.empty_cache()
116
+
117
+ def process_batch(self, signals, masks, use_mlm=False):
118
+ if not signals:
119
+ return torch.empty(0, 0, 0)
120
+ batch_size = len(signals)
121
+ split = (batch_size + len(self.devices) - 1) // len(self.devices)
122
+ results = {}
123
+ task_id = 0
124
+ for i in range(0, batch_size, split):
125
+ end = min(i + split, batch_size)
126
+ gpu_id = (i // split) % len(self.devices)
127
+ self.gpu_queues[gpu_id].put(
128
+ (signals[i:end], masks[i:end], use_mlm, task_id)
129
+ )
130
+ task_id += 1
131
+ for _ in range(task_id):
132
+ tid, result = self.result_queue.get()
133
+ if isinstance(result, Exception):
134
+ raise result
135
+ results[tid] = result
136
+ parts = [results[i] for i in range(task_id) if results[i].numel() > 0]
137
+ return torch.cat(parts, dim=0) if parts else torch.empty(0, 0, 0)
138
+
139
+ def cleanup(self):
140
+ """Explicit cleanup method"""
141
+ for q in self.gpu_queues:
142
+ q.put(None)
143
+ for w in self.workers:
144
+ w.join(timeout=5.0)
145
+ for model in self.models:
146
+ del model
147
+ for extractor in self.extractors:
148
+ del extractor
149
+ self.models.clear()
150
+ self.extractors.clear()
151
+ torch.cuda.empty_cache()
152
+ gc.collect()
153
+
154
+ def __del__(self):
155
+ self.cleanup()
156
+
157
+
158
+ # NO CACHE - we need to clean up models properly between runs
159
+ def get_model_config(layer):
160
+ return {
161
+ "raw": (None, None, None),
162
+ "wavlm": ("microsoft/wavlm-large", WavLMModel, layer),
163
+ "wav2vec2": ("facebook/wav2vec2-large-lv60", Wav2Vec2Model, layer),
164
+ "hubert": ("facebook/hubert-large-ll60k", HubertModel, layer),
165
+ "wavlm_base": ("microsoft/wavlm-base", WavLMModel, layer),
166
+ "wav2vec2_base": ("facebook/wav2vec2-base", Wav2Vec2Model, layer),
167
+ "hubert_base": ("facebook/hubert-base-ls960", HubertModel, layer),
168
+ "wav2vec2_xlsr": ("facebook/wav2vec2-large-xlsr-53", Wav2Vec2Model, layer),
169
+ "ast": ("MIT/ast-finetuned-audioset-10-10-0.4593", ASTModel, layer),
170
+ }
171
+
172
+
173
+ # Store loaded models globally to properly manage them
174
+ _loaded_models = {}
175
+
176
+
177
+ def load_model(name, layer, max_gpus=None):
178
+ global _loaded_models
179
+
180
+ # Clean up any previously loaded models first
181
+ if _loaded_models:
182
+ for key, model_data in _loaded_models.items():
183
+ if isinstance(model_data, tuple) and len(model_data) == 2:
184
+ if isinstance(model_data[0], BalancedDualGPUModel):
185
+ model_data[0].cleanup()
186
+ elif isinstance(model_data[0], tuple):
187
+ # Single GPU model
188
+ _, model = model_data[0]
189
+ del model
190
+ _loaded_models.clear()
191
+ torch.cuda.empty_cache()
192
+ gc.collect()
193
+
194
+ if name.lower() in {"raw", "waveform"}:
195
+ return "raw", layer
196
+
197
+ ngpu = get_gpu_count(max_gpus)
198
+ if ngpu > 1:
199
+ model = BalancedDualGPUModel(name, layer, max_gpus)
200
+ _loaded_models[name] = (model, layer)
201
+ return model, layer
202
+ else:
203
+ ckpt, cls, layer_eff = get_model_config(layer)[name]
204
+ if cls is ASTModel:
205
+ extractor = AutoFeatureExtractor.from_pretrained(ckpt)
206
+ else:
207
+ extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)
208
+
209
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
210
+ attn_impl = "eager" if cls in (WavLMModel, ASTModel) else "sdpa"
211
+ model = cls.from_pretrained(
212
+ ckpt,
213
+ output_hidden_states=True,
214
+ use_safetensors=True,
215
+ torch_dtype=torch.float16,
216
+ low_cpu_mem_usage=True,
217
+ attn_implementation=attn_impl
218
+ )
219
+ model.eval()
220
+ model = model.to(device)
221
+
222
+ for param in model.parameters():
223
+ param.requires_grad = False
224
+
225
+ model_tuple = ((extractor, model), layer_eff)
226
+ _loaded_models[name] = model_tuple
227
+ return (extractor, model), layer_eff
228
+
229
+
230
+ def cleanup_all_models():
231
+ """Call this at the end of each experiment to ensure complete cleanup"""
232
+ global _loaded_models
233
+ if _loaded_models:
234
+ for key, model_data in _loaded_models.items():
235
+ if isinstance(model_data, tuple) and len(model_data) == 2:
236
+ if isinstance(model_data[0], BalancedDualGPUModel):
237
+ model_data[0].cleanup()
238
+ elif isinstance(model_data[0], tuple):
239
+ # Single GPU model
240
+ _, model = model_data[0]
241
+ del model
242
+ _loaded_models.clear()
243
+ torch.cuda.empty_cache()
244
+ gc.collect()
245
+
246
+
247
+ def embed_batch_raw(signals, masks_audio):
248
+ win = int(ENERGY_WIN_MS * SR / 1000)
249
+ hop = int(ENERGY_HOP_MS * SR / 1000)
250
+ reps, L_max = [], 0
251
+ for sig_np, mask_np in zip(signals, masks_audio):
252
+ x = torch.as_tensor(sig_np[:-1], dtype=torch.float32)
253
+ frames = x.unfold(0, win, hop)
254
+ mask = torch.as_tensor(mask_np[: len(frames)], dtype=torch.bool)
255
+ keep = frames[mask] if mask.any() else frames[:1]
256
+ reps.append(keep)
257
+ L_max = max(L_max, keep.size(0))
258
+ reps = [F.pad(r, (0, 0, 0, L_max - r.size(0))) for r in reps]
259
+ return torch.stack(reps, dim=0)
260
+
261
+
262
+ def embed_batch_single_gpu(
263
+ signals, masks_audio, extractor, model, layer, use_mlm=False
264
+ ):
265
+ if not signals:
266
+ return torch.empty(0, 0, 0)
267
+ device = next(model.parameters()).device
268
+
269
+ max_batch = 2
270
+ all_keeps = []
271
+
272
+ for i in range(0, len(signals), max_batch):
273
+ batch_signals = signals[i:i + max_batch]
274
+ batch_masks = masks_audio[i:i + max_batch]
275
+
276
+ inputs = extractor(batch_signals, sampling_rate=SR, return_tensors="pt", padding=True)
277
+ input_values = inputs.input_values.to(device, non_blocking=True)
278
+
279
+ orig_mode = model.training
280
+ model.train() if use_mlm else model.eval()
281
+
282
+ with torch.no_grad():
283
+ with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
284
+ hs = model(input_values, output_hidden_states=True).hidden_states[layer]
285
+ model.train(orig_mode)
286
+
287
+ B, T, D = hs.shape
288
+ for b in range(B):
289
+ mask_b = batch_masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
290
+ mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
291
+ all_keeps.append(hs[b][mask_t].cpu())
292
+
293
+ # Aggressive cleanup
294
+ del hs, input_values, inputs
295
+ torch.cuda.empty_cache()
296
+
297
+ if all_keeps:
298
+ L_max = max(x.shape[0] for x in all_keeps)
299
+ keep_padded = [F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in all_keeps]
300
+ result = torch.stack(keep_padded, dim=0)
301
+ # Clean up intermediate lists
302
+ del all_keeps, keep_padded
303
+ return result
304
+ else:
305
+ return torch.empty(0, 0, 0)
306
+
307
+
308
+ def embed_batch(signals, masks_audio, model_wrapper, layer, use_mlm=False):
309
+ if model_wrapper == "raw":
310
+ return embed_batch_raw(signals, masks_audio)
311
+ if isinstance(model_wrapper, BalancedDualGPUModel):
312
+ all_embeddings = []
313
+ batch_size = min(BATCH_SIZE, 2)
314
+ for i in range(0, len(signals), batch_size):
315
+ batch_emb = model_wrapper.process_batch(
316
+ signals[i: i + batch_size], masks_audio[i: i + batch_size], use_mlm
317
+ )
318
+ if batch_emb.numel() > 0:
319
+ all_embeddings.append(batch_emb)
320
+ # Clear cache after each batch
321
+ torch.cuda.empty_cache()
322
+
323
+ if all_embeddings:
324
+ result = torch.cat(all_embeddings, dim=0)
325
+ del all_embeddings
326
+ return result
327
+ else:
328
+ return torch.empty(0, 0, 0)
329
+ else:
330
+ extractor, model = model_wrapper
331
+ return embed_batch_single_gpu(
332
+ signals, masks_audio, extractor, model, layer, use_mlm=use_mlm
333
+ )
utils.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import threading
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ try:
10
+ from scipy.optimize import linear_sum_assignment as _lsa
11
+ except Exception:
12
+ _lsa = None
13
+
14
+ warnings.filterwarnings("ignore", message="Some weights of Wav2Vec2Model")
15
+
16
+
17
+ def get_gpu_count(max_gpus=None):
18
+ ngpu = torch.cuda.device_count()
19
+ if max_gpus is not None:
20
+ ngpu = min(ngpu, max_gpus)
21
+ return ngpu
22
+
23
+
24
+ def clear_gpu_memory():
25
+ """Enhanced GPU memory clearing"""
26
+ if torch.cuda.is_available():
27
+ for i in range(torch.cuda.device_count()):
28
+ with torch.cuda.device(i):
29
+ torch.cuda.empty_cache()
30
+ torch.cuda.synchronize()
31
+ gc.collect()
32
+ torch.cuda.empty_cache()
33
+
34
+
35
+ def get_gpu_memory_info(verbose=False):
36
+ if not verbose:
37
+ return
38
+ for i in range(torch.cuda.device_count()):
39
+ try:
40
+ free_b, total_b = torch.cuda.mem_get_info(i) # type: ignore[attr-defined]
41
+ free_gb = free_b / 1024**3
42
+ total_gb = total_b / 1024**3
43
+ except Exception:
44
+ total_gb = torch.cuda.get_device_properties(i).total_memory / 1024**3
45
+ free_gb = total_gb - (torch.cuda.memory_reserved(i) / 1024**3)
46
+ mem_allocated = torch.cuda.memory_allocated(i) / 1024**3
47
+ print(f"GPU {i}: {mem_allocated:.2f}GB allocated, {free_gb:.2f}GB free / {total_gb:.2f}GB total")
48
+
49
+
50
+ def write_wav_16bit(path, x, sr=16000):
51
+ path = Path(path)
52
+ path.parent.mkdir(parents=True, exist_ok=True)
53
+ try:
54
+ import soundfile as sf
55
+
56
+ sf.write(str(path), x.astype(np.float32), sr)
57
+ except Exception:
58
+ from scipy.io.wavfile import write
59
+
60
+ write(str(path), sr, (np.clip(x, -1, 1) * 32767).astype(np.int16))
61
+
62
+
63
+ def safe_corr_np(a, b):
64
+ L = min(len(a), len(b))
65
+ if L <= 1:
66
+ return 0.0
67
+ a = a[:L].astype(np.float64)
68
+ b = b[:L].astype(np.float64)
69
+ a -= a.mean()
70
+ b -= b.mean()
71
+ da = a.std()
72
+ db = b.std()
73
+ if da <= 1e-12 or db <= 1e-12:
74
+ return 0.0
75
+ r = float((a * b).mean() / (da * db))
76
+ return max(-1.0, min(1.0, r))
77
+
78
+
79
+ def hungarian(cost):
80
+ try:
81
+ if _lsa is not None:
82
+ return _lsa(cost)
83
+ raise RuntimeError("scipy.optimize.linear_sum_assignment unavailable")
84
+ except Exception:
85
+ used = set()
86
+ rows, cols = [], []
87
+ for i in range(cost.shape[0]):
88
+ j = int(
89
+ np.argmin(
90
+ [
91
+ cost[i, k] if k not in used else 1e12
92
+ for k in range(cost.shape[1])
93
+ ]
94
+ )
95
+ )
96
+ used.add(j)
97
+ rows.append(i)
98
+ cols.append(j)
99
+ return np.asarray(rows), np.asarray(cols)
100
+
101
+
102
+ class GPUWorkDistributor:
103
+
104
+ def __init__(self, max_gpus=None):
105
+ ngpu = get_gpu_count(max_gpus)
106
+ self.gpu_locks = [threading.Lock() for _ in range(max(1, min(ngpu, 2)))]
107
+ self.gpu_load = [0 for _ in range(max(1, min(ngpu, 2)))]
108
+ self.ngpu = ngpu
109
+
110
+ def get_least_loaded_gpu(self):
111
+ return int(np.argmin(self.gpu_load))
112
+
113
+ def execute_on_gpu(self, func, *args, **kwargs):
114
+ if self.ngpu == 0:
115
+ kwargs.pop("device", None)
116
+ return func(*args, **kwargs)
117
+ gid = self.get_least_loaded_gpu()
118
+ with self.gpu_locks[gid]:
119
+ self.gpu_load[gid] += 1
120
+ try:
121
+ with torch.cuda.device(gid):
122
+ kwargs["device"] = f"cuda:{gid}"
123
+ result = func(*args, **kwargs)
124
+ # Clear cache after execution
125
+ torch.cuda.empty_cache()
126
+ return result
127
+ finally:
128
+ self.gpu_load[gid] -= 1
129
+
130
+
131
+ @dataclass
132
+ class Mixture:
133
+
134
+ mixture_id: str
135
+ refs: list[Path]
136
+ systems: dict[str, list[Path]]
137
+ speaker_ids: list[str]
138
+
139
+
140
+ def canonicalize_mixtures(mixtures, systems=None):
141
+ canon = []
142
+ if not mixtures:
143
+ return canon
144
+
145
+ def as_paths(seq):
146
+ return [p if isinstance(p, Path) else Path(str(p)) for p in seq]
147
+
148
+ def speaker_id_from_ref(ref_path, idx, mixture_id):
149
+ stem = (ref_path.stem or "").strip()
150
+ if not stem:
151
+ stem = f"spk{idx:02d}"
152
+ return f"{mixture_id}__{stem}"
153
+
154
+ if isinstance(mixtures[0], dict):
155
+ for m in mixtures:
156
+ mid = str(m.get("mixture_id") or m.get("id") or "").strip()
157
+ if not mid:
158
+ raise ValueError("Each mixture must include 'mixture_id'.")
159
+ refs = as_paths(m.get("references", []))
160
+ if not refs:
161
+ raise ValueError(f"Mixture {mid}: 'references' must be non-empty.")
162
+ sysmap = {}
163
+ if isinstance(m.get("systems"), dict):
164
+ for algo, outs in m["systems"].items():
165
+ sysmap[str(algo)] = as_paths(outs)
166
+ spk_ids = [speaker_id_from_ref(r, i, mid) for i, r in enumerate(refs)]
167
+ canon.append(Mixture(mid, refs, sysmap, spk_ids))
168
+ return canon
169
+
170
+ if isinstance(mixtures[0], list):
171
+ for i, group in enumerate(mixtures):
172
+ mid = f"mix_{i:03d}"
173
+ refs, spk_ids = [], []
174
+ for d in group:
175
+ if not isinstance(d, dict) or "ref" not in d or "id" not in d:
176
+ raise ValueError(
177
+ "Legacy mixtures expect dicts with 'id' and 'ref'."
178
+ )
179
+ rp = Path(d["ref"])
180
+ refs.append(rp)
181
+ spk_ids.append(f"{mid}__{str(d['id']).strip()}")
182
+ sysmap = {}
183
+ if systems:
184
+ for algo, per_mix in systems.items():
185
+ if mid in per_mix:
186
+ sysmap[algo] = as_paths(per_mix[mid])
187
+ canon.append(Mixture(mid, refs, sysmap, spk_ids))
188
+ return canon
189
+
190
+ raise ValueError("Unsupported 'mixtures' format.")
191
+
192
+
193
+ def random_misalign(sig, sr, max_ms, mode="single", rng=None):
194
+ import random
195
+
196
+ if rng is None:
197
+ rng = random
198
+ max_samples = int(sr * max_ms / 1000)
199
+ if max_samples == 0:
200
+ return sig
201
+ shift = (
202
+ rng.randint(-max_samples, max_samples) if mode == "range" else int(max_samples)
203
+ )
204
+ if shift == 0:
205
+ return sig
206
+ if isinstance(sig, torch.Tensor):
207
+ z = torch.zeros(abs(shift), dtype=sig.dtype, device=sig.device)
208
+ return (
209
+ torch.cat([z, sig[:-shift]]) if shift > 0 else torch.cat([sig[-shift:], z])
210
+ )
211
+ else:
212
+ z = np.zeros(abs(shift), dtype=sig.dtype)
213
+ return (
214
+ np.concatenate([z, sig[:-shift]])
215
+ if shift > 0
216
+ else np.concatenate([sig[-shift:], z])
217
+ )
218
+
219
+
220
+ def safe_cov_torch(X):
221
+ Xc = X - X.mean(dim=0, keepdim=True)
222
+ cov = Xc.T @ Xc / (Xc.shape[0] - 1)
223
+ if torch.linalg.matrix_rank(cov) < cov.shape[0]:
224
+ cov += torch.eye(cov.shape[0], device=cov.device) * 1e-6
225
+ return cov
226
+
227
+
228
+ def mahalanobis_torch(x, mu, inv):
229
+ diff = x - mu
230
+ diff_T = diff.transpose(-1, -2) if diff.ndim >= 2 else diff
231
+ return torch.sqrt(diff @ inv @ diff_T + 1e-6)