Spaces:
Running
on
Zero
Running
on
Zero
Upload 12 files
Browse files- argshield.py +144 -0
- audio.py +61 -0
- config.py +33 -0
- distortions.py +339 -0
- engine.py +455 -0
- hf_readme.md +136 -0
- hf_requirements.txt +26 -0
- init.py +4 -0
- main.py +24 -0
- metrics.py +549 -0
- models.py +333 -0
- utils.py +231 -0
argshield.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path, PurePath
|
| 5 |
+
import importlib.util
|
| 6 |
+
|
| 7 |
+
from config import DEFAULT_ALPHA
|
| 8 |
+
from models import get_model_config
|
| 9 |
+
|
| 10 |
+
# Central table for default layers per model (kept identical to original table)
|
| 11 |
+
MODEL_DEFAULT_LAYER = {
|
| 12 |
+
"raw": None,
|
| 13 |
+
"wavlm": 24,
|
| 14 |
+
"wav2vec2": 24,
|
| 15 |
+
"hubert": 24,
|
| 16 |
+
"wavlm_base": 12,
|
| 17 |
+
"wav2vec2_base": 12,
|
| 18 |
+
"hubert_base": 12,
|
| 19 |
+
"wav2vec2_xlsr": 24,
|
| 20 |
+
"ast": 12,
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
def _read_manifest_json(path: Path):
|
| 24 |
+
text = Path(path).read_text(encoding="utf-8")
|
| 25 |
+
try:
|
| 26 |
+
return json.loads(text)
|
| 27 |
+
except json.JSONDecodeError as e:
|
| 28 |
+
raise SystemExit(f"Manifest must be JSON. Failed to parse: {e}")
|
| 29 |
+
|
| 30 |
+
def _read_manifest_py(path: Path):
|
| 31 |
+
spec = importlib.util.spec_from_file_location("manifest_mod", str(path))
|
| 32 |
+
if spec is None or spec.loader is None:
|
| 33 |
+
raise SystemExit(f"Could not load Python manifest: {path}")
|
| 34 |
+
mod = importlib.util.module_from_spec(spec)
|
| 35 |
+
spec.loader.exec_module(mod) # executes the .py file
|
| 36 |
+
|
| 37 |
+
if not hasattr(mod, "MANIFEST"):
|
| 38 |
+
raise SystemExit(f"Python manifest {path} must define a top-level variable MANIFEST")
|
| 39 |
+
|
| 40 |
+
manifest = mod.MANIFEST
|
| 41 |
+
|
| 42 |
+
def _to_str(p):
|
| 43 |
+
if isinstance(p, (Path, PurePath)):
|
| 44 |
+
return str(p)
|
| 45 |
+
if isinstance(p, str):
|
| 46 |
+
return p
|
| 47 |
+
raise TypeError(f"Path entry must be str or Path, got {type(p)}: {p}")
|
| 48 |
+
|
| 49 |
+
normalized = []
|
| 50 |
+
try:
|
| 51 |
+
for item in manifest:
|
| 52 |
+
mix_id = item["mixture_id"]
|
| 53 |
+
refs = [_to_str(x) for x in item["references"]]
|
| 54 |
+
systems = {}
|
| 55 |
+
for sys_name, lst in item["systems"].items():
|
| 56 |
+
systems[sys_name] = [_to_str(x) for x in lst]
|
| 57 |
+
normalized.append({
|
| 58 |
+
"mixture_id": mix_id,
|
| 59 |
+
"references": refs,
|
| 60 |
+
"systems": systems,
|
| 61 |
+
})
|
| 62 |
+
except (KeyError, TypeError, ValueError) as e:
|
| 63 |
+
raise SystemExit(f"Malformed MANIFEST in {path}: {e}")
|
| 64 |
+
return normalized
|
| 65 |
+
|
| 66 |
+
def _read_manifest(path: Path):
|
| 67 |
+
suffix = path.suffix.lower()
|
| 68 |
+
if suffix in {".py"}:
|
| 69 |
+
return _read_manifest_py(path)
|
| 70 |
+
elif suffix in {".json", ".txt"}:
|
| 71 |
+
return _read_manifest_json(path)
|
| 72 |
+
else:
|
| 73 |
+
raise SystemExit(f"Unsupported manifest type '{suffix}'. Use .py, .json, or .txt")
|
| 74 |
+
|
| 75 |
+
def _parse_args():
|
| 76 |
+
parser = argparse.ArgumentParser(
|
| 77 |
+
description="Run PS/PM experiment from a manifest file."
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--manifest",
|
| 81 |
+
type=Path,
|
| 82 |
+
required=True,
|
| 83 |
+
help="Path to manifest (.py with MANIFEST or .json/.txt with JSON).",
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--model",
|
| 87 |
+
type=str,
|
| 88 |
+
required=True,
|
| 89 |
+
help=("Embedding model. Choices: "
|
| 90 |
+
"raw, wavlm, wav2vec2, hubert, wavlm_base, wav2vec2_base, "
|
| 91 |
+
"hubert_base, wav2vec2_xlsr, ast"),
|
| 92 |
+
)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--layer",
|
| 95 |
+
type=int,
|
| 96 |
+
default=None,
|
| 97 |
+
help="Optional layer (validated per model). Omit to use the model default.",
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--alpha",
|
| 101 |
+
type=float,
|
| 102 |
+
default=None,
|
| 103 |
+
help="Optional diffusion-maps alpha in [0,1] (default: config DEFAULT_ALPHA).",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument("--verbose", action="store_true", help="Verbose logging.")
|
| 106 |
+
parser.add_argument("--max-gpus", type=int, default=None, help="Limit GPUs to use (must be >= 0).")
|
| 107 |
+
return parser.parse_args()
|
| 108 |
+
|
| 109 |
+
def _validate_and_resolve(model: str, layer_opt: int|None, alpha_opt: float|None):
|
| 110 |
+
allowed_models = set(get_model_config(0).keys())
|
| 111 |
+
if model not in allowed_models:
|
| 112 |
+
raise SystemExit(f"Unknown --model '{model}'. Allowed: {sorted(allowed_models)}")
|
| 113 |
+
|
| 114 |
+
max_layer = MODEL_DEFAULT_LAYER.get(model)
|
| 115 |
+
if model == "raw":
|
| 116 |
+
layer_final = 0 if layer_opt is None else int(layer_opt)
|
| 117 |
+
else:
|
| 118 |
+
if layer_opt is None:
|
| 119 |
+
if max_layer is None:
|
| 120 |
+
raise SystemExit(f"--layer must be provided for model '{model}'.")
|
| 121 |
+
layer_final = max_layer
|
| 122 |
+
else:
|
| 123 |
+
layer_final = int(layer_opt)
|
| 124 |
+
if max_layer is not None and not (0 <= layer_final <= max_layer):
|
| 125 |
+
raise SystemExit(
|
| 126 |
+
f"--layer {layer_final} is out of range for '{model}'. "
|
| 127 |
+
f"Expected 0..{max_layer} (or omit to use default {max_layer})."
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
alpha_final = DEFAULT_ALPHA if alpha_opt is None else float(alpha_opt)
|
| 131 |
+
if not (0.0 <= alpha_final <= 1.0):
|
| 132 |
+
raise SystemExit("--alpha must be in [0, 1].")
|
| 133 |
+
return layer_final, alpha_final
|
| 134 |
+
|
| 135 |
+
def _validate_gpus(max_gpus_opt):
|
| 136 |
+
if max_gpus_opt is None:
|
| 137 |
+
return None
|
| 138 |
+
try:
|
| 139 |
+
mg = int(max_gpus_opt)
|
| 140 |
+
except Exception:
|
| 141 |
+
raise SystemExit("--max-gpus must be an integer >= 0.")
|
| 142 |
+
if mg < 0:
|
| 143 |
+
raise SystemExit("--max-gpus must be >= 0.")
|
| 144 |
+
return mg
|
audio.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pyloudnorm as pyln
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from config import SILENCE_RATIO, SR
|
| 7 |
+
from utils import hungarian, safe_corr_np
|
| 8 |
+
import warnings
|
| 9 |
+
warnings.filterwarnings("ignore", message="Possible clipped samples in output.")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def loudness_normalize(wav, sr=SR, target_lufs=-23.0):
|
| 13 |
+
meter = pyln.Meter(sr)
|
| 14 |
+
loudness = meter.integrated_loudness(wav)
|
| 15 |
+
normalized_wav = pyln.normalize.loudness(wav, loudness, target_lufs)
|
| 16 |
+
peak = np.max(np.abs(normalized_wav))
|
| 17 |
+
if peak > 1.0:
|
| 18 |
+
normalized_wav = normalized_wav / max(peak, 1e-12)
|
| 19 |
+
return np.clip(normalized_wav, -1.0, 1.0)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def frame_rms_torch(sig, win, hop):
|
| 23 |
+
dev = sig.device
|
| 24 |
+
frames = sig.unfold(0, win, hop)
|
| 25 |
+
if frames.size(0) and (frames.size(0) - 1) * hop == sig.numel() - win:
|
| 26 |
+
frames = frames[:-1]
|
| 27 |
+
rms = torch.sqrt((frames**2).mean(1) + 1e-12)
|
| 28 |
+
return rms.to(dev)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def make_union_voiced_mask(refs_tensors, win, hop):
|
| 32 |
+
device = refs_tensors[0].device
|
| 33 |
+
rms_vecs = [frame_rms_torch(r, win, hop) for r in refs_tensors]
|
| 34 |
+
lengths = [v.numel() for v in rms_vecs]
|
| 35 |
+
L_max = max(lengths)
|
| 36 |
+
silent_union = torch.zeros(L_max, dtype=torch.bool, device=device)
|
| 37 |
+
for idx, (rms, L) in enumerate(zip(rms_vecs, lengths)):
|
| 38 |
+
thr = SILENCE_RATIO * torch.sqrt((refs_tensors[idx] ** 2).mean())
|
| 39 |
+
sil = rms <= thr
|
| 40 |
+
silent_union[:L] |= sil
|
| 41 |
+
return ~silent_union
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def assign_outputs_to_refs_by_corr(ref_paths, out_paths):
|
| 45 |
+
if not out_paths:
|
| 46 |
+
return [None] * len(ref_paths)
|
| 47 |
+
refs = [loudness_normalize(librosa.load(str(p), sr=SR)[0]) for p in ref_paths]
|
| 48 |
+
outs = [loudness_normalize(librosa.load(str(p), sr=SR)[0]) for p in out_paths]
|
| 49 |
+
n, m = len(refs), len(outs)
|
| 50 |
+
K = max(n, m)
|
| 51 |
+
C = np.ones((K, K), dtype=np.float64)
|
| 52 |
+
for i in range(n):
|
| 53 |
+
for j in range(m):
|
| 54 |
+
r = safe_corr_np(refs[i], outs[j])
|
| 55 |
+
C[i, j] = 1.0 - (r + 1.0) * 0.5 # lower = better
|
| 56 |
+
ri, cj = hungarian(C)
|
| 57 |
+
mapping = [None] * n
|
| 58 |
+
for i, j in zip(ri, cj):
|
| 59 |
+
if i < n and j < m:
|
| 60 |
+
mapping[i] = int(j)
|
| 61 |
+
return mapping
|
config.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
import warnings
|
| 5 |
+
warnings.filterwarnings(
|
| 6 |
+
"ignore",
|
| 7 |
+
category=UserWarning,
|
| 8 |
+
message=r"^expandable_segments not supported on this platform"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
SR = 16_000
|
| 12 |
+
RESULTS_ROOT = "results"
|
| 13 |
+
BATCH_SIZE = 2
|
| 14 |
+
ENERGY_WIN_MS = 20
|
| 15 |
+
ENERGY_HOP_MS = 20
|
| 16 |
+
SILENCE_RATIO = 0.1
|
| 17 |
+
EPS = 1e-4
|
| 18 |
+
COV_TOL = 1e-6
|
| 19 |
+
|
| 20 |
+
DEFAULT_LAYER = 2
|
| 21 |
+
DEFAULT_ADD_CI = True
|
| 22 |
+
DEFAULT_DELTA_CI = 0.05
|
| 23 |
+
DEFAULT_ALPHA = 1.0
|
| 24 |
+
|
| 25 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True,garbage_collection_threshold:0.6"
|
| 26 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
|
| 27 |
+
|
| 28 |
+
torch.backends.cudnn.benchmark = True
|
| 29 |
+
torch.backends.cudnn.deterministic = False
|
| 30 |
+
torch.backends.cudnn.enabled = True
|
| 31 |
+
|
| 32 |
+
if torch.cuda.is_available():
|
| 33 |
+
torch.cuda.set_per_process_memory_fraction(0.8)
|
distortions.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import numpy as np
|
| 3 |
+
from numpy.fft import irfft, rfft, rfftfreq
|
| 4 |
+
from scipy.signal import butter, filtfilt, lfilter
|
| 5 |
+
|
| 6 |
+
from config import ENERGY_WIN_MS, EPS, SR
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def sig_stats(x):
|
| 10 |
+
A_pk = max(np.max(np.abs(x)), EPS)
|
| 11 |
+
A_rms = max(np.sqrt(np.mean(x**2)), EPS)
|
| 12 |
+
A_95 = max(np.percentile(np.abs(x), 95), EPS)
|
| 13 |
+
return A_pk, A_rms, A_95
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def frame_distortions(
|
| 17 |
+
frame,
|
| 18 |
+
sr,
|
| 19 |
+
distortion_keys,
|
| 20 |
+
notch_freqs=None,
|
| 21 |
+
low_cutoffs=None,
|
| 22 |
+
high_cutoffs=None,
|
| 23 |
+
frame_start=0,
|
| 24 |
+
):
|
| 25 |
+
notch_freqs = [] if notch_freqs is None else notch_freqs
|
| 26 |
+
low_cutoffs = [] if low_cutoffs is None else low_cutoffs
|
| 27 |
+
high_cutoffs = [] if high_cutoffs is None else high_cutoffs
|
| 28 |
+
distortions = {}
|
| 29 |
+
|
| 30 |
+
A_pk, A_rms, A_95 = sig_stats(frame)
|
| 31 |
+
frame_len = len(frame)
|
| 32 |
+
X = rfft(frame)
|
| 33 |
+
freqs = rfftfreq(frame_len, 1 / sr)
|
| 34 |
+
t = np.arange(frame_len) / sr
|
| 35 |
+
|
| 36 |
+
if ("notch" in distortion_keys) or distortion_keys == "all":
|
| 37 |
+
bw = 60.0
|
| 38 |
+
for f0 in notch_freqs:
|
| 39 |
+
Y = X.copy()
|
| 40 |
+
band = (freqs > f0 - bw) & (freqs < f0 + bw)
|
| 41 |
+
Y[band] = 0
|
| 42 |
+
distortions[f"Notch_{int(round(f0))}Hz"] = irfft(Y, n=len(frame))
|
| 43 |
+
|
| 44 |
+
if ("comb" in distortion_keys) or distortion_keys == "all":
|
| 45 |
+
for d_ms, decay in zip([2.5, 5, 7.5, 10, 12.5, 15], [0.4, 0.5, 0.6, 0.7, 0.9]):
|
| 46 |
+
D = int(sr * d_ms / 1000)
|
| 47 |
+
if D >= frame_len:
|
| 48 |
+
continue
|
| 49 |
+
out = frame.copy()
|
| 50 |
+
out[:-D] += decay * frame[D:]
|
| 51 |
+
distortions[f"Comb_{int(d_ms)}ms"] = out
|
| 52 |
+
|
| 53 |
+
if ("tremolo" in distortion_keys) or distortion_keys == "all":
|
| 54 |
+
depth = 1.0
|
| 55 |
+
t_centre = (frame_start + 0.5 * len(frame)) / sr
|
| 56 |
+
for r_hz in [1, 2, 4, 6]:
|
| 57 |
+
mod = (1 - depth) + depth * 0.5 * (1 + np.sin(2 * np.pi * r_hz * t_centre))
|
| 58 |
+
distortions[f"Tremolo_{r_hz}Hz"] = frame * mod
|
| 59 |
+
|
| 60 |
+
if ("noise" in distortion_keys) or distortion_keys == "all":
|
| 61 |
+
nyq = sr / 2
|
| 62 |
+
low_norm = 20 / nyq
|
| 63 |
+
high_freq = min(20_000, 0.45 * sr)
|
| 64 |
+
high_norm = min(high_freq / nyq, 0.99)
|
| 65 |
+
b_band, a_band = butter(5, [low_norm, high_norm], btype="band")
|
| 66 |
+
|
| 67 |
+
def add_noise(sig, snr_db, color="white"):
|
| 68 |
+
nl_target = 10 ** (snr_db / 10)
|
| 69 |
+
n = np.random.randn(len(sig))
|
| 70 |
+
if color == "pink":
|
| 71 |
+
n = np.cumsum(n)
|
| 72 |
+
n /= max(np.max(np.abs(n)), 1e-12)
|
| 73 |
+
elif color == "brown":
|
| 74 |
+
n = np.cumsum(np.cumsum(n))
|
| 75 |
+
n /= max(np.max(np.abs(n)), 1e-12)
|
| 76 |
+
n = lfilter(b_band, a_band, n)
|
| 77 |
+
rms_sig = np.sqrt(np.mean(sig**2))
|
| 78 |
+
rms_n = np.sqrt(np.mean(n**2)) + 1e-12
|
| 79 |
+
noise_rms = rms_sig / np.sqrt(nl_target)
|
| 80 |
+
noise_rms = max(noise_rms, rms_sig / np.sqrt(10 ** (15 / 10)))
|
| 81 |
+
n *= noise_rms / rms_n
|
| 82 |
+
return sig + n
|
| 83 |
+
|
| 84 |
+
for snr in [-15, -10, -5, 0, 5, 10, 15, 20, 25]:
|
| 85 |
+
for clr in ["white", "pink", "brown"]:
|
| 86 |
+
if (snr in [-15, -10, -5]) and (clr == "white"):
|
| 87 |
+
continue
|
| 88 |
+
distortions[f"{clr.capitalize()}Noise_{snr}dB"] = add_noise(
|
| 89 |
+
frame, snr, clr
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if ("harmonic" in distortion_keys) or distortion_keys == "all":
|
| 93 |
+
for f_h, rel_amp in zip([100, 500, 1000, 4000], [0.4, 0.6, 0.8, 1.0]):
|
| 94 |
+
tone = (rel_amp * A_rms) * np.sin(2 * np.pi * f_h * t)
|
| 95 |
+
distortions[f"Harmonic_{f_h}Hz"] = frame + tone
|
| 96 |
+
|
| 97 |
+
if ("reverb" in distortion_keys) or distortion_keys == "all":
|
| 98 |
+
for tail_ms, decay in zip([50, 100, 200, 400], [0.3, 0.5, 0.7, 0.9]):
|
| 99 |
+
L = int(sr * tail_ms / 1000)
|
| 100 |
+
if L >= frame_len:
|
| 101 |
+
continue
|
| 102 |
+
irv = np.exp(-np.linspace(0, 6, L)) * decay
|
| 103 |
+
reverbed = np.convolve(frame, irv)[:frame_len]
|
| 104 |
+
distortions[f"Reverb_{tail_ms}ms"] = reverbed
|
| 105 |
+
|
| 106 |
+
if ("noisegate" in distortion_keys) or distortion_keys == "all":
|
| 107 |
+
for pct in [0.05, 0.10, 0.20, 0.40]:
|
| 108 |
+
thr = pct * A_95
|
| 109 |
+
g = frame.copy()
|
| 110 |
+
g[np.abs(g) < thr] = 0
|
| 111 |
+
distortions[f"NoiseGate_{int(pct * 100)}pct"] = g
|
| 112 |
+
|
| 113 |
+
if ("pitch_shift" in distortion_keys) or distortion_keys == "all":
|
| 114 |
+
n_fft = min(2048, frame_len // 2)
|
| 115 |
+
for shift in [-4, -2, 2, 4]:
|
| 116 |
+
y = librosa.effects.pitch_shift(frame, sr=sr, n_steps=shift, n_fft=n_fft)
|
| 117 |
+
distortions[f"PitchShift_{shift}st"] = y[:frame_len]
|
| 118 |
+
|
| 119 |
+
if ("lowpass" in distortion_keys) or distortion_keys == "all":
|
| 120 |
+
for fc in low_cutoffs:
|
| 121 |
+
if fc >= sr / 2 * 0.99:
|
| 122 |
+
continue
|
| 123 |
+
b, a = butter(6, fc / (sr / 2), btype="low")
|
| 124 |
+
distortions[f"Lowpass_{fc}Hz"] = filtfilt(b, a, frame)
|
| 125 |
+
|
| 126 |
+
if ("highpass" in distortion_keys) or distortion_keys == "all":
|
| 127 |
+
for fc in high_cutoffs:
|
| 128 |
+
if fc <= 20:
|
| 129 |
+
continue
|
| 130 |
+
b, a = butter(6, fc / (sr / 2), btype="high")
|
| 131 |
+
distortions[f"Highpass_{fc}Hz"] = filtfilt(b, a, frame)
|
| 132 |
+
|
| 133 |
+
if ("echo" in distortion_keys) or distortion_keys == "all":
|
| 134 |
+
for delay_ms, amp in zip([50, 100, 150], [0.4, 0.5, 0.7]):
|
| 135 |
+
D = int(sr * delay_ms / 1000)
|
| 136 |
+
if D >= frame_len:
|
| 137 |
+
continue
|
| 138 |
+
echo = np.pad(frame, (D, 0), "constant")[:-D] * amp
|
| 139 |
+
distortions[f"Echo_{delay_ms}ms"] = frame + echo
|
| 140 |
+
|
| 141 |
+
if ("clipping" in distortion_keys) or distortion_keys == "all":
|
| 142 |
+
for frac in [0.70, 0.50, 0.30]:
|
| 143 |
+
thr = frac * A_95
|
| 144 |
+
distortions[f"Clipping_{frac:.2f}p95"] = np.clip(frame, -thr, thr)
|
| 145 |
+
|
| 146 |
+
if ("vibrato" in distortion_keys) or distortion_keys == "all":
|
| 147 |
+
n_fft = min(2048, frame_len // 2)
|
| 148 |
+
base_depth = 0.03 * (A_rms / A_pk)
|
| 149 |
+
for rate_hz, scale in zip([3, 5, 7], [1.0, 1.3, 1.6]):
|
| 150 |
+
depth = np.clip(base_depth * scale, 0.01, 0.05)
|
| 151 |
+
y = librosa.effects.time_stretch(frame, rate=1 + depth, n_fft=n_fft)
|
| 152 |
+
distortions[f"Vibrato_{rate_hz}Hz"] = librosa.util.fix_length(
|
| 153 |
+
y, size=frame_len
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
return distortions
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def apply_adv_distortions(ref, distortion_keys, sr=SR):
|
| 160 |
+
frame_len = int(ENERGY_WIN_MS * sr / 1000)
|
| 161 |
+
n_frames = int(np.ceil(len(ref) / frame_len))
|
| 162 |
+
pad_len = n_frames * frame_len - len(ref)
|
| 163 |
+
ref_padded = (
|
| 164 |
+
np.concatenate([ref, np.zeros(pad_len, dtype=ref.dtype)]) if pad_len else ref
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
X_full = rfft(ref_padded)
|
| 168 |
+
freqs_f = rfftfreq(len(ref_padded), 1 / sr)
|
| 169 |
+
mag_full = np.abs(X_full)
|
| 170 |
+
|
| 171 |
+
valid = (freqs_f > 80) & (freqs_f < 0.45 * sr)
|
| 172 |
+
cand_indices = np.argsort(mag_full[valid])[-60:]
|
| 173 |
+
cand_freqs = freqs_f[valid][cand_indices]
|
| 174 |
+
cand_freqs = cand_freqs[np.argsort(mag_full[valid][cand_indices])[::-1]]
|
| 175 |
+
|
| 176 |
+
selected_notch_freqs = []
|
| 177 |
+
for f0 in cand_freqs:
|
| 178 |
+
if all(abs(f0 - f_sel) > 300 for f_sel in selected_notch_freqs):
|
| 179 |
+
selected_notch_freqs.append(float(f0))
|
| 180 |
+
if len(selected_notch_freqs) >= 20:
|
| 181 |
+
break
|
| 182 |
+
|
| 183 |
+
mag2 = np.abs(X_full) ** 2
|
| 184 |
+
total_p = mag2.sum()
|
| 185 |
+
cum_low = np.cumsum(mag2)
|
| 186 |
+
q_low = [0.50, 0.70, 0.85, 0.95]
|
| 187 |
+
lowpass_cutoffs = []
|
| 188 |
+
for q in q_low:
|
| 189 |
+
idx = np.searchsorted(cum_low, q * total_p)
|
| 190 |
+
f_c = float(freqs_f[idx])
|
| 191 |
+
lowpass_cutoffs.append(round(f_c / 100.0) * 100)
|
| 192 |
+
|
| 193 |
+
cum_high = np.cumsum(mag2[::-1])
|
| 194 |
+
q_high = [0.05, 0.15, 0.30, 0.50]
|
| 195 |
+
highpass_cutoffs = []
|
| 196 |
+
for q in q_high:
|
| 197 |
+
idx = np.searchsorted(cum_high, q * total_p)
|
| 198 |
+
f_c = float(freqs_f[-1 - idx])
|
| 199 |
+
highpass_cutoffs.append(round(f_c / 100.0) * 100)
|
| 200 |
+
|
| 201 |
+
lowpass_cutoffs = sorted(set(lowpass_cutoffs))
|
| 202 |
+
highpass_cutoffs = sorted(set(highpass_cutoffs))
|
| 203 |
+
|
| 204 |
+
out = {}
|
| 205 |
+
for f in range(n_frames):
|
| 206 |
+
start, end = f * frame_len, (f + 1) * frame_len
|
| 207 |
+
frame = ref_padded[start:end]
|
| 208 |
+
frame_dists = frame_distortions(
|
| 209 |
+
frame,
|
| 210 |
+
sr,
|
| 211 |
+
distortion_keys,
|
| 212 |
+
notch_freqs=selected_notch_freqs,
|
| 213 |
+
low_cutoffs=lowpass_cutoffs,
|
| 214 |
+
high_cutoffs=highpass_cutoffs,
|
| 215 |
+
frame_start=start,
|
| 216 |
+
)
|
| 217 |
+
for lbl, sig in frame_dists.items():
|
| 218 |
+
if lbl not in out:
|
| 219 |
+
out[lbl] = np.zeros_like(ref_padded)
|
| 220 |
+
out[lbl][start:end] = sig
|
| 221 |
+
|
| 222 |
+
return list(out.values())
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def apply_distortions(ref, distortion_keys, sr=SR):
|
| 226 |
+
distortions = {}
|
| 227 |
+
X = rfft(ref)
|
| 228 |
+
freqs = rfftfreq(len(ref), 1 / sr)
|
| 229 |
+
t = np.arange(len(ref)) / sr
|
| 230 |
+
|
| 231 |
+
if ("notch" in distortion_keys) or distortion_keys == "all":
|
| 232 |
+
for c in [500, 1000, 2000, 4000, 8000]:
|
| 233 |
+
Y = X.copy()
|
| 234 |
+
Y[(freqs > c - 50) & (freqs < c + 50)] = 0
|
| 235 |
+
distortions[f"Notch_{c}Hz"] = irfft(Y, n=len(ref))
|
| 236 |
+
|
| 237 |
+
if ("comb" in distortion_keys) or distortion_keys == "all":
|
| 238 |
+
for d, decay in zip([2.5, 5, 7.5, 10, 12.5, 15], [0.4, 0.5, 0.6, 0.7, 0.9]):
|
| 239 |
+
D = int(sr * d / 1000)
|
| 240 |
+
if D >= len(ref):
|
| 241 |
+
continue
|
| 242 |
+
cpy = ref.copy()
|
| 243 |
+
if len(ref) > D:
|
| 244 |
+
cpy[:-D] += decay * ref[D:]
|
| 245 |
+
distortions[f"Comb_{int(d)}ms"] = cpy
|
| 246 |
+
|
| 247 |
+
if ("tremolo" in distortion_keys) or distortion_keys == "all":
|
| 248 |
+
for r, depth in zip([1, 2, 4, 6], [0.3, 0.5, 0.8, 1.0]):
|
| 249 |
+
mod = (1 - depth) + depth * 0.5 * (1 + np.sin(2 * np.pi * r * t))
|
| 250 |
+
distortions[f"Tremolo_{r}Hz"] = ref * mod
|
| 251 |
+
|
| 252 |
+
if ("noise" in distortion_keys) or distortion_keys == "all":
|
| 253 |
+
|
| 254 |
+
def add_noise(signal, snr_db, color):
|
| 255 |
+
rms = np.sqrt(np.mean(signal**2))
|
| 256 |
+
nl = 10 ** (snr_db / 10)
|
| 257 |
+
noise_rms = rms / np.sqrt(nl)
|
| 258 |
+
n = np.random.randn(len(signal))
|
| 259 |
+
if color == "pink":
|
| 260 |
+
n = np.cumsum(n)
|
| 261 |
+
n /= max(np.max(np.abs(n)), 1e-12)
|
| 262 |
+
elif color == "brown":
|
| 263 |
+
n = np.cumsum(np.cumsum(n))
|
| 264 |
+
n /= max(np.max(np.abs(n)), 1e-12)
|
| 265 |
+
return signal + noise_rms * n
|
| 266 |
+
|
| 267 |
+
for snr in [-15, -10, -5, 0, 5, 10, 15, 20, 25]:
|
| 268 |
+
for clr in ["white", "pink", "brown"]:
|
| 269 |
+
if snr in [-15, -10, -5] and clr in ["white"]:
|
| 270 |
+
continue
|
| 271 |
+
distortions[f"{clr.capitalize()}Noise_{snr}dB"] = add_noise(
|
| 272 |
+
ref, snr, clr
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
if ("harmonic" in distortion_keys) or distortion_keys == "all":
|
| 276 |
+
for f_h, amp in zip([100, 500, 1000, 4000], [0.02, 0.03, 0.05, 0.08]):
|
| 277 |
+
tone = amp * np.sin(2 * np.pi * f_h * t)
|
| 278 |
+
distortions[f"Harmonic_{f_h}Hz"] = ref + tone
|
| 279 |
+
|
| 280 |
+
if ("reverb" in distortion_keys) or distortion_keys == "all":
|
| 281 |
+
for tail_ms, decay in zip([5, 10, 15, 20], [0.3, 0.5, 0.7, 0.9, 1.1]):
|
| 282 |
+
L = int(sr * tail_ms / 1000)
|
| 283 |
+
if L >= len(ref):
|
| 284 |
+
continue
|
| 285 |
+
irv = np.exp(-np.linspace(0, 3, L)) * decay
|
| 286 |
+
reverbed = np.convolve(ref, irv)[: len(ref)]
|
| 287 |
+
distortions[f"Reverb_{tail_ms}ms"] = reverbed
|
| 288 |
+
|
| 289 |
+
if ("noisegate" in distortion_keys) or distortion_keys == "all":
|
| 290 |
+
for thr in [0.005, 0.01, 0.02, 0.04]:
|
| 291 |
+
g = ref.copy()
|
| 292 |
+
g[np.abs(g) < thr] = 0
|
| 293 |
+
distortions[f"NoiseGate_{thr}"] = g
|
| 294 |
+
|
| 295 |
+
if ("pitch_shift" in distortion_keys) or distortion_keys == "all":
|
| 296 |
+
n_fft = min(2048, len(ref) // 2)
|
| 297 |
+
for shift in [-4, -2, 2, 4]:
|
| 298 |
+
shifted = librosa.effects.pitch_shift(
|
| 299 |
+
y=ref, sr=sr, n_steps=shift, n_fft=n_fft
|
| 300 |
+
)
|
| 301 |
+
distortions[f"PitchShift_{shift}st"] = shifted[: len(ref)]
|
| 302 |
+
|
| 303 |
+
if ("lowpass" in distortion_keys) or distortion_keys == "all":
|
| 304 |
+
for freq in [2000, 3000, 4000, 6000]:
|
| 305 |
+
if freq >= (sr / 2):
|
| 306 |
+
continue
|
| 307 |
+
b, a = butter(4, freq / (sr / 2), "low")
|
| 308 |
+
distortions[f"Lowpass_{freq}Hz"] = filtfilt(b, a, ref)
|
| 309 |
+
|
| 310 |
+
if ("highpass" in distortion_keys) or distortion_keys == "all":
|
| 311 |
+
for freq in [100, 300, 500, 800]:
|
| 312 |
+
if freq >= (sr / 2):
|
| 313 |
+
continue
|
| 314 |
+
b, a = butter(4, freq / (sr / 2), "high")
|
| 315 |
+
distortions[f"Highpass_{freq}Hz"] = filtfilt(b, a, ref)
|
| 316 |
+
|
| 317 |
+
if ("echo" in distortion_keys) or distortion_keys == "all":
|
| 318 |
+
for delay_ms, amp in zip([5, 10, 15, 20], [0.3, 0.5, 0.7]):
|
| 319 |
+
delay = int(sr * delay_ms / 1000)
|
| 320 |
+
if delay >= len(ref):
|
| 321 |
+
continue
|
| 322 |
+
echo = np.pad(ref, (delay, 0), "constant")[:-delay] * amp
|
| 323 |
+
distortions[f"Echo_{delay_ms}ms"] = ref + echo
|
| 324 |
+
|
| 325 |
+
if ("clipping" in distortion_keys) or distortion_keys == "all":
|
| 326 |
+
for thr in [0.3, 0.5, 0.7]:
|
| 327 |
+
distortions[f"Clipping_{thr}"] = np.clip(ref, -thr, thr)
|
| 328 |
+
|
| 329 |
+
if ("vibrato" in distortion_keys) or distortion_keys == "all":
|
| 330 |
+
for rate, depth in zip([3, 5, 7], [0.001, 0.002, 0.003]):
|
| 331 |
+
vibrato = np.sin(2 * np.pi * rate * t) * depth
|
| 332 |
+
vibrato_signal = librosa.effects.time_stretch(
|
| 333 |
+
ref, rate=1 + float(vibrato.mean()), n_fft=min(2048, len(ref) // 2)
|
| 334 |
+
)
|
| 335 |
+
distortions[f"Vibrato_{rate}Hz"] = librosa.util.fix_length(
|
| 336 |
+
vibrato_signal, size=len(ref)
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
return list(distortions.values())
|
engine.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import librosa
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from audio import (
|
| 8 |
+
assign_outputs_to_refs_by_corr,
|
| 9 |
+
loudness_normalize,
|
| 10 |
+
make_union_voiced_mask,
|
| 11 |
+
)
|
| 12 |
+
from config import *
|
| 13 |
+
from distortions import apply_adv_distortions, apply_distortions
|
| 14 |
+
from metrics import (
|
| 15 |
+
compute_pm,
|
| 16 |
+
compute_ps,
|
| 17 |
+
diffusion_map_torch,
|
| 18 |
+
pm_ci_components_full,
|
| 19 |
+
ps_ci_components_full,
|
| 20 |
+
)
|
| 21 |
+
from models import embed_batch, load_model
|
| 22 |
+
from utils import *
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def compute_mapss_measures(
|
| 26 |
+
models,
|
| 27 |
+
mixtures,
|
| 28 |
+
*,
|
| 29 |
+
systems=None,
|
| 30 |
+
algos=None,
|
| 31 |
+
experiment_id=None,
|
| 32 |
+
layer=DEFAULT_LAYER,
|
| 33 |
+
add_ci=DEFAULT_ADD_CI,
|
| 34 |
+
alpha=DEFAULT_ALPHA,
|
| 35 |
+
seed=42,
|
| 36 |
+
on_missing="skip",
|
| 37 |
+
verbose=False,
|
| 38 |
+
max_gpus=None,
|
| 39 |
+
):
|
| 40 |
+
gpu_distributor = GPUWorkDistributor(max_gpus)
|
| 41 |
+
ngpu = get_gpu_count(max_gpus)
|
| 42 |
+
|
| 43 |
+
if on_missing not in {"skip", "error"}:
|
| 44 |
+
raise ValueError("on_missing must be 'skip' or 'error'.")
|
| 45 |
+
|
| 46 |
+
torch.manual_seed(seed)
|
| 47 |
+
random.seed(seed)
|
| 48 |
+
np.random.seed(seed)
|
| 49 |
+
if torch.cuda.is_available():
|
| 50 |
+
torch.cuda.manual_seed_all(seed)
|
| 51 |
+
|
| 52 |
+
canon_mix = canonicalize_mixtures(mixtures, systems=systems)
|
| 53 |
+
|
| 54 |
+
mixture_entries = []
|
| 55 |
+
for m in canon_mix:
|
| 56 |
+
entries = []
|
| 57 |
+
for i, refp in enumerate(m.refs):
|
| 58 |
+
sid = m.speaker_ids[i]
|
| 59 |
+
entries.append(
|
| 60 |
+
{"id": sid, "ref": Path(refp), "mixture": m.mixture_id, "outs": {}}
|
| 61 |
+
)
|
| 62 |
+
mixture_entries.append(entries)
|
| 63 |
+
|
| 64 |
+
for m, mix_entries in zip(canon_mix, mixture_entries):
|
| 65 |
+
for algo, out_list in (m.systems or {}).items():
|
| 66 |
+
mapping = assign_outputs_to_refs_by_corr(
|
| 67 |
+
[e["ref"] for e in mix_entries], out_list
|
| 68 |
+
)
|
| 69 |
+
for idx, e in enumerate(mix_entries):
|
| 70 |
+
j = mapping[idx]
|
| 71 |
+
if j is not None:
|
| 72 |
+
e["outs"][algo] = out_list[j]
|
| 73 |
+
|
| 74 |
+
if algos is None:
|
| 75 |
+
algos_to_run = sorted(
|
| 76 |
+
{algo for m in canon_mix for algo in (m.systems or {}).keys()}
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
algos_to_run = list(algos)
|
| 80 |
+
|
| 81 |
+
exp_id = experiment_id or datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 82 |
+
exp_root = os.path.join(RESULTS_ROOT, f"experiment_{exp_id}")
|
| 83 |
+
os.makedirs(exp_root, exist_ok=True)
|
| 84 |
+
|
| 85 |
+
params = {
|
| 86 |
+
"models": models,
|
| 87 |
+
"layer": layer,
|
| 88 |
+
"add_ci": add_ci,
|
| 89 |
+
"alpha": alpha,
|
| 90 |
+
"seed": seed,
|
| 91 |
+
"batch_size": BATCH_SIZE,
|
| 92 |
+
"ngpu": ngpu,
|
| 93 |
+
"max_gpus": max_gpus,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
with open(os.path.join(exp_root, "params.json"), "w") as f:
|
| 97 |
+
json.dump(params, f, indent=2)
|
| 98 |
+
|
| 99 |
+
canon_struct = [
|
| 100 |
+
{
|
| 101 |
+
"mixture_id": m.mixture_id,
|
| 102 |
+
"references": [str(p) for p in m.refs],
|
| 103 |
+
"systems": {
|
| 104 |
+
a: [str(p) for p in outs] for a, outs in (m.systems or {}).items()
|
| 105 |
+
},
|
| 106 |
+
"speaker_ids": m.speaker_ids,
|
| 107 |
+
}
|
| 108 |
+
for m in canon_mix
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
with open(os.path.join(exp_root, "manifest_canonical.json"), "w") as f:
|
| 112 |
+
json.dump(canon_struct, f, indent=2)
|
| 113 |
+
|
| 114 |
+
print(f"Starting experiment {exp_id} with {ngpu} GPUs")
|
| 115 |
+
print(f"Results will be saved to: {exp_root}")
|
| 116 |
+
|
| 117 |
+
clear_gpu_memory()
|
| 118 |
+
get_gpu_memory_info(verbose)
|
| 119 |
+
|
| 120 |
+
flat_entries = [e for mix in mixture_entries for e in mix]
|
| 121 |
+
all_refs = {}
|
| 122 |
+
|
| 123 |
+
if verbose:
|
| 124 |
+
print("Loading reference signals...")
|
| 125 |
+
for e in flat_entries:
|
| 126 |
+
wav, _ = librosa.load(str(e["ref"]), sr=SR)
|
| 127 |
+
all_refs[e["id"]] = torch.from_numpy(loudness_normalize(wav))
|
| 128 |
+
|
| 129 |
+
if verbose:
|
| 130 |
+
print("Computing voiced masks...")
|
| 131 |
+
|
| 132 |
+
win = int(ENERGY_WIN_MS * SR / 1000)
|
| 133 |
+
hop = int(ENERGY_HOP_MS * SR / 1000)
|
| 134 |
+
voiced_mask_mix = []
|
| 135 |
+
|
| 136 |
+
for i, mix in enumerate(mixture_entries):
|
| 137 |
+
if verbose:
|
| 138 |
+
print(f" Computing mask for mixture {i + 1}/{len(mixture_entries)}")
|
| 139 |
+
|
| 140 |
+
if ngpu > 0:
|
| 141 |
+
with torch.cuda.device(0):
|
| 142 |
+
refs_for_mix = [all_refs[e["id"]].cuda() for e in mix]
|
| 143 |
+
mask = make_union_voiced_mask(refs_for_mix, win, hop)
|
| 144 |
+
voiced_mask_mix.append(mask.cpu())
|
| 145 |
+
# Explicitly delete GPU tensors
|
| 146 |
+
for ref in refs_for_mix:
|
| 147 |
+
del ref
|
| 148 |
+
torch.cuda.empty_cache()
|
| 149 |
+
else:
|
| 150 |
+
refs_for_mix = [all_refs[e["id"]].cpu() for e in mix]
|
| 151 |
+
mask = make_union_voiced_mask(refs_for_mix, win, hop)
|
| 152 |
+
voiced_mask_mix.append(mask.cpu())
|
| 153 |
+
|
| 154 |
+
ordered_speakers = [e["id"] for e in flat_entries]
|
| 155 |
+
|
| 156 |
+
for algo_idx, algo in enumerate(algos_to_run):
|
| 157 |
+
if verbose:
|
| 158 |
+
print(f"\nProcessing Algorithm {algo_idx + 1}/{len(algos_to_run)}: {algo}")
|
| 159 |
+
|
| 160 |
+
algo_dir = os.path.join(exp_root, algo)
|
| 161 |
+
os.makedirs(algo_dir, exist_ok=True)
|
| 162 |
+
|
| 163 |
+
all_outs = {}
|
| 164 |
+
missing = []
|
| 165 |
+
|
| 166 |
+
for mix_idx, mix in enumerate(mixture_entries):
|
| 167 |
+
for e in mix:
|
| 168 |
+
assigned_path = e.get("outs", {}).get(algo)
|
| 169 |
+
if assigned_path is None:
|
| 170 |
+
missing.append((e["mixture"], e["id"]))
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
wav, _ = librosa.load(str(assigned_path), sr=SR)
|
| 174 |
+
all_outs[e["id"]] = torch.from_numpy(loudness_normalize(wav))
|
| 175 |
+
|
| 176 |
+
if missing:
|
| 177 |
+
msg = f"[{algo}] missing outputs for {len(missing)} speaker(s)"
|
| 178 |
+
if on_missing == "error":
|
| 179 |
+
raise FileNotFoundError(msg)
|
| 180 |
+
else:
|
| 181 |
+
if verbose:
|
| 182 |
+
warnings.warn(msg + " Skipping those speakers.")
|
| 183 |
+
|
| 184 |
+
if not all_outs:
|
| 185 |
+
if verbose:
|
| 186 |
+
warnings.warn(f"[{algo}] No outputs provided. Skipping algorithm.")
|
| 187 |
+
continue
|
| 188 |
+
|
| 189 |
+
ps_ts = {m: {s: [] for s in ordered_speakers} for m in models}
|
| 190 |
+
pm_ts = {m: {s: [] for s in ordered_speakers} for m in models}
|
| 191 |
+
ps_bias_ts = {m: {s: [] for s in ordered_speakers} for m in models}
|
| 192 |
+
ps_prob_ts = {m: {s: [] for s in ordered_speakers} for m in models}
|
| 193 |
+
pm_bias_ts = {m: {s: [] for s in ordered_speakers} for m in models}
|
| 194 |
+
pm_prob_ts = {m: {s: [] for s in ordered_speakers} for m in models}
|
| 195 |
+
|
| 196 |
+
for model_idx, mname in enumerate(models):
|
| 197 |
+
if verbose:
|
| 198 |
+
print(f" Processing Model {model_idx + 1}/{len(models)}: {mname}")
|
| 199 |
+
|
| 200 |
+
for metric_type in ["PS", "PM"]:
|
| 201 |
+
clear_gpu_memory()
|
| 202 |
+
gc.collect()
|
| 203 |
+
|
| 204 |
+
model_wrapper, layer_eff = load_model(mname, layer, max_gpus)
|
| 205 |
+
get_gpu_memory_info(verbose)
|
| 206 |
+
|
| 207 |
+
embs_by_mix = {}
|
| 208 |
+
labels_by_mix = {}
|
| 209 |
+
|
| 210 |
+
for k, mix in enumerate(mixture_entries):
|
| 211 |
+
speakers_this_mix = [e for e in mix if e["id"] in all_outs]
|
| 212 |
+
if not speakers_this_mix:
|
| 213 |
+
continue
|
| 214 |
+
|
| 215 |
+
if verbose:
|
| 216 |
+
print(
|
| 217 |
+
f"Processing mixture {k + 1}/{len(mixture_entries)} for {metric_type}"
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
all_signals_mix = []
|
| 221 |
+
all_masks_mix = []
|
| 222 |
+
all_labels_mix = []
|
| 223 |
+
|
| 224 |
+
for e in speakers_this_mix:
|
| 225 |
+
s = e["id"]
|
| 226 |
+
|
| 227 |
+
if metric_type == "PS":
|
| 228 |
+
dists = [
|
| 229 |
+
loudness_normalize(d)
|
| 230 |
+
for d in apply_distortions(all_refs[s].numpy(), "all")
|
| 231 |
+
]
|
| 232 |
+
else:
|
| 233 |
+
dists = [
|
| 234 |
+
loudness_normalize(d)
|
| 235 |
+
for d in apply_adv_distortions(
|
| 236 |
+
all_refs[s].numpy(), "all"
|
| 237 |
+
)
|
| 238 |
+
]
|
| 239 |
+
|
| 240 |
+
sigs = [all_refs[s].numpy(), all_outs[s].numpy()] + dists
|
| 241 |
+
lbls = ["ref", "out"] + [f"d{i}" for i in range(len(dists))]
|
| 242 |
+
|
| 243 |
+
masks = [voiced_mask_mix[k]] * len(sigs)
|
| 244 |
+
all_signals_mix.extend(sigs)
|
| 245 |
+
all_masks_mix.extend(masks)
|
| 246 |
+
all_labels_mix.extend([f"{s}-{l}" for l in lbls])
|
| 247 |
+
|
| 248 |
+
try:
|
| 249 |
+
# Process in smaller batches
|
| 250 |
+
batch_size = min(2, BATCH_SIZE)
|
| 251 |
+
embeddings_list = []
|
| 252 |
+
|
| 253 |
+
for i in range(0, len(all_signals_mix), batch_size):
|
| 254 |
+
batch_sigs = all_signals_mix[i:i + batch_size]
|
| 255 |
+
batch_masks = all_masks_mix[i:i + batch_size]
|
| 256 |
+
|
| 257 |
+
batch_embs = embed_batch(
|
| 258 |
+
batch_sigs,
|
| 259 |
+
batch_masks,
|
| 260 |
+
model_wrapper,
|
| 261 |
+
layer_eff,
|
| 262 |
+
use_mlm=False,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if batch_embs.numel() > 0:
|
| 266 |
+
embeddings_list.append(batch_embs.cpu())
|
| 267 |
+
|
| 268 |
+
torch.cuda.empty_cache()
|
| 269 |
+
|
| 270 |
+
if embeddings_list:
|
| 271 |
+
embeddings = torch.cat(embeddings_list, dim=0)
|
| 272 |
+
embs_by_mix[k] = embeddings
|
| 273 |
+
labels_by_mix[k] = all_labels_mix
|
| 274 |
+
|
| 275 |
+
except Exception as ex:
|
| 276 |
+
if verbose:
|
| 277 |
+
print(f" ERROR processing mixture {k + 1}: {ex}")
|
| 278 |
+
continue
|
| 279 |
+
finally:
|
| 280 |
+
# Always clean up after processing a mixture
|
| 281 |
+
del all_signals_mix, all_masks_mix
|
| 282 |
+
if 'embeddings_list' in locals():
|
| 283 |
+
del embeddings_list
|
| 284 |
+
clear_gpu_memory()
|
| 285 |
+
gc.collect()
|
| 286 |
+
|
| 287 |
+
if verbose:
|
| 288 |
+
print(f" Computing {metric_type} scores for {mname}...")
|
| 289 |
+
|
| 290 |
+
# Process mixtures with their stored embeddings and labels
|
| 291 |
+
with ThreadPoolExecutor(
|
| 292 |
+
max_workers=min(2, ngpu if ngpu > 0 else 1)
|
| 293 |
+
) as executor:
|
| 294 |
+
for k in range(len(mixture_entries)):
|
| 295 |
+
if k not in embs_by_mix:
|
| 296 |
+
continue
|
| 297 |
+
|
| 298 |
+
E, L, D = embs_by_mix[k].shape
|
| 299 |
+
if L == 0:
|
| 300 |
+
if verbose:
|
| 301 |
+
print(f" WARNING: mixture {k + 1} produced 0 frames after masking; skipping.")
|
| 302 |
+
continue
|
| 303 |
+
|
| 304 |
+
# Get the labels for this mixture
|
| 305 |
+
labels_for_mix = labels_by_mix[k]
|
| 306 |
+
|
| 307 |
+
def process_frame(f, embeddings_mix, labels_mix):
|
| 308 |
+
try:
|
| 309 |
+
frame_emb = embeddings_mix[:, f, :].detach().cpu().numpy()
|
| 310 |
+
|
| 311 |
+
if add_ci:
|
| 312 |
+
coords_d, coords_c, eigvals, k_sub_gauss = (
|
| 313 |
+
gpu_distributor.execute_on_gpu(
|
| 314 |
+
diffusion_map_torch,
|
| 315 |
+
frame_emb,
|
| 316 |
+
labels_mix,
|
| 317 |
+
alpha=alpha,
|
| 318 |
+
eig_solver="full",
|
| 319 |
+
return_eigs=True,
|
| 320 |
+
return_complement=True,
|
| 321 |
+
return_cval=add_ci,
|
| 322 |
+
)
|
| 323 |
+
)
|
| 324 |
+
else:
|
| 325 |
+
coords_d = gpu_distributor.execute_on_gpu(
|
| 326 |
+
diffusion_map_torch,
|
| 327 |
+
frame_emb,
|
| 328 |
+
labels_mix,
|
| 329 |
+
alpha=alpha,
|
| 330 |
+
eig_solver="full",
|
| 331 |
+
return_eigs=False,
|
| 332 |
+
return_complement=False,
|
| 333 |
+
return_cval=False,
|
| 334 |
+
)
|
| 335 |
+
coords_c = None
|
| 336 |
+
eigvals = None
|
| 337 |
+
k_sub_gauss = 1
|
| 338 |
+
|
| 339 |
+
if metric_type == "PS":
|
| 340 |
+
score = compute_ps(
|
| 341 |
+
coords_d, labels_mix, max_gpus
|
| 342 |
+
)
|
| 343 |
+
bias = prob = None
|
| 344 |
+
if add_ci:
|
| 345 |
+
bias, prob = ps_ci_components_full(
|
| 346 |
+
coords_d,
|
| 347 |
+
coords_c,
|
| 348 |
+
eigvals,
|
| 349 |
+
labels_mix,
|
| 350 |
+
delta=DEFAULT_DELTA_CI,
|
| 351 |
+
)
|
| 352 |
+
return f, "PS", score, bias, prob
|
| 353 |
+
else:
|
| 354 |
+
score = compute_pm(
|
| 355 |
+
coords_d, labels_mix, "gamma", max_gpus
|
| 356 |
+
)
|
| 357 |
+
bias = prob = None
|
| 358 |
+
if add_ci:
|
| 359 |
+
bias, prob = pm_ci_components_full(
|
| 360 |
+
coords_d,
|
| 361 |
+
coords_c,
|
| 362 |
+
eigvals,
|
| 363 |
+
labels_mix,
|
| 364 |
+
delta=DEFAULT_DELTA_CI,
|
| 365 |
+
K=k_sub_gauss,
|
| 366 |
+
)
|
| 367 |
+
return f, "PM", score, bias, prob
|
| 368 |
+
|
| 369 |
+
except Exception as ex:
|
| 370 |
+
if verbose:
|
| 371 |
+
print(f" ERROR frame {f + 1}: {ex}")
|
| 372 |
+
return None
|
| 373 |
+
|
| 374 |
+
futures = [
|
| 375 |
+
executor.submit(process_frame, f, embs_by_mix[k], labels_for_mix)
|
| 376 |
+
for f in range(L)
|
| 377 |
+
]
|
| 378 |
+
for fut in futures:
|
| 379 |
+
result = fut.result()
|
| 380 |
+
if result is None:
|
| 381 |
+
continue
|
| 382 |
+
|
| 383 |
+
f, metric, score, bias, prob = result
|
| 384 |
+
|
| 385 |
+
if metric == "PS":
|
| 386 |
+
for sp in score:
|
| 387 |
+
ps_ts[mname][sp].append(score[sp])
|
| 388 |
+
if add_ci and bias is not None:
|
| 389 |
+
ps_bias_ts[mname][sp].append(bias[sp])
|
| 390 |
+
ps_prob_ts[mname][sp].append(prob[sp])
|
| 391 |
+
else:
|
| 392 |
+
for sp in score:
|
| 393 |
+
pm_ts[mname][sp].append(score[sp])
|
| 394 |
+
if add_ci and bias is not None:
|
| 395 |
+
pm_bias_ts[mname][sp].append(bias[sp])
|
| 396 |
+
pm_prob_ts[mname][sp].append(prob[sp])
|
| 397 |
+
|
| 398 |
+
# Clean up after processing all mixtures for this metric
|
| 399 |
+
del embs_by_mix, labels_by_mix
|
| 400 |
+
clear_gpu_memory()
|
| 401 |
+
gc.collect()
|
| 402 |
+
|
| 403 |
+
del model_wrapper
|
| 404 |
+
clear_gpu_memory()
|
| 405 |
+
gc.collect()
|
| 406 |
+
|
| 407 |
+
if verbose:
|
| 408 |
+
print(f" Saving results for {algo}...")
|
| 409 |
+
|
| 410 |
+
for m in models:
|
| 411 |
+
|
| 412 |
+
def _pad(vec, n):
|
| 413 |
+
return vec + [np.nan] * (n - len(vec))
|
| 414 |
+
|
| 415 |
+
max_len = 0
|
| 416 |
+
for s in ordered_speakers:
|
| 417 |
+
max_len = max(max_len, len(ps_ts[m][s]), len(pm_ts[m][s]))
|
| 418 |
+
|
| 419 |
+
pd.DataFrame(
|
| 420 |
+
{s: _pad(ps_ts[m][s], max_len) for s in ordered_speakers}
|
| 421 |
+
).to_csv(os.path.join(algo_dir, f"ps_scores_{m}.csv"), index=False)
|
| 422 |
+
|
| 423 |
+
pd.DataFrame(
|
| 424 |
+
{s: _pad(pm_ts[m][s], max_len) for s in ordered_speakers}
|
| 425 |
+
).to_csv(os.path.join(algo_dir, f"pm_scores_{m}.csv"), index=False)
|
| 426 |
+
|
| 427 |
+
if add_ci:
|
| 428 |
+
ci_cols = {}
|
| 429 |
+
for s in ordered_speakers:
|
| 430 |
+
ci_cols[f"{s}_ps_bias"] = _pad(ps_bias_ts[m][s], max_len)
|
| 431 |
+
ci_cols[f"{s}_ps_prob"] = _pad(ps_prob_ts[m][s], max_len)
|
| 432 |
+
ci_cols[f"{s}_pm_bias"] = _pad(pm_bias_ts[m][s], max_len)
|
| 433 |
+
ci_cols[f"{s}_pm_prob"] = _pad(pm_prob_ts[m][s], max_len)
|
| 434 |
+
pd.DataFrame(ci_cols).to_csv(
|
| 435 |
+
os.path.join(algo_dir, f"ci_{m}.csv"), index=False
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
del all_outs
|
| 439 |
+
clear_gpu_memory()
|
| 440 |
+
gc.collect()
|
| 441 |
+
|
| 442 |
+
print(f"\nEXPERIMENT COMPLETED")
|
| 443 |
+
print(f"Results saved to: {exp_root}")
|
| 444 |
+
|
| 445 |
+
del all_refs, voiced_mask_mix
|
| 446 |
+
|
| 447 |
+
# Import and call the cleanup function
|
| 448 |
+
from models import cleanup_all_models
|
| 449 |
+
cleanup_all_models()
|
| 450 |
+
|
| 451 |
+
clear_gpu_memory()
|
| 452 |
+
get_gpu_memory_info(verbose)
|
| 453 |
+
gc.collect()
|
| 454 |
+
|
| 455 |
+
return exp_root
|
hf_readme.md
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: MAPSS Multi Source Audio Perceptual Separation Scores
|
| 3 |
+
emoji: 🎵
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.0.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# MAPSS: Multi-source Audio Perceptual Separation Scores
|
| 14 |
+
|
| 15 |
+
Evaluate audio source separation quality using Perceptual Similarity (PS) and Perceptual Matching (PM) metrics.
|
| 16 |
+
|
| 17 |
+
## Features
|
| 18 |
+
|
| 19 |
+
- **Perceptual Similarity (PS)**: Measures how similar separated outputs are to reference sources in perceptual embedding space
|
| 20 |
+
- **Perceptual Matching (PM)**: Evaluates robustness against a comprehensive set of audio distortions
|
| 21 |
+
- **Multiple embedding models**: Support for WavLM, Wav2Vec2, HuBERT, AST, and more
|
| 22 |
+
- **Automatic output-to-reference matching**: Uses correlation-based Hungarian algorithm
|
| 23 |
+
- **GPU-optimized processing**: Efficient batch processing with memory management
|
| 24 |
+
- **Diffusion maps**: Advanced dimensionality reduction for perceptual space analysis
|
| 25 |
+
|
| 26 |
+
## Input Format
|
| 27 |
+
|
| 28 |
+
Upload a ZIP file containing:
|
| 29 |
+
```
|
| 30 |
+
your_mixture.zip
|
| 31 |
+
├── references/ # Original clean sources
|
| 32 |
+
│ ├── speaker1.wav
|
| 33 |
+
│ ├── speaker2.wav
|
| 34 |
+
│ └── ...
|
| 35 |
+
└── outputs/ # Separated outputs from your algorithm
|
| 36 |
+
├── separated1.wav
|
| 37 |
+
├── separated2.wav
|
| 38 |
+
└── ...
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### Audio Requirements
|
| 42 |
+
- Format: WAV files
|
| 43 |
+
- Sample rate: Any (automatically resampled to 16kHz)
|
| 44 |
+
- Channels: Mono or stereo (converted to mono)
|
| 45 |
+
- Number of files: Equal number of references and outputs
|
| 46 |
+
|
| 47 |
+
## Output Format
|
| 48 |
+
|
| 49 |
+
The tool generates a ZIP file containing:
|
| 50 |
+
- `ps_scores_{model}.csv`: PS scores for each speaker/source (0-1, higher is better)
|
| 51 |
+
- `pm_scores_{model}.csv`: PM scores for each speaker/source (0-1, higher is better)
|
| 52 |
+
- `params.json`: Experiment parameters used
|
| 53 |
+
- `manifest_canonical.json`: File mapping and processing details
|
| 54 |
+
|
| 55 |
+
### Score Interpretation
|
| 56 |
+
- **PS Score**: Perceptual Similarity
|
| 57 |
+
- 1.0 = Perfect separation (output identical to reference)
|
| 58 |
+
- 0.5 = Moderate separation quality
|
| 59 |
+
- 0.0 = Poor separation (output closer to other sources)
|
| 60 |
+
|
| 61 |
+
- **PM Score**: Perceptual Matching (robustness)
|
| 62 |
+
- 1.0 = Highly robust to distortions
|
| 63 |
+
- 0.5 = Moderate robustness
|
| 64 |
+
- 0.0 = Not robust (easily confused with distorted versions)
|
| 65 |
+
|
| 66 |
+
## Available Models
|
| 67 |
+
|
| 68 |
+
| Model | Description | Default Layer | Use Case |
|
| 69 |
+
|-------|-------------|---------------|----------|
|
| 70 |
+
| `raw` | Raw waveform features | N/A | Baseline comparison |
|
| 71 |
+
| `wavlm` | WavLM Large | 24 | Best overall performance |
|
| 72 |
+
| `wav2vec2` | Wav2Vec2 Large | 24 | Strong performance |
|
| 73 |
+
| `hubert` | HuBERT Large | 24 | Good for speech |
|
| 74 |
+
| `wavlm_base` | WavLM Base | 12 | Faster, good quality |
|
| 75 |
+
| `wav2vec2_base` | Wav2Vec2 Base | 12 | Faster processing |
|
| 76 |
+
| `hubert_base` | HuBERT Base | 12 | Faster for speech |
|
| 77 |
+
| `wav2vec2_xlsr` | Wav2Vec2 XLSR-53 | 24 | Multilingual |
|
| 78 |
+
| `ast` | Audio Spectrogram Transformer | 12 | General audio |
|
| 79 |
+
|
| 80 |
+
## Parameters
|
| 81 |
+
|
| 82 |
+
- **Model**: Select the embedding model for feature extraction
|
| 83 |
+
- **Layer**: Which transformer layer to use (auto-selected by default)
|
| 84 |
+
- **Alpha**: Diffusion maps parameter (0.0-1.0, default: 1.0)
|
| 85 |
+
- 0.0 = No normalization
|
| 86 |
+
- 1.0 = Full normalization (recommended)
|
| 87 |
+
|
| 88 |
+
## How It Works
|
| 89 |
+
|
| 90 |
+
1. **Feature Extraction**: Audio signals are processed through pre-trained self-supervised models to extract perceptual embeddings
|
| 91 |
+
2. **Voice Activity Detection**: Automatic detection of voiced segments using energy-based masking
|
| 92 |
+
3. **Diffusion Maps**: Embeddings are projected using diffusion maps for robust dimensionality reduction
|
| 93 |
+
4. **PS Computation**: Measures Mahalanobis distance between separated outputs and references vs other sources
|
| 94 |
+
5. **PM Computation**: Evaluates against comprehensive distortions including:
|
| 95 |
+
- Noise (white, pink, brown at various SNRs)
|
| 96 |
+
- Filtering (lowpass, highpass, notch, comb)
|
| 97 |
+
- Effects (reverb, echo, tremolo, vibrato)
|
| 98 |
+
- Distortions (clipping, pitch shift, time stretch)
|
| 99 |
+
6. **Scoring**: Frame-level scores are computed and aggregated
|
| 100 |
+
|
| 101 |
+
## Technical Details
|
| 102 |
+
|
| 103 |
+
- **Loudness normalization**: ITU-R BS.1770 standard (-23 LUFS)
|
| 104 |
+
- **Frame-based processing**: 20ms windows with 20ms hop
|
| 105 |
+
- **Correlation-based assignment**: Hungarian algorithm for optimal matching
|
| 106 |
+
- **Memory optimization**: Batch processing with automatic GPU memory management
|
| 107 |
+
- **Robust statistics**: Covariance regularization and outlier handling
|
| 108 |
+
|
| 109 |
+
## Citation
|
| 110 |
+
|
| 111 |
+
If you use MAPSS in your research, please cite:
|
| 112 |
+
|
| 113 |
+
```bibtex
|
| 114 |
+
@article{mapss2024,
|
| 115 |
+
title={MAPSS: Multi-source Audio Perceptual Separation Scores},
|
| 116 |
+
author={Your Name},
|
| 117 |
+
journal={arXiv preprint},
|
| 118 |
+
year={2024}
|
| 119 |
+
}
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
## Limitations
|
| 123 |
+
|
| 124 |
+
- Processing time scales with audio length and model size
|
| 125 |
+
- Memory requirements depend on number of sources and audio length
|
| 126 |
+
- Currently optimized for speech separation (music separation support in development)
|
| 127 |
+
- Maximum recommended sources: 10 per mixture
|
| 128 |
+
|
| 129 |
+
## License
|
| 130 |
+
|
| 131 |
+
Code: MIT License
|
| 132 |
+
Paper: CC-BY-4.0
|
| 133 |
+
|
| 134 |
+
## Support
|
| 135 |
+
|
| 136 |
+
For issues, questions, or contributions, please visit the [GitHub repository](https://github.com/yourusername/mapss).
|
hf_requirements.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies
|
| 2 |
+
gradio>=4.0.0
|
| 3 |
+
torch>=2.0.0
|
| 4 |
+
torchaudio>=2.0.0
|
| 5 |
+
transformers>=4.35.0
|
| 6 |
+
accelerate>=0.24.0
|
| 7 |
+
|
| 8 |
+
# Audio processing
|
| 9 |
+
librosa>=0.10.0
|
| 10 |
+
soundfile>=0.12.0
|
| 11 |
+
pyloudnorm>=0.1.0
|
| 12 |
+
scipy>=1.11.0
|
| 13 |
+
numpy>=1.24.0
|
| 14 |
+
|
| 15 |
+
# Data handling
|
| 16 |
+
pandas>=2.0.0
|
| 17 |
+
|
| 18 |
+
# Model specific
|
| 19 |
+
safetensors>=0.4.0
|
| 20 |
+
sentencepiece>=0.1.99 # For some tokenizers
|
| 21 |
+
|
| 22 |
+
# Optional optimizations
|
| 23 |
+
triton>=2.1.0 # For faster attention if available
|
| 24 |
+
|
| 25 |
+
# Memory management
|
| 26 |
+
psutil>=5.9.0
|
init.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from engine import compute_mapss_measures
|
| 2 |
+
|
| 3 |
+
__version__ = "1.0.0"
|
| 4 |
+
__all__ = ["compute_mapss_measures"]
|
main.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from engine import compute_mapss_measures
|
| 4 |
+
from argshield import _parse_args, _read_manifest, _validate_and_resolve, _validate_gpus
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
args = _parse_args()
|
| 8 |
+
|
| 9 |
+
manifest = _read_manifest(Path(args.manifest))
|
| 10 |
+
layer_final, alpha_final = _validate_and_resolve(args.model, args.layer, args.alpha)
|
| 11 |
+
max_gpus_final = _validate_gpus(args.max_gpus)
|
| 12 |
+
|
| 13 |
+
results_dir = compute_mapss_measures(
|
| 14 |
+
models=[args.model],
|
| 15 |
+
mixtures=manifest,
|
| 16 |
+
verbose=args.verbose,
|
| 17 |
+
max_gpus=max_gpus_final,
|
| 18 |
+
layer=layer_final,
|
| 19 |
+
alpha=alpha_final,
|
| 20 |
+
)
|
| 21 |
+
print(f"Results saved to: {results_dir}")
|
| 22 |
+
|
| 23 |
+
if __name__ == "__main__":
|
| 24 |
+
main()
|
metrics.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from scipy.special import gammaincc
|
| 7 |
+
from scipy.stats import gamma
|
| 8 |
+
|
| 9 |
+
from config import COV_TOL, DEFAULT_DELTA_CI
|
| 10 |
+
from utils import get_gpu_count, mahalanobis_torch, safe_cov_torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def pm_tail_gamma(d_out_sq, sq_dists):
|
| 14 |
+
"""PM tail gamma exactly as original."""
|
| 15 |
+
mu = sq_dists.mean().item()
|
| 16 |
+
var = sq_dists.var(unbiased=True).item()
|
| 17 |
+
if var == 0.0:
|
| 18 |
+
return 1.0
|
| 19 |
+
k = (mu**2) / var
|
| 20 |
+
theta = var / mu
|
| 21 |
+
return float(1.0 - gamma.cdf(d_out_sq, a=k, scale=theta))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def pm_tail_rank(d_out_sq, sq_dists):
|
| 25 |
+
"""PM tail rank exactly as original."""
|
| 26 |
+
rank = int((sq_dists < d_out_sq).sum().item())
|
| 27 |
+
n = sq_dists.numel()
|
| 28 |
+
return 1.0 - (rank + 0.5) / (n + 1.0)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def diffusion_map_torch(
|
| 32 |
+
X_np,
|
| 33 |
+
labels_by_mix,
|
| 34 |
+
*,
|
| 35 |
+
cutoff=0.99,
|
| 36 |
+
tol=1e-3,
|
| 37 |
+
diffusion_time=1,
|
| 38 |
+
alpha=0.0,
|
| 39 |
+
eig_solver="lobpcg",
|
| 40 |
+
k=None,
|
| 41 |
+
device=None,
|
| 42 |
+
return_eigs=False,
|
| 43 |
+
return_complement=False,
|
| 44 |
+
return_cval=False,
|
| 45 |
+
):
|
| 46 |
+
"""Diffusion map computation exactly as original."""
|
| 47 |
+
device = device or ("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 48 |
+
X = torch.as_tensor(X_np, dtype=torch.float32, device=device)
|
| 49 |
+
N = X.shape[0]
|
| 50 |
+
|
| 51 |
+
if device != "cpu" and torch.cuda.is_available():
|
| 52 |
+
stream = torch.cuda.Stream(device=device)
|
| 53 |
+
ctx_dev = torch.cuda.device(device)
|
| 54 |
+
ctx_stream = torch.cuda.stream(stream)
|
| 55 |
+
else:
|
| 56 |
+
from contextlib import nullcontext
|
| 57 |
+
|
| 58 |
+
stream = None
|
| 59 |
+
ctx_dev = nullcontext()
|
| 60 |
+
ctx_stream = nullcontext()
|
| 61 |
+
|
| 62 |
+
with ctx_dev:
|
| 63 |
+
with ctx_stream:
|
| 64 |
+
if N > 1000:
|
| 65 |
+
chunk = min(500, N)
|
| 66 |
+
D2 = torch.zeros(N, N, device=device)
|
| 67 |
+
for i in range(0, N, chunk):
|
| 68 |
+
ei = min(i + chunk, N)
|
| 69 |
+
for j in range(0, N, chunk):
|
| 70 |
+
ej = min(j + chunk, N)
|
| 71 |
+
D2[i:ei, j:ej] = torch.cdist(X[i:ei], X[j:ej]).pow_(2)
|
| 72 |
+
else:
|
| 73 |
+
D2 = torch.cdist(X, X).pow_(2)
|
| 74 |
+
|
| 75 |
+
i, j = torch.triu_indices(
|
| 76 |
+
N, N, offset=1, device=None if device == "cpu" else device
|
| 77 |
+
)
|
| 78 |
+
eps = torch.median(D2[i, j])
|
| 79 |
+
K = torch.exp(-D2 / (2 * eps))
|
| 80 |
+
d = K.sum(dim=1)
|
| 81 |
+
|
| 82 |
+
if alpha != 0.0:
|
| 83 |
+
d_alpha_inv = d.pow(-alpha)
|
| 84 |
+
K *= d_alpha_inv[:, None] * d_alpha_inv[None, :]
|
| 85 |
+
d = K.sum(dim=1)
|
| 86 |
+
|
| 87 |
+
D_half_inv = torch.diag(torch.rsqrt(d))
|
| 88 |
+
K_sym = D_half_inv @ K @ D_half_inv
|
| 89 |
+
|
| 90 |
+
if eig_solver == "lobpcg":
|
| 91 |
+
m = k if k is not None else min(N - 1, 50)
|
| 92 |
+
init = torch.randn(N, m, device=device)
|
| 93 |
+
vals, vecs = torch.lobpcg(
|
| 94 |
+
K_sym, k=m, X=init, niter=200, tol=tol, largest=True
|
| 95 |
+
)
|
| 96 |
+
elif eig_solver == "full":
|
| 97 |
+
vals, vecs = torch.linalg.eigh(K_sym)
|
| 98 |
+
vals, vecs = vals.flip(0), vecs.flip(1)
|
| 99 |
+
if k is not None:
|
| 100 |
+
vecs = vecs[:, : k + 1]
|
| 101 |
+
vals = vals[: k + 1]
|
| 102 |
+
else:
|
| 103 |
+
raise ValueError(f"Unknown eig_solver '{eig_solver}'")
|
| 104 |
+
|
| 105 |
+
psi = vecs[:, 1:]
|
| 106 |
+
lam = vals[1:]
|
| 107 |
+
cum = torch.cumsum(lam, dim=0)
|
| 108 |
+
L = int((cum / cum[-1] < cutoff).sum().item()) + 1
|
| 109 |
+
lam_pow = lam.pow(diffusion_time)
|
| 110 |
+
psi_all = psi * lam_pow
|
| 111 |
+
Psi = psi_all[:, :L]
|
| 112 |
+
Psi_rest = psi_all[:, L:]
|
| 113 |
+
|
| 114 |
+
if return_cval:
|
| 115 |
+
indices_with_out = [
|
| 116 |
+
ii for ii, name in enumerate(labels_by_mix) if "out" in name
|
| 117 |
+
]
|
| 118 |
+
valid_idx = torch.tensor(
|
| 119 |
+
[ii for ii in range(N) if ii not in indices_with_out], device=device
|
| 120 |
+
)
|
| 121 |
+
pi_min = d[valid_idx].min() / d[valid_idx].sum()
|
| 122 |
+
c_val = lam_pow[0] * pi_min.rsqrt() / math.log(2.0)
|
| 123 |
+
|
| 124 |
+
if stream is not None:
|
| 125 |
+
stream.synchronize()
|
| 126 |
+
|
| 127 |
+
if return_complement and return_eigs and return_cval:
|
| 128 |
+
return (
|
| 129 |
+
Psi.cpu().numpy(),
|
| 130 |
+
Psi_rest.cpu().numpy(),
|
| 131 |
+
lam.cpu().numpy(),
|
| 132 |
+
float(c_val),
|
| 133 |
+
)
|
| 134 |
+
if return_complement and return_eigs:
|
| 135 |
+
return Psi.cpu().numpy(), Psi_rest.cpu().numpy(), lam.cpu().numpy()
|
| 136 |
+
if return_complement:
|
| 137 |
+
return Psi.cpu().numpy(), Psi_rest.cpu().numpy()
|
| 138 |
+
if return_eigs:
|
| 139 |
+
return Psi.cpu().numpy(), lam.cpu().numpy()
|
| 140 |
+
return Psi.cpu().numpy()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def compute_ps(coords, labels, max_gpus=None):
|
| 144 |
+
ngpu = get_gpu_count(max_gpus)
|
| 145 |
+
|
| 146 |
+
if ngpu == 0:
|
| 147 |
+
coords_t = torch.tensor(coords)
|
| 148 |
+
spks_here = sorted({l.split("-")[0] for l in labels})
|
| 149 |
+
out = {}
|
| 150 |
+
for s in spks_here:
|
| 151 |
+
idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
|
| 152 |
+
out_i = labels.index(f"{s}-out")
|
| 153 |
+
ref_is = [i for i in idxs if i != out_i]
|
| 154 |
+
mu = coords_t[ref_is].mean(0)
|
| 155 |
+
cov = safe_cov_torch(coords_t[ref_is])
|
| 156 |
+
inv = torch.linalg.inv(cov)
|
| 157 |
+
A = mahalanobis_torch(coords_t[out_i], mu, inv)
|
| 158 |
+
B_list = []
|
| 159 |
+
for o in spks_here:
|
| 160 |
+
if o == s:
|
| 161 |
+
continue
|
| 162 |
+
o_idxs = [
|
| 163 |
+
i
|
| 164 |
+
for i, l in enumerate(labels)
|
| 165 |
+
if l.startswith(o) and not l.endswith("-out")
|
| 166 |
+
]
|
| 167 |
+
mu_o = coords_t[o_idxs].mean(0)
|
| 168 |
+
inv_o = torch.linalg.inv(safe_cov_torch(coords_t[o_idxs]))
|
| 169 |
+
B_list.append(mahalanobis_torch(coords_t[out_i], mu_o, inv_o))
|
| 170 |
+
B_min = torch.min(torch.stack(B_list)) if B_list else torch.tensor(0.0)
|
| 171 |
+
out[s] = (1 - A / (A + B_min + 1e-6)).item()
|
| 172 |
+
return out
|
| 173 |
+
|
| 174 |
+
# GPU version
|
| 175 |
+
device = min(ngpu - 1, 1) # Use second GPU if available
|
| 176 |
+
device_str = f"cuda:{device}"
|
| 177 |
+
coords_t = torch.tensor(coords, device=device_str)
|
| 178 |
+
spks_here = sorted({l.split("-")[0] for l in labels})
|
| 179 |
+
out = {}
|
| 180 |
+
|
| 181 |
+
stream = torch.cuda.Stream(device=device_str)
|
| 182 |
+
with torch.cuda.device(device):
|
| 183 |
+
with torch.cuda.stream(stream):
|
| 184 |
+
for s in spks_here:
|
| 185 |
+
idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
|
| 186 |
+
out_i = labels.index(f"{s}-out")
|
| 187 |
+
ref_is = [i for i in idxs if i != out_i]
|
| 188 |
+
mu = coords_t[ref_is].mean(0)
|
| 189 |
+
cov = safe_cov_torch(coords_t[ref_is])
|
| 190 |
+
inv = torch.linalg.inv(cov)
|
| 191 |
+
A = mahalanobis_torch(coords_t[out_i], mu, inv)
|
| 192 |
+
B_list = []
|
| 193 |
+
for o in spks_here:
|
| 194 |
+
if o == s:
|
| 195 |
+
continue
|
| 196 |
+
o_idxs = [
|
| 197 |
+
i
|
| 198 |
+
for i, l in enumerate(labels)
|
| 199 |
+
if l.startswith(o) and not l.endswith("-out")
|
| 200 |
+
]
|
| 201 |
+
mu_o = coords_t[o_idxs].mean(0)
|
| 202 |
+
inv_o = torch.linalg.inv(safe_cov_torch(coords_t[o_idxs]))
|
| 203 |
+
B_list.append(mahalanobis_torch(coords_t[out_i], mu_o, inv_o))
|
| 204 |
+
B_min = (
|
| 205 |
+
torch.min(torch.stack(B_list))
|
| 206 |
+
if B_list
|
| 207 |
+
else torch.tensor(0.0, device=device_str)
|
| 208 |
+
)
|
| 209 |
+
out[s] = (1 - A / (A + B_min + 1e-6)).item()
|
| 210 |
+
stream.synchronize()
|
| 211 |
+
return out
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def compute_pm(coords, labels, pm_method, max_gpus=None):
|
| 215 |
+
ngpu = get_gpu_count(max_gpus)
|
| 216 |
+
|
| 217 |
+
if ngpu == 0:
|
| 218 |
+
coords_t = torch.tensor(coords)
|
| 219 |
+
spks_here = sorted({l.split("-")[0] for l in labels})
|
| 220 |
+
out = {}
|
| 221 |
+
for s in spks_here:
|
| 222 |
+
idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
|
| 223 |
+
ref_i = labels.index(f"{s}-ref")
|
| 224 |
+
out_i = labels.index(f"{s}-out")
|
| 225 |
+
d_idx = [i for i in idxs if i not in {ref_i, out_i}]
|
| 226 |
+
if len(d_idx) < 2:
|
| 227 |
+
out[s] = 0.0
|
| 228 |
+
continue
|
| 229 |
+
ref_v = coords_t[ref_i]
|
| 230 |
+
dist = coords_t[d_idx] - ref_v
|
| 231 |
+
N, D = dist.shape
|
| 232 |
+
cov = dist.T @ dist / (N - 1)
|
| 233 |
+
if torch.linalg.matrix_rank(cov) < D:
|
| 234 |
+
cov += torch.eye(D) * COV_TOL
|
| 235 |
+
inv = torch.linalg.inv(cov)
|
| 236 |
+
sq_dists = torch.stack(
|
| 237 |
+
[mahalanobis_torch(coords_t[i], ref_v, inv) ** 2 for i in d_idx]
|
| 238 |
+
)
|
| 239 |
+
d_out_sq = float(mahalanobis_torch(coords_t[out_i], ref_v, inv) ** 2)
|
| 240 |
+
pm_score = (
|
| 241 |
+
pm_tail_rank(d_out_sq, sq_dists)
|
| 242 |
+
if pm_method == "rank"
|
| 243 |
+
else pm_tail_gamma(d_out_sq, sq_dists)
|
| 244 |
+
)
|
| 245 |
+
out[s] = float(np.clip(pm_score, 0.0, 1.0))
|
| 246 |
+
return out
|
| 247 |
+
|
| 248 |
+
# GPU version
|
| 249 |
+
device = min(ngpu - 1, 1)
|
| 250 |
+
device_str = f"cuda:{device}"
|
| 251 |
+
coords_t = torch.tensor(coords, device=device_str)
|
| 252 |
+
spks_here = sorted({l.split("-")[0] for l in labels})
|
| 253 |
+
out = {}
|
| 254 |
+
|
| 255 |
+
stream = torch.cuda.Stream(device=device_str)
|
| 256 |
+
with torch.cuda.device(device):
|
| 257 |
+
with torch.cuda.stream(stream):
|
| 258 |
+
for s in spks_here:
|
| 259 |
+
idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
|
| 260 |
+
ref_i = labels.index(f"{s}-ref")
|
| 261 |
+
out_i = labels.index(f"{s}-out")
|
| 262 |
+
d_idx = [i for i in idxs if i not in {ref_i, out_i}]
|
| 263 |
+
if len(d_idx) < 2:
|
| 264 |
+
out[s] = 0.0
|
| 265 |
+
continue
|
| 266 |
+
ref_v = coords_t[ref_i]
|
| 267 |
+
dist = coords_t[d_idx] - ref_v
|
| 268 |
+
N, D = dist.shape
|
| 269 |
+
cov = dist.T @ dist / (N - 1)
|
| 270 |
+
if torch.linalg.matrix_rank(cov) < D:
|
| 271 |
+
cov += torch.eye(D, device=device_str) * COV_TOL
|
| 272 |
+
inv = torch.linalg.inv(cov)
|
| 273 |
+
sq_dists = torch.stack(
|
| 274 |
+
[mahalanobis_torch(coords_t[i], ref_v, inv) ** 2 for i in d_idx]
|
| 275 |
+
)
|
| 276 |
+
d_out_sq = float(mahalanobis_torch(coords_t[out_i], ref_v, inv) ** 2)
|
| 277 |
+
pm_score = (
|
| 278 |
+
pm_tail_rank(d_out_sq, sq_dists)
|
| 279 |
+
if pm_method == "rank"
|
| 280 |
+
else pm_tail_gamma(d_out_sq, sq_dists)
|
| 281 |
+
)
|
| 282 |
+
out[s] = float(np.clip(pm_score, 0.0, 1.0))
|
| 283 |
+
stream.synchronize()
|
| 284 |
+
return out
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def pm_ci_components_full(
|
| 288 |
+
coords_d, coords_rest, eigvals, labels, *, delta=0.05, K=1.0, C1=1.0, C2=1.0
|
| 289 |
+
):
|
| 290 |
+
"""PM CI components exactly as original - complete implementation."""
|
| 291 |
+
_EPS = 1e-12
|
| 292 |
+
|
| 293 |
+
def _safe_x(a, theta):
|
| 294 |
+
return a / max(theta, _EPS)
|
| 295 |
+
|
| 296 |
+
D = coords_d.shape[1]
|
| 297 |
+
m = coords_rest.shape[1]
|
| 298 |
+
if m == 0:
|
| 299 |
+
z = {s: 0.0 for s in {l.split("-")[0] for l in labels}}
|
| 300 |
+
return z.copy(), z.copy()
|
| 301 |
+
|
| 302 |
+
X_d = torch.tensor(
|
| 303 |
+
coords_d, device="cuda:0" if torch.cuda.is_available() else "cpu"
|
| 304 |
+
)
|
| 305 |
+
X_c = torch.tensor(
|
| 306 |
+
coords_rest, device="cuda:0" if torch.cuda.is_available() else "cpu"
|
| 307 |
+
)
|
| 308 |
+
spk_ids = sorted({l.split("-")[0] for l in labels})
|
| 309 |
+
bias_ci = {}
|
| 310 |
+
prob_ci = {}
|
| 311 |
+
|
| 312 |
+
for s in spk_ids:
|
| 313 |
+
idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
|
| 314 |
+
ref_i = labels.index(f"{s}-ref")
|
| 315 |
+
out_i = labels.index(f"{s}-out")
|
| 316 |
+
dist_is = [i for i in idxs if i not in {ref_i, out_i}]
|
| 317 |
+
n_p = len(dist_is)
|
| 318 |
+
|
| 319 |
+
if n_p < 2:
|
| 320 |
+
bias_ci[s] = 0.0
|
| 321 |
+
prob_ci[s] = 0.0
|
| 322 |
+
continue
|
| 323 |
+
|
| 324 |
+
ref_d = X_d[ref_i]
|
| 325 |
+
ref_c = X_c[ref_i]
|
| 326 |
+
D_mat = X_d[dist_is] - ref_d
|
| 327 |
+
C_mat = X_c[dist_is] - ref_c
|
| 328 |
+
Sigma_d = safe_cov_torch(D_mat)
|
| 329 |
+
Sigma_c = safe_cov_torch(C_mat)
|
| 330 |
+
C_dc = D_mat.T @ C_mat / (n_p - 1)
|
| 331 |
+
inv_Sigma_d = torch.linalg.inv(Sigma_d)
|
| 332 |
+
|
| 333 |
+
S_i = (
|
| 334 |
+
Sigma_c
|
| 335 |
+
- C_dc.T @ inv_Sigma_d @ C_dc
|
| 336 |
+
+ torch.eye(X_c.shape[1], device=X_c.device) * 1e-9
|
| 337 |
+
)
|
| 338 |
+
S_inv = torch.linalg.inv(S_i)
|
| 339 |
+
|
| 340 |
+
diff_out_d = X_d[out_i] - ref_d
|
| 341 |
+
diff_out_c = X_c[out_i] - ref_c
|
| 342 |
+
r_out = diff_out_c - C_dc.T @ inv_Sigma_d @ diff_out_d
|
| 343 |
+
delta_Gi_a = float(r_out @ S_inv @ r_out)
|
| 344 |
+
|
| 345 |
+
r_list = []
|
| 346 |
+
for p in dist_is:
|
| 347 |
+
d_p = X_d[p] - ref_d
|
| 348 |
+
c_p = X_c[p] - ref_c
|
| 349 |
+
r_p = c_p - C_dc.T @ inv_Sigma_d @ d_p
|
| 350 |
+
r_list.append(r_p)
|
| 351 |
+
R_p = torch.stack(r_list, dim=0)
|
| 352 |
+
delta_Gi_p = torch.sum(R_p @ S_inv * R_p, dim=1)
|
| 353 |
+
delta_Gi_mu_max = float(delta_Gi_p.max())
|
| 354 |
+
|
| 355 |
+
mah_sq = torch.stack(
|
| 356 |
+
[(X_d[i] - ref_d) @ inv_Sigma_d @ (X_d[i] - ref_d) for i in dist_is]
|
| 357 |
+
)
|
| 358 |
+
mu_g = float(mah_sq.mean())
|
| 359 |
+
sigma2_g = float(mah_sq.var(unbiased=True) + 1e-12)
|
| 360 |
+
sigma_g = math.sqrt(sigma2_g)
|
| 361 |
+
|
| 362 |
+
full_sq = mah_sq + delta_Gi_p
|
| 363 |
+
mu_full = float(full_sq.mean())
|
| 364 |
+
sigma2_full = float(full_sq.var(unbiased=True) + 1e-12)
|
| 365 |
+
|
| 366 |
+
if sigma2_g == 0.0:
|
| 367 |
+
delta_Gi_k = delta_Gi_theta = 0.0
|
| 368 |
+
else:
|
| 369 |
+
factor = delta_Gi_mu_max * n_p / (n_p - 1)
|
| 370 |
+
delta_Gi_k = 1.0 * factor * (mu_full + mu_g) / sigma2_g
|
| 371 |
+
delta_Gi_theta = 1.0 * factor * (sigma2_full + sigma2_g) / (mu_g**2 + 1e-9)
|
| 372 |
+
|
| 373 |
+
k_d = (mu_g**2) / max(sigma2_g, 1e-12)
|
| 374 |
+
theta_d = sigma2_g / max(mu_g, 1e-12)
|
| 375 |
+
a_d = float(diff_out_d @ inv_Sigma_d @ diff_out_d)
|
| 376 |
+
|
| 377 |
+
pm_center = gammaincc(k_d, _safe_x(a_d, theta_d))
|
| 378 |
+
|
| 379 |
+
corner_vals = []
|
| 380 |
+
for s_k in (-1, 1):
|
| 381 |
+
for s_theta in (-1, 1):
|
| 382 |
+
for s_a in (-1, 1):
|
| 383 |
+
k_c = max(k_d + s_k * delta_Gi_k, 1e-6)
|
| 384 |
+
theta_c = max(theta_d + s_theta * delta_Gi_theta, 1e-6)
|
| 385 |
+
a_c = max(a_d + s_a * delta_Gi_a, 1e-8)
|
| 386 |
+
corner_vals.append(gammaincc(k_c, _safe_x(a_c, theta_c)))
|
| 387 |
+
|
| 388 |
+
bias_ci[s] = max(abs(v - pm_center) for v in corner_vals)
|
| 389 |
+
|
| 390 |
+
# Probabilistic half-width
|
| 391 |
+
R_sq = float(mah_sq.max()) + 1e-12
|
| 392 |
+
log_term = math.log(6.0 / delta)
|
| 393 |
+
eps_mu = math.sqrt(2 * sigma2_g * log_term / n_p) + 3 * R_sq * log_term / n_p
|
| 394 |
+
eps_sigma = (
|
| 395 |
+
math.sqrt(2 * R_sq**2 * log_term / n_p) + 3 * R_sq**2 * log_term / n_p
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
g1_x = 2.0 * mu_g / (sigma2_g + 1e-9)
|
| 399 |
+
g1_y = -2.0 * mu_g**2 / (sigma_g**3 + 1e-9)
|
| 400 |
+
g2_x = -sigma2_g / (mu_g**2 + 1e-9)
|
| 401 |
+
g2_y = 2.0 * sigma_g / (mu_g + 1e-9)
|
| 402 |
+
|
| 403 |
+
delta_k = min(abs(g1_x) * eps_mu + abs(g1_y) * eps_sigma, 0.5 * k_d)
|
| 404 |
+
delta_theta = min(abs(g2_x) * eps_mu + abs(g2_y) * eps_sigma, 0.5 * theta_d)
|
| 405 |
+
delta_a = min(R_sq * math.sqrt(2 * log_term / n_p), 0.5 * a_d + 1e-12)
|
| 406 |
+
|
| 407 |
+
pm_corners = []
|
| 408 |
+
for s_k in (-1, 1):
|
| 409 |
+
for s_theta in (-1, 1):
|
| 410 |
+
for s_a in (-1, 1):
|
| 411 |
+
k_c = k_d + s_k * delta_k
|
| 412 |
+
theta_c = theta_d + s_theta * delta_theta
|
| 413 |
+
a_c = max(a_d + s_a * delta_a, 1e-8)
|
| 414 |
+
pm_corners.append(gammaincc(k_c, _safe_x(a_c, theta_c)))
|
| 415 |
+
|
| 416 |
+
prob_ci[s] = max(abs(pm - pm_center) for pm in pm_corners)
|
| 417 |
+
|
| 418 |
+
return bias_ci, prob_ci
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def ps_ci_components_full(coords_d, coords_rest, eigvals, labels, *, delta=0.05):
|
| 422 |
+
"""PS CI components exactly as original - complete implementation."""
|
| 423 |
+
|
| 424 |
+
def _mean_dev(lam_max, delta, n_eff):
|
| 425 |
+
return math.sqrt(2 * lam_max * math.log(2 / delta) / n_eff)
|
| 426 |
+
|
| 427 |
+
def _rel_cov_dev(lam_max, trace, delta, n_eff, C=1.0):
|
| 428 |
+
r = trace / lam_max
|
| 429 |
+
abs_dev = (
|
| 430 |
+
C * lam_max * (math.sqrt(r / n_eff) + (r + math.log(2 / delta)) / n_eff)
|
| 431 |
+
)
|
| 432 |
+
return abs_dev / lam_max
|
| 433 |
+
|
| 434 |
+
def _maha_eps_m(a_hat, lam_min, lam_max, mean_dev, rel_cov_dev):
|
| 435 |
+
term1 = 2 * math.sqrt(a_hat) * mean_dev * math.sqrt(lam_max / lam_min)
|
| 436 |
+
term2 = a_hat * rel_cov_dev
|
| 437 |
+
return term1 + term2
|
| 438 |
+
|
| 439 |
+
D = coords_d.shape[1]
|
| 440 |
+
m = coords_rest.shape[1]
|
| 441 |
+
if m == 0:
|
| 442 |
+
z = {s: 0.0 for s in set(l.split("-")[0] for l in labels)}
|
| 443 |
+
return z.copy(), z.copy()
|
| 444 |
+
|
| 445 |
+
X_d = torch.tensor(
|
| 446 |
+
coords_d, device="cuda:0" if torch.cuda.is_available() else "cpu"
|
| 447 |
+
)
|
| 448 |
+
X_c = torch.tensor(
|
| 449 |
+
coords_rest, device="cuda:0" if torch.cuda.is_available() else "cpu"
|
| 450 |
+
)
|
| 451 |
+
spk_ids = sorted({l.split("-")[0] for l in labels})
|
| 452 |
+
bias = {}
|
| 453 |
+
prob = {}
|
| 454 |
+
|
| 455 |
+
for s in spk_ids:
|
| 456 |
+
idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
|
| 457 |
+
out_i = labels.index(f"{s}-out")
|
| 458 |
+
ref_is = [i for i in idxs if i != out_i]
|
| 459 |
+
|
| 460 |
+
mu_d = X_d[ref_is].mean(0)
|
| 461 |
+
mu_c = X_c[ref_is].mean(0)
|
| 462 |
+
Sigma_d = safe_cov_torch(X_d[ref_is])
|
| 463 |
+
Sigma_c = safe_cov_torch(X_c[ref_is])
|
| 464 |
+
C_dc = (X_d[ref_is] - mu_d).T @ (X_c[ref_is] - mu_c) / (len(ref_is) - 1)
|
| 465 |
+
inv_Sd = torch.linalg.inv(Sigma_d)
|
| 466 |
+
|
| 467 |
+
lam_min = torch.linalg.eigvalsh(Sigma_d).min().clamp_min(1e-9).item()
|
| 468 |
+
lam_max = torch.linalg.eigvalsh(Sigma_d).max()
|
| 469 |
+
trace = torch.trace(Sigma_d).item()
|
| 470 |
+
|
| 471 |
+
diff_d = X_d[out_i] - mu_d
|
| 472 |
+
diff_c = X_c[out_i] - mu_c
|
| 473 |
+
A_d = float(mahalanobis_torch(X_d[out_i], mu_d, inv_Sd))
|
| 474 |
+
|
| 475 |
+
r_i = diff_c - C_dc.T @ inv_Sd @ diff_d
|
| 476 |
+
S_i = (
|
| 477 |
+
Sigma_c
|
| 478 |
+
- C_dc.T @ inv_Sd @ C_dc
|
| 479 |
+
+ torch.eye(X_c.shape[1], device=X_c.device) * 1e-9
|
| 480 |
+
)
|
| 481 |
+
term_i = math.sqrt(float(r_i @ torch.linalg.solve(S_i, r_i)))
|
| 482 |
+
|
| 483 |
+
B_d, term_j = float("inf"), 0.0
|
| 484 |
+
Sig_o = None
|
| 485 |
+
for o in spk_ids:
|
| 486 |
+
if o == s:
|
| 487 |
+
continue
|
| 488 |
+
o_idxs = [
|
| 489 |
+
i
|
| 490 |
+
for i, l in enumerate(labels)
|
| 491 |
+
if l.startswith(o) and not l.endswith("-out")
|
| 492 |
+
]
|
| 493 |
+
muo_d = X_d[o_idxs].mean(0)
|
| 494 |
+
muo_c = X_c[o_idxs].mean(0)
|
| 495 |
+
Sig_o_tmp = safe_cov_torch(X_d[o_idxs])
|
| 496 |
+
inv_So = torch.linalg.inv(Sig_o_tmp)
|
| 497 |
+
this_B = float(mahalanobis_torch(X_d[out_i], muo_d, inv_So))
|
| 498 |
+
|
| 499 |
+
if this_B < B_d:
|
| 500 |
+
B_d = this_B
|
| 501 |
+
Sig_o = Sig_o_tmp
|
| 502 |
+
diff_do = X_d[out_i] - muo_d
|
| 503 |
+
diff_co = X_c[out_i] - muo_c
|
| 504 |
+
C_oc = (
|
| 505 |
+
(X_d[o_idxs] - muo_d).T @ (X_c[o_idxs] - muo_c) / (len(o_idxs) - 1)
|
| 506 |
+
)
|
| 507 |
+
r_j = diff_co - C_oc.T @ inv_So @ diff_do
|
| 508 |
+
S_j = (
|
| 509 |
+
safe_cov_torch(X_c[o_idxs])
|
| 510 |
+
- C_oc.T @ inv_So @ C_oc
|
| 511 |
+
+ torch.eye(X_c.shape[1], device=X_c.device) * 1e-9
|
| 512 |
+
)
|
| 513 |
+
term_j = math.sqrt(float(r_j @ torch.linalg.solve(S_j, r_j)))
|
| 514 |
+
|
| 515 |
+
denom = A_d + B_d
|
| 516 |
+
bias[s] = (B_d * term_i + A_d * term_j) / (denom**2)
|
| 517 |
+
|
| 518 |
+
if Sig_o is not None:
|
| 519 |
+
lam_min_o = torch.linalg.eigvalsh(Sig_o).min().clamp_min(1e-9).item()
|
| 520 |
+
lam_max_o = torch.linalg.eigvalsh(Sig_o).max().item()
|
| 521 |
+
trace_o = torch.trace(Sig_o).item()
|
| 522 |
+
|
| 523 |
+
n_eff = max(int(0.7 * len(ref_is)), 3)
|
| 524 |
+
RIDGE = 0.05
|
| 525 |
+
lam_min_eff = max(lam_min, RIDGE * lam_max.item())
|
| 526 |
+
lam_min_o_eff = max(lam_min_o, RIDGE * lam_max_o)
|
| 527 |
+
|
| 528 |
+
eps_i_sg = _maha_eps_m(
|
| 529 |
+
A_d,
|
| 530 |
+
lam_min_eff,
|
| 531 |
+
lam_max.item(),
|
| 532 |
+
_mean_dev(lam_max.item(), delta / 2, n_eff),
|
| 533 |
+
_rel_cov_dev(lam_max.item(), trace, delta / 2, n_eff),
|
| 534 |
+
)
|
| 535 |
+
eps_j_sg = _maha_eps_m(
|
| 536 |
+
B_d,
|
| 537 |
+
lam_min_o_eff,
|
| 538 |
+
lam_max_o,
|
| 539 |
+
_mean_dev(lam_max_o, delta / 2, n_eff),
|
| 540 |
+
_rel_cov_dev(lam_max_o, trace_o, delta / 2, n_eff),
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
grad_l2 = math.hypot(A_d, B_d) / (A_d + B_d) ** 2
|
| 544 |
+
ps_radius = grad_l2 * math.hypot(eps_i_sg, eps_j_sg)
|
| 545 |
+
prob[s] = min(1.0, ps_radius)
|
| 546 |
+
else:
|
| 547 |
+
prob[s] = 0.0
|
| 548 |
+
|
| 549 |
+
return bias, prob
|
models.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import queue
|
| 2 |
+
import threading
|
| 3 |
+
import gc
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from transformers import (
|
| 8 |
+
HubertModel,
|
| 9 |
+
Wav2Vec2FeatureExtractor,
|
| 10 |
+
Wav2Vec2Model,
|
| 11 |
+
WavLMModel,
|
| 12 |
+
ASTModel,
|
| 13 |
+
AutoFeatureExtractor,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from config import BATCH_SIZE, ENERGY_HOP_MS, ENERGY_WIN_MS, SR
|
| 17 |
+
from utils import get_gpu_count
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class BalancedDualGPUModel:
|
| 21 |
+
|
| 22 |
+
def __init__(self, model_name, layer, max_gpus=None):
|
| 23 |
+
self.layer = layer
|
| 24 |
+
self.models = []
|
| 25 |
+
self.extractors = []
|
| 26 |
+
self.devices = []
|
| 27 |
+
ngpu = get_gpu_count(max_gpus)
|
| 28 |
+
|
| 29 |
+
for gpu_id in range(min(ngpu, 2)):
|
| 30 |
+
device = f"cuda:{gpu_id}"
|
| 31 |
+
self.devices.append(device)
|
| 32 |
+
ckpt, cls, _ = get_model_config(layer)[model_name]
|
| 33 |
+
if cls is ASTModel:
|
| 34 |
+
extractor = AutoFeatureExtractor.from_pretrained(ckpt)
|
| 35 |
+
else:
|
| 36 |
+
extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)
|
| 37 |
+
|
| 38 |
+
attn_impl = "eager" if cls in (WavLMModel, ASTModel) else "sdpa"
|
| 39 |
+
model = cls.from_pretrained(
|
| 40 |
+
ckpt,
|
| 41 |
+
output_hidden_states=True,
|
| 42 |
+
use_safetensors=True,
|
| 43 |
+
torch_dtype=torch.float16,
|
| 44 |
+
low_cpu_mem_usage=True,
|
| 45 |
+
attn_implementation=attn_impl
|
| 46 |
+
)
|
| 47 |
+
model.eval()
|
| 48 |
+
model = model.to(device)
|
| 49 |
+
|
| 50 |
+
for param in model.parameters():
|
| 51 |
+
param.requires_grad = False
|
| 52 |
+
|
| 53 |
+
self.extractors.append(extractor)
|
| 54 |
+
self.models.append(model)
|
| 55 |
+
|
| 56 |
+
self.gpu_queues = [queue.Queue() for _ in range(len(self.devices))]
|
| 57 |
+
self.result_queue = queue.Queue()
|
| 58 |
+
self.workers = []
|
| 59 |
+
for i in range(len(self.devices)):
|
| 60 |
+
worker = threading.Thread(target=self._gpu_worker, args=(i,))
|
| 61 |
+
worker.daemon = True
|
| 62 |
+
worker.start()
|
| 63 |
+
self.workers.append(worker)
|
| 64 |
+
|
| 65 |
+
def _gpu_worker(self, gpu_id):
|
| 66 |
+
device = self.devices[gpu_id]
|
| 67 |
+
model = self.models[gpu_id]
|
| 68 |
+
extractor = self.extractors[gpu_id]
|
| 69 |
+
while True:
|
| 70 |
+
task = self.gpu_queues[gpu_id].get()
|
| 71 |
+
if task is None:
|
| 72 |
+
break
|
| 73 |
+
signals, masks, use_mlm, task_id = task
|
| 74 |
+
try:
|
| 75 |
+
inputs = extractor(
|
| 76 |
+
signals, sampling_rate=SR, return_tensors="pt", padding=True
|
| 77 |
+
)
|
| 78 |
+
input_values = inputs.input_values.to(device, non_blocking=True)
|
| 79 |
+
|
| 80 |
+
torch.cuda.empty_cache()
|
| 81 |
+
|
| 82 |
+
orig_mode = model.training
|
| 83 |
+
model.train() if use_mlm else model.eval()
|
| 84 |
+
with torch.no_grad():
|
| 85 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
| 86 |
+
hs = model(
|
| 87 |
+
input_values, output_hidden_states=True
|
| 88 |
+
).hidden_states[self.layer]
|
| 89 |
+
model.train(orig_mode)
|
| 90 |
+
|
| 91 |
+
B, T, D = hs.shape
|
| 92 |
+
keep = []
|
| 93 |
+
for b in range(B):
|
| 94 |
+
mask_b = masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
|
| 95 |
+
mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
|
| 96 |
+
keep.append(hs[b][mask_t].cpu())
|
| 97 |
+
|
| 98 |
+
# Aggressive cleanup
|
| 99 |
+
del hs, input_values, inputs
|
| 100 |
+
torch.cuda.empty_cache()
|
| 101 |
+
|
| 102 |
+
if keep:
|
| 103 |
+
L_max = max(x.shape[0] for x in keep)
|
| 104 |
+
keep_padded = [
|
| 105 |
+
F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in keep
|
| 106 |
+
]
|
| 107 |
+
result = torch.stack(keep_padded, dim=0)
|
| 108 |
+
else:
|
| 109 |
+
result = torch.empty(0, 0, 0)
|
| 110 |
+
self.result_queue.put((task_id, result))
|
| 111 |
+
except Exception as e:
|
| 112 |
+
self.result_queue.put((task_id, e))
|
| 113 |
+
finally:
|
| 114 |
+
# Always clear cache after processing
|
| 115 |
+
torch.cuda.empty_cache()
|
| 116 |
+
|
| 117 |
+
def process_batch(self, signals, masks, use_mlm=False):
|
| 118 |
+
if not signals:
|
| 119 |
+
return torch.empty(0, 0, 0)
|
| 120 |
+
batch_size = len(signals)
|
| 121 |
+
split = (batch_size + len(self.devices) - 1) // len(self.devices)
|
| 122 |
+
results = {}
|
| 123 |
+
task_id = 0
|
| 124 |
+
for i in range(0, batch_size, split):
|
| 125 |
+
end = min(i + split, batch_size)
|
| 126 |
+
gpu_id = (i // split) % len(self.devices)
|
| 127 |
+
self.gpu_queues[gpu_id].put(
|
| 128 |
+
(signals[i:end], masks[i:end], use_mlm, task_id)
|
| 129 |
+
)
|
| 130 |
+
task_id += 1
|
| 131 |
+
for _ in range(task_id):
|
| 132 |
+
tid, result = self.result_queue.get()
|
| 133 |
+
if isinstance(result, Exception):
|
| 134 |
+
raise result
|
| 135 |
+
results[tid] = result
|
| 136 |
+
parts = [results[i] for i in range(task_id) if results[i].numel() > 0]
|
| 137 |
+
return torch.cat(parts, dim=0) if parts else torch.empty(0, 0, 0)
|
| 138 |
+
|
| 139 |
+
def cleanup(self):
|
| 140 |
+
"""Explicit cleanup method"""
|
| 141 |
+
for q in self.gpu_queues:
|
| 142 |
+
q.put(None)
|
| 143 |
+
for w in self.workers:
|
| 144 |
+
w.join(timeout=5.0)
|
| 145 |
+
for model in self.models:
|
| 146 |
+
del model
|
| 147 |
+
for extractor in self.extractors:
|
| 148 |
+
del extractor
|
| 149 |
+
self.models.clear()
|
| 150 |
+
self.extractors.clear()
|
| 151 |
+
torch.cuda.empty_cache()
|
| 152 |
+
gc.collect()
|
| 153 |
+
|
| 154 |
+
def __del__(self):
|
| 155 |
+
self.cleanup()
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# NO CACHE - we need to clean up models properly between runs
|
| 159 |
+
def get_model_config(layer):
|
| 160 |
+
return {
|
| 161 |
+
"raw": (None, None, None),
|
| 162 |
+
"wavlm": ("microsoft/wavlm-large", WavLMModel, layer),
|
| 163 |
+
"wav2vec2": ("facebook/wav2vec2-large-lv60", Wav2Vec2Model, layer),
|
| 164 |
+
"hubert": ("facebook/hubert-large-ll60k", HubertModel, layer),
|
| 165 |
+
"wavlm_base": ("microsoft/wavlm-base", WavLMModel, layer),
|
| 166 |
+
"wav2vec2_base": ("facebook/wav2vec2-base", Wav2Vec2Model, layer),
|
| 167 |
+
"hubert_base": ("facebook/hubert-base-ls960", HubertModel, layer),
|
| 168 |
+
"wav2vec2_xlsr": ("facebook/wav2vec2-large-xlsr-53", Wav2Vec2Model, layer),
|
| 169 |
+
"ast": ("MIT/ast-finetuned-audioset-10-10-0.4593", ASTModel, layer),
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# Store loaded models globally to properly manage them
|
| 174 |
+
_loaded_models = {}
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def load_model(name, layer, max_gpus=None):
|
| 178 |
+
global _loaded_models
|
| 179 |
+
|
| 180 |
+
# Clean up any previously loaded models first
|
| 181 |
+
if _loaded_models:
|
| 182 |
+
for key, model_data in _loaded_models.items():
|
| 183 |
+
if isinstance(model_data, tuple) and len(model_data) == 2:
|
| 184 |
+
if isinstance(model_data[0], BalancedDualGPUModel):
|
| 185 |
+
model_data[0].cleanup()
|
| 186 |
+
elif isinstance(model_data[0], tuple):
|
| 187 |
+
# Single GPU model
|
| 188 |
+
_, model = model_data[0]
|
| 189 |
+
del model
|
| 190 |
+
_loaded_models.clear()
|
| 191 |
+
torch.cuda.empty_cache()
|
| 192 |
+
gc.collect()
|
| 193 |
+
|
| 194 |
+
if name.lower() in {"raw", "waveform"}:
|
| 195 |
+
return "raw", layer
|
| 196 |
+
|
| 197 |
+
ngpu = get_gpu_count(max_gpus)
|
| 198 |
+
if ngpu > 1:
|
| 199 |
+
model = BalancedDualGPUModel(name, layer, max_gpus)
|
| 200 |
+
_loaded_models[name] = (model, layer)
|
| 201 |
+
return model, layer
|
| 202 |
+
else:
|
| 203 |
+
ckpt, cls, layer_eff = get_model_config(layer)[name]
|
| 204 |
+
if cls is ASTModel:
|
| 205 |
+
extractor = AutoFeatureExtractor.from_pretrained(ckpt)
|
| 206 |
+
else:
|
| 207 |
+
extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)
|
| 208 |
+
|
| 209 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 210 |
+
attn_impl = "eager" if cls in (WavLMModel, ASTModel) else "sdpa"
|
| 211 |
+
model = cls.from_pretrained(
|
| 212 |
+
ckpt,
|
| 213 |
+
output_hidden_states=True,
|
| 214 |
+
use_safetensors=True,
|
| 215 |
+
torch_dtype=torch.float16,
|
| 216 |
+
low_cpu_mem_usage=True,
|
| 217 |
+
attn_implementation=attn_impl
|
| 218 |
+
)
|
| 219 |
+
model.eval()
|
| 220 |
+
model = model.to(device)
|
| 221 |
+
|
| 222 |
+
for param in model.parameters():
|
| 223 |
+
param.requires_grad = False
|
| 224 |
+
|
| 225 |
+
model_tuple = ((extractor, model), layer_eff)
|
| 226 |
+
_loaded_models[name] = model_tuple
|
| 227 |
+
return (extractor, model), layer_eff
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def cleanup_all_models():
|
| 231 |
+
"""Call this at the end of each experiment to ensure complete cleanup"""
|
| 232 |
+
global _loaded_models
|
| 233 |
+
if _loaded_models:
|
| 234 |
+
for key, model_data in _loaded_models.items():
|
| 235 |
+
if isinstance(model_data, tuple) and len(model_data) == 2:
|
| 236 |
+
if isinstance(model_data[0], BalancedDualGPUModel):
|
| 237 |
+
model_data[0].cleanup()
|
| 238 |
+
elif isinstance(model_data[0], tuple):
|
| 239 |
+
# Single GPU model
|
| 240 |
+
_, model = model_data[0]
|
| 241 |
+
del model
|
| 242 |
+
_loaded_models.clear()
|
| 243 |
+
torch.cuda.empty_cache()
|
| 244 |
+
gc.collect()
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def embed_batch_raw(signals, masks_audio):
|
| 248 |
+
win = int(ENERGY_WIN_MS * SR / 1000)
|
| 249 |
+
hop = int(ENERGY_HOP_MS * SR / 1000)
|
| 250 |
+
reps, L_max = [], 0
|
| 251 |
+
for sig_np, mask_np in zip(signals, masks_audio):
|
| 252 |
+
x = torch.as_tensor(sig_np[:-1], dtype=torch.float32)
|
| 253 |
+
frames = x.unfold(0, win, hop)
|
| 254 |
+
mask = torch.as_tensor(mask_np[: len(frames)], dtype=torch.bool)
|
| 255 |
+
keep = frames[mask] if mask.any() else frames[:1]
|
| 256 |
+
reps.append(keep)
|
| 257 |
+
L_max = max(L_max, keep.size(0))
|
| 258 |
+
reps = [F.pad(r, (0, 0, 0, L_max - r.size(0))) for r in reps]
|
| 259 |
+
return torch.stack(reps, dim=0)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def embed_batch_single_gpu(
|
| 263 |
+
signals, masks_audio, extractor, model, layer, use_mlm=False
|
| 264 |
+
):
|
| 265 |
+
if not signals:
|
| 266 |
+
return torch.empty(0, 0, 0)
|
| 267 |
+
device = next(model.parameters()).device
|
| 268 |
+
|
| 269 |
+
max_batch = 2
|
| 270 |
+
all_keeps = []
|
| 271 |
+
|
| 272 |
+
for i in range(0, len(signals), max_batch):
|
| 273 |
+
batch_signals = signals[i:i + max_batch]
|
| 274 |
+
batch_masks = masks_audio[i:i + max_batch]
|
| 275 |
+
|
| 276 |
+
inputs = extractor(batch_signals, sampling_rate=SR, return_tensors="pt", padding=True)
|
| 277 |
+
input_values = inputs.input_values.to(device, non_blocking=True)
|
| 278 |
+
|
| 279 |
+
orig_mode = model.training
|
| 280 |
+
model.train() if use_mlm else model.eval()
|
| 281 |
+
|
| 282 |
+
with torch.no_grad():
|
| 283 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
| 284 |
+
hs = model(input_values, output_hidden_states=True).hidden_states[layer]
|
| 285 |
+
model.train(orig_mode)
|
| 286 |
+
|
| 287 |
+
B, T, D = hs.shape
|
| 288 |
+
for b in range(B):
|
| 289 |
+
mask_b = batch_masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
|
| 290 |
+
mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
|
| 291 |
+
all_keeps.append(hs[b][mask_t].cpu())
|
| 292 |
+
|
| 293 |
+
# Aggressive cleanup
|
| 294 |
+
del hs, input_values, inputs
|
| 295 |
+
torch.cuda.empty_cache()
|
| 296 |
+
|
| 297 |
+
if all_keeps:
|
| 298 |
+
L_max = max(x.shape[0] for x in all_keeps)
|
| 299 |
+
keep_padded = [F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in all_keeps]
|
| 300 |
+
result = torch.stack(keep_padded, dim=0)
|
| 301 |
+
# Clean up intermediate lists
|
| 302 |
+
del all_keeps, keep_padded
|
| 303 |
+
return result
|
| 304 |
+
else:
|
| 305 |
+
return torch.empty(0, 0, 0)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def embed_batch(signals, masks_audio, model_wrapper, layer, use_mlm=False):
|
| 309 |
+
if model_wrapper == "raw":
|
| 310 |
+
return embed_batch_raw(signals, masks_audio)
|
| 311 |
+
if isinstance(model_wrapper, BalancedDualGPUModel):
|
| 312 |
+
all_embeddings = []
|
| 313 |
+
batch_size = min(BATCH_SIZE, 2)
|
| 314 |
+
for i in range(0, len(signals), batch_size):
|
| 315 |
+
batch_emb = model_wrapper.process_batch(
|
| 316 |
+
signals[i: i + batch_size], masks_audio[i: i + batch_size], use_mlm
|
| 317 |
+
)
|
| 318 |
+
if batch_emb.numel() > 0:
|
| 319 |
+
all_embeddings.append(batch_emb)
|
| 320 |
+
# Clear cache after each batch
|
| 321 |
+
torch.cuda.empty_cache()
|
| 322 |
+
|
| 323 |
+
if all_embeddings:
|
| 324 |
+
result = torch.cat(all_embeddings, dim=0)
|
| 325 |
+
del all_embeddings
|
| 326 |
+
return result
|
| 327 |
+
else:
|
| 328 |
+
return torch.empty(0, 0, 0)
|
| 329 |
+
else:
|
| 330 |
+
extractor, model = model_wrapper
|
| 331 |
+
return embed_batch_single_gpu(
|
| 332 |
+
signals, masks_audio, extractor, model, layer, use_mlm=use_mlm
|
| 333 |
+
)
|
utils.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import threading
|
| 3 |
+
import warnings
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
try:
|
| 10 |
+
from scipy.optimize import linear_sum_assignment as _lsa
|
| 11 |
+
except Exception:
|
| 12 |
+
_lsa = None
|
| 13 |
+
|
| 14 |
+
warnings.filterwarnings("ignore", message="Some weights of Wav2Vec2Model")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_gpu_count(max_gpus=None):
|
| 18 |
+
ngpu = torch.cuda.device_count()
|
| 19 |
+
if max_gpus is not None:
|
| 20 |
+
ngpu = min(ngpu, max_gpus)
|
| 21 |
+
return ngpu
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def clear_gpu_memory():
|
| 25 |
+
"""Enhanced GPU memory clearing"""
|
| 26 |
+
if torch.cuda.is_available():
|
| 27 |
+
for i in range(torch.cuda.device_count()):
|
| 28 |
+
with torch.cuda.device(i):
|
| 29 |
+
torch.cuda.empty_cache()
|
| 30 |
+
torch.cuda.synchronize()
|
| 31 |
+
gc.collect()
|
| 32 |
+
torch.cuda.empty_cache()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_gpu_memory_info(verbose=False):
|
| 36 |
+
if not verbose:
|
| 37 |
+
return
|
| 38 |
+
for i in range(torch.cuda.device_count()):
|
| 39 |
+
try:
|
| 40 |
+
free_b, total_b = torch.cuda.mem_get_info(i) # type: ignore[attr-defined]
|
| 41 |
+
free_gb = free_b / 1024**3
|
| 42 |
+
total_gb = total_b / 1024**3
|
| 43 |
+
except Exception:
|
| 44 |
+
total_gb = torch.cuda.get_device_properties(i).total_memory / 1024**3
|
| 45 |
+
free_gb = total_gb - (torch.cuda.memory_reserved(i) / 1024**3)
|
| 46 |
+
mem_allocated = torch.cuda.memory_allocated(i) / 1024**3
|
| 47 |
+
print(f"GPU {i}: {mem_allocated:.2f}GB allocated, {free_gb:.2f}GB free / {total_gb:.2f}GB total")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def write_wav_16bit(path, x, sr=16000):
|
| 51 |
+
path = Path(path)
|
| 52 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 53 |
+
try:
|
| 54 |
+
import soundfile as sf
|
| 55 |
+
|
| 56 |
+
sf.write(str(path), x.astype(np.float32), sr)
|
| 57 |
+
except Exception:
|
| 58 |
+
from scipy.io.wavfile import write
|
| 59 |
+
|
| 60 |
+
write(str(path), sr, (np.clip(x, -1, 1) * 32767).astype(np.int16))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def safe_corr_np(a, b):
|
| 64 |
+
L = min(len(a), len(b))
|
| 65 |
+
if L <= 1:
|
| 66 |
+
return 0.0
|
| 67 |
+
a = a[:L].astype(np.float64)
|
| 68 |
+
b = b[:L].astype(np.float64)
|
| 69 |
+
a -= a.mean()
|
| 70 |
+
b -= b.mean()
|
| 71 |
+
da = a.std()
|
| 72 |
+
db = b.std()
|
| 73 |
+
if da <= 1e-12 or db <= 1e-12:
|
| 74 |
+
return 0.0
|
| 75 |
+
r = float((a * b).mean() / (da * db))
|
| 76 |
+
return max(-1.0, min(1.0, r))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def hungarian(cost):
|
| 80 |
+
try:
|
| 81 |
+
if _lsa is not None:
|
| 82 |
+
return _lsa(cost)
|
| 83 |
+
raise RuntimeError("scipy.optimize.linear_sum_assignment unavailable")
|
| 84 |
+
except Exception:
|
| 85 |
+
used = set()
|
| 86 |
+
rows, cols = [], []
|
| 87 |
+
for i in range(cost.shape[0]):
|
| 88 |
+
j = int(
|
| 89 |
+
np.argmin(
|
| 90 |
+
[
|
| 91 |
+
cost[i, k] if k not in used else 1e12
|
| 92 |
+
for k in range(cost.shape[1])
|
| 93 |
+
]
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
used.add(j)
|
| 97 |
+
rows.append(i)
|
| 98 |
+
cols.append(j)
|
| 99 |
+
return np.asarray(rows), np.asarray(cols)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class GPUWorkDistributor:
|
| 103 |
+
|
| 104 |
+
def __init__(self, max_gpus=None):
|
| 105 |
+
ngpu = get_gpu_count(max_gpus)
|
| 106 |
+
self.gpu_locks = [threading.Lock() for _ in range(max(1, min(ngpu, 2)))]
|
| 107 |
+
self.gpu_load = [0 for _ in range(max(1, min(ngpu, 2)))]
|
| 108 |
+
self.ngpu = ngpu
|
| 109 |
+
|
| 110 |
+
def get_least_loaded_gpu(self):
|
| 111 |
+
return int(np.argmin(self.gpu_load))
|
| 112 |
+
|
| 113 |
+
def execute_on_gpu(self, func, *args, **kwargs):
|
| 114 |
+
if self.ngpu == 0:
|
| 115 |
+
kwargs.pop("device", None)
|
| 116 |
+
return func(*args, **kwargs)
|
| 117 |
+
gid = self.get_least_loaded_gpu()
|
| 118 |
+
with self.gpu_locks[gid]:
|
| 119 |
+
self.gpu_load[gid] += 1
|
| 120 |
+
try:
|
| 121 |
+
with torch.cuda.device(gid):
|
| 122 |
+
kwargs["device"] = f"cuda:{gid}"
|
| 123 |
+
result = func(*args, **kwargs)
|
| 124 |
+
# Clear cache after execution
|
| 125 |
+
torch.cuda.empty_cache()
|
| 126 |
+
return result
|
| 127 |
+
finally:
|
| 128 |
+
self.gpu_load[gid] -= 1
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@dataclass
|
| 132 |
+
class Mixture:
|
| 133 |
+
|
| 134 |
+
mixture_id: str
|
| 135 |
+
refs: list[Path]
|
| 136 |
+
systems: dict[str, list[Path]]
|
| 137 |
+
speaker_ids: list[str]
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def canonicalize_mixtures(mixtures, systems=None):
|
| 141 |
+
canon = []
|
| 142 |
+
if not mixtures:
|
| 143 |
+
return canon
|
| 144 |
+
|
| 145 |
+
def as_paths(seq):
|
| 146 |
+
return [p if isinstance(p, Path) else Path(str(p)) for p in seq]
|
| 147 |
+
|
| 148 |
+
def speaker_id_from_ref(ref_path, idx, mixture_id):
|
| 149 |
+
stem = (ref_path.stem or "").strip()
|
| 150 |
+
if not stem:
|
| 151 |
+
stem = f"spk{idx:02d}"
|
| 152 |
+
return f"{mixture_id}__{stem}"
|
| 153 |
+
|
| 154 |
+
if isinstance(mixtures[0], dict):
|
| 155 |
+
for m in mixtures:
|
| 156 |
+
mid = str(m.get("mixture_id") or m.get("id") or "").strip()
|
| 157 |
+
if not mid:
|
| 158 |
+
raise ValueError("Each mixture must include 'mixture_id'.")
|
| 159 |
+
refs = as_paths(m.get("references", []))
|
| 160 |
+
if not refs:
|
| 161 |
+
raise ValueError(f"Mixture {mid}: 'references' must be non-empty.")
|
| 162 |
+
sysmap = {}
|
| 163 |
+
if isinstance(m.get("systems"), dict):
|
| 164 |
+
for algo, outs in m["systems"].items():
|
| 165 |
+
sysmap[str(algo)] = as_paths(outs)
|
| 166 |
+
spk_ids = [speaker_id_from_ref(r, i, mid) for i, r in enumerate(refs)]
|
| 167 |
+
canon.append(Mixture(mid, refs, sysmap, spk_ids))
|
| 168 |
+
return canon
|
| 169 |
+
|
| 170 |
+
if isinstance(mixtures[0], list):
|
| 171 |
+
for i, group in enumerate(mixtures):
|
| 172 |
+
mid = f"mix_{i:03d}"
|
| 173 |
+
refs, spk_ids = [], []
|
| 174 |
+
for d in group:
|
| 175 |
+
if not isinstance(d, dict) or "ref" not in d or "id" not in d:
|
| 176 |
+
raise ValueError(
|
| 177 |
+
"Legacy mixtures expect dicts with 'id' and 'ref'."
|
| 178 |
+
)
|
| 179 |
+
rp = Path(d["ref"])
|
| 180 |
+
refs.append(rp)
|
| 181 |
+
spk_ids.append(f"{mid}__{str(d['id']).strip()}")
|
| 182 |
+
sysmap = {}
|
| 183 |
+
if systems:
|
| 184 |
+
for algo, per_mix in systems.items():
|
| 185 |
+
if mid in per_mix:
|
| 186 |
+
sysmap[algo] = as_paths(per_mix[mid])
|
| 187 |
+
canon.append(Mixture(mid, refs, sysmap, spk_ids))
|
| 188 |
+
return canon
|
| 189 |
+
|
| 190 |
+
raise ValueError("Unsupported 'mixtures' format.")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def random_misalign(sig, sr, max_ms, mode="single", rng=None):
|
| 194 |
+
import random
|
| 195 |
+
|
| 196 |
+
if rng is None:
|
| 197 |
+
rng = random
|
| 198 |
+
max_samples = int(sr * max_ms / 1000)
|
| 199 |
+
if max_samples == 0:
|
| 200 |
+
return sig
|
| 201 |
+
shift = (
|
| 202 |
+
rng.randint(-max_samples, max_samples) if mode == "range" else int(max_samples)
|
| 203 |
+
)
|
| 204 |
+
if shift == 0:
|
| 205 |
+
return sig
|
| 206 |
+
if isinstance(sig, torch.Tensor):
|
| 207 |
+
z = torch.zeros(abs(shift), dtype=sig.dtype, device=sig.device)
|
| 208 |
+
return (
|
| 209 |
+
torch.cat([z, sig[:-shift]]) if shift > 0 else torch.cat([sig[-shift:], z])
|
| 210 |
+
)
|
| 211 |
+
else:
|
| 212 |
+
z = np.zeros(abs(shift), dtype=sig.dtype)
|
| 213 |
+
return (
|
| 214 |
+
np.concatenate([z, sig[:-shift]])
|
| 215 |
+
if shift > 0
|
| 216 |
+
else np.concatenate([sig[-shift:], z])
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def safe_cov_torch(X):
|
| 221 |
+
Xc = X - X.mean(dim=0, keepdim=True)
|
| 222 |
+
cov = Xc.T @ Xc / (Xc.shape[0] - 1)
|
| 223 |
+
if torch.linalg.matrix_rank(cov) < cov.shape[0]:
|
| 224 |
+
cov += torch.eye(cov.shape[0], device=cov.device) * 1e-6
|
| 225 |
+
return cov
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def mahalanobis_torch(x, mu, inv):
|
| 229 |
+
diff = x - mu
|
| 230 |
+
diff_T = diff.transpose(-1, -2) if diff.ndim >= 2 else diff
|
| 231 |
+
return torch.sqrt(diff @ inv @ diff_T + 1e-6)
|