| |
| """Utilities for strict log-mel statistic prototype routing.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import math |
| import os |
| import random |
| import sys |
| from pathlib import Path |
| from typing import Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| FEATURE_TYPE = "mean_std" |
| LEGACY_FEATURE_ALIASES = { |
| "logmel_mean_std_128": "mean_std", |
| "mean_std_128": "mean_std", |
| } |
| FEATURE_TYPES = ( |
| "mean_only", |
| "std_only", |
| "mean_std", |
| "mean_std_delta", |
| "mean_std_band_energy", |
| "mean_std_extended", |
| ) |
| PROTOTYPE_MODES = ("global", "class_balanced", "class_conditional") |
| COVARIANCE_MODES = ("full_cov_shrinkage", "diag_cov") |
| SELECTOR_MODES = ( |
| "entropy", |
| "mahalanobis_proto", |
| "hard_mahalanobis", |
| "gaussian_nll_proto", |
| "mahalanobis_zscore_proto", |
| "soft_mahalanobis_proto", |
| "hybrid_maha_entropy", |
| "maha_margin_fallback", |
| ) |
|
|
| CLASS_LABELS = [ |
| "alarm", |
| "baby_cry", |
| "dog_bark", |
| "engine", |
| "fire", |
| "footsteps", |
| "knocking", |
| "telephone_ringing", |
| "piano", |
| "speech", |
| ] |
| LABEL_TO_INDEX = {label: index for index, label in enumerate(CLASS_LABELS)} |
| INDEX_TO_LABEL = {index: label for label, index in LABEL_TO_INDEX.items()} |
| DOMAIN_TO_TASK = {"D1": 0, "D2": 1, "D3": 2} |
| TASK_TO_DOMAIN = {task: domain for domain, task in DOMAIN_TO_TASK.items()} |
|
|
|
|
| def repo_root() -> Path: |
| return Path(os.environ.get("TASK7_REPO", Path(__file__).resolve().parents[1])).resolve() |
|
|
|
|
| def setup_repo_imports(root: Optional[Path] = None) -> Path: |
| root = (root or repo_root()).resolve() |
| for subdir in ("baseline", "utils"): |
| path = str(root / subdir) |
| if path not in sys.path: |
| sys.path.insert(0, path) |
| return root |
|
|
|
|
| def set_reproducible_seed(seed: int) -> None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.benchmark = False |
| torch.backends.cudnn.deterministic = True |
|
|
|
|
| def canonical_feature_type(feature_type: Optional[str]) -> str: |
| value = (feature_type or FEATURE_TYPE).strip() |
| value = LEGACY_FEATURE_ALIASES.get(value, value) |
| if value not in FEATURE_TYPES: |
| raise ValueError(f"Unknown feature_type {feature_type!r}; expected one of {FEATURE_TYPES}") |
| return value |
|
|
|
|
| def resolve_audio_path(filename: str, data_root: str | Path) -> Path: |
| path = Path(str(filename)) |
| if path.is_absolute(): |
| return path |
| return Path(data_root) / path |
|
|
|
|
| def pad_sequence(waveform: np.ndarray, min_len: int) -> np.ndarray: |
| if waveform.shape[0] < min_len: |
| return np.concatenate((waveform, np.zeros(min_len - waveform.shape[0], dtype=waveform.dtype))) |
| return waveform |
|
|
|
|
| def load_audio_waveform(filename: str, data_root: str | Path, sample_rate: int = 32000, clip_samples: int = 128000) -> np.ndarray: |
| import librosa |
|
|
| path = resolve_audio_path(filename, data_root) |
| waveform, _ = librosa.core.load(str(path), sr=sample_rate, mono=True) |
| return pad_sequence(waveform, clip_samples).astype(np.float32, copy=False) |
|
|
|
|
| def collate_waveforms(batch: Sequence[Tuple[np.ndarray, str]]) -> Tuple[torch.Tensor, List[str]]: |
| waveforms, filenames = zip(*batch) |
| max_len = max(wave.shape[0] for wave in waveforms) |
| padded = [] |
| for wave in waveforms: |
| if wave.shape[0] < max_len: |
| wave = np.concatenate((wave, np.zeros(max_len - wave.shape[0], dtype=wave.dtype))) |
| padded.append(wave) |
| return torch.from_numpy(np.stack(padded, axis=0)), list(filenames) |
|
|
|
|
| def make_tta_crops(waveform: torch.Tensor, tta_crops: int, crop_samples: int) -> torch.Tensor: |
| """Return 1, 3, or 4 independent crops from one waveform tensor.""" |
|
|
| if tta_crops not in (1, 3, 4): |
| raise ValueError("--tta_crops must be one of 1, 3, 4") |
| wave = waveform.detach().float().view(-1) |
| if tta_crops == 1: |
| return wave.unsqueeze(0) |
| if wave.numel() < crop_samples: |
| wave = torch.cat([wave, torch.zeros(crop_samples - wave.numel(), dtype=wave.dtype, device=wave.device)]) |
| if wave.numel() <= crop_samples: |
| front = wave[:crop_samples] |
| middle = front |
| back = front |
| else: |
| last = wave.numel() - crop_samples |
| front = wave[:crop_samples] |
| middle_start = last // 2 |
| middle = wave[middle_start:middle_start + crop_samples] |
| back = wave[last:last + crop_samples] |
| crops = [front, middle, back] |
| if tta_crops == 4: |
| crops.append(wave) |
| max_len = max(crop.numel() for crop in crops) |
| padded = [] |
| for crop in crops: |
| if crop.numel() < max_len: |
| crop = torch.cat([crop, torch.zeros(max_len - crop.numel(), dtype=crop.dtype, device=crop.device)]) |
| padded.append(crop) |
| return torch.stack(padded, dim=0) |
|
|
|
|
| def _safe_std(x: torch.Tensor, dim: int) -> torch.Tensor: |
| return torch.std(x, dim=dim, unbiased=x.shape[dim] > 1) |
|
|
|
|
| def extract_logmel_tensor(model: torch.nn.Module, inputs: torch.Tensor) -> torch.Tensor: |
| x = model.spectrogram_extractor(inputs) |
| return model.logmel_extractor(x).squeeze(1) |
|
|
|
|
| def _band_ratios(logmel: torch.Tensor) -> torch.Tensor: |
| band_slices = (slice(0, 21), slice(21, 43), slice(43, 64)) |
| band_log_energy = [] |
| for band_slice in band_slices: |
| band_log_energy.append(torch.logsumexp(logmel[:, :, band_slice], dim=(1, 2))) |
| stacked = torch.stack(band_log_energy, dim=1) |
| return torch.softmax(stacked, dim=1) |
|
|
|
|
| def _spectral_summary(logmel: torch.Tensor) -> torch.Tensor: |
| eps = 1e-8 |
| bins = torch.linspace(0.0, 1.0, logmel.shape[-1], device=logmel.device, dtype=logmel.dtype) |
| weights = torch.softmax(logmel, dim=-1) |
| centroid = torch.sum(weights * bins.view(1, 1, -1), dim=-1) |
| bandwidth = torch.sqrt(torch.sum(weights * (bins.view(1, 1, -1) - centroid.unsqueeze(-1)).pow(2), dim=-1) + eps) |
| cdf = torch.cumsum(weights, dim=-1) |
| rolloff_idx = torch.argmax((cdf >= 0.85).to(torch.int64), dim=-1).float() |
| rolloff = rolloff_idx / max(logmel.shape[-1] - 1, 1) |
| shifted = logmel - torch.amax(logmel, dim=-1, keepdim=True) |
| linear = torch.exp(torch.clamp(shifted, min=-80.0, max=20.0)) + eps |
| flatness = torch.exp(torch.mean(torch.log(linear), dim=-1)) / (torch.mean(linear, dim=-1) + eps) |
| values = [centroid, bandwidth, rolloff, flatness] |
| summary = [] |
| for value in values: |
| summary.append(torch.mean(value, dim=1)) |
| summary.append(_safe_std(value, dim=1)) |
| return torch.stack(summary, dim=1) |
|
|
|
|
| def extract_logmel_stat( |
| model: torch.nn.Module, |
| inputs: torch.Tensor, |
| feature_type: Optional[str] = None, |
| ) -> torch.Tensor: |
| feature_type = canonical_feature_type(feature_type) |
| logmel = extract_logmel_tensor(model, inputs) |
| mean = torch.mean(logmel, dim=1) |
| std = _safe_std(logmel, dim=1) |
| if feature_type == "mean_only": |
| return mean |
| if feature_type == "std_only": |
| return std |
|
|
| parts = [mean, std] |
| if feature_type in ("mean_std_delta", "mean_std_extended"): |
| if logmel.shape[1] > 1: |
| delta = logmel[:, 1:, :] - logmel[:, :-1, :] |
| else: |
| delta = torch.zeros_like(logmel) |
| parts.extend([torch.mean(delta, dim=1), _safe_std(delta, dim=1)]) |
| if feature_type in ("mean_std_band_energy", "mean_std_extended"): |
| parts.append(_band_ratios(logmel)) |
| if feature_type == "mean_std_extended": |
| parts.append(_spectral_summary(logmel)) |
| return torch.cat(parts, dim=1) |
|
|
|
|
| def _logdet_psd(cov: torch.Tensor, eps: float) -> torch.Tensor: |
| sign, logabsdet = torch.linalg.slogdet(cov) |
| if bool(sign > 0): |
| return logabsdet |
| eigvals = torch.linalg.eigvalsh(cov) |
| return torch.sum(torch.log(torch.clamp(eigvals, min=eps))) |
|
|
|
|
| def _fit_covariance( |
| stats: torch.Tensor, |
| center: torch.Tensor, |
| covariance_mode: str, |
| shrinkage: float, |
| eps: float, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| if covariance_mode not in COVARIANCE_MODES: |
| raise ValueError(f"Unknown covariance_mode {covariance_mode!r}; expected one of {COVARIANCE_MODES}") |
| centered = stats - center.unsqueeze(0) |
| if stats.shape[0] > 1: |
| cov = centered.t().matmul(centered) / (stats.shape[0] - 1) |
| else: |
| cov = torch.diag(torch.ones(stats.shape[1], dtype=stats.dtype)) |
| diag_values = torch.clamp(torch.diag(cov), min=eps) |
| if covariance_mode == "diag_cov": |
| cov = torch.diag(diag_values + eps) |
| else: |
| diag = torch.diag(diag_values) |
| cov = (1.0 - shrinkage) * cov + shrinkage * diag + eps * torch.eye(cov.shape[0], dtype=cov.dtype) |
| cov_inv = torch.linalg.pinv(cov) |
| logdet = _logdet_psd(cov, eps) |
| return cov, cov_inv, logdet |
|
|
|
|
| def _mahalanobis(stats: torch.Tensor, mean: torch.Tensor, cov_inv: torch.Tensor) -> torch.Tensor: |
| diff = stats - mean.unsqueeze(0) |
| return torch.sum((diff @ cov_inv) * diff, dim=1) |
|
|
|
|
| def _distance_stats(distances: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| mean = torch.mean(distances) |
| std = torch.std(distances, unbiased=distances.numel() > 1) |
| return mean, torch.clamp(std, min=1e-6) |
|
|
|
|
| def _extract_labels_from_batch(batch) -> Optional[torch.Tensor]: |
| if len(batch) < 2: |
| return None |
| target = batch[1] |
| if not torch.is_tensor(target): |
| target = torch.as_tensor(target) |
| if target.ndim > 1: |
| return torch.argmax(target.float(), dim=-1).long() |
| return target.long() |
|
|
|
|
| def _class_ids(labels: Optional[torch.Tensor]) -> List[int]: |
| if labels is None: |
| return [] |
| return sorted(int(value) for value in torch.unique(labels.detach().cpu()).tolist()) |
|
|
|
|
| def make_prototype_entry( |
| stats: torch.Tensor, |
| domain: str, |
| source_split: str, |
| shrinkage: float = 0.15, |
| diag_weight: Optional[float] = None, |
| eps: float = 1e-3, |
| labels: Optional[torch.Tensor] = None, |
| feature_type: str = FEATURE_TYPE, |
| prototype_mode: str = "global", |
| covariance_mode: str = "full_cov_shrinkage", |
| ) -> Dict[str, object]: |
| if stats.ndim != 2: |
| raise ValueError(f"Expected stats with shape (n, d), got {tuple(stats.shape)}") |
| if prototype_mode not in PROTOTYPE_MODES: |
| raise ValueError(f"Unknown prototype_mode {prototype_mode!r}; expected one of {PROTOTYPE_MODES}") |
| if diag_weight is not None: |
| covariance_mode = "diag_cov" if diag_weight >= 1.0 else covariance_mode |
| stats = stats.detach().cpu().float() |
| feature_type = canonical_feature_type(feature_type) |
| labels_cpu = labels.detach().cpu().long() if labels is not None else None |
|
|
| raw_mean = torch.mean(stats, dim=0) |
| raw_std = _safe_std(stats, dim=0) |
| class_means: Dict[int, torch.Tensor] = {} |
| class_counts: Dict[int, int] = {} |
| if labels_cpu is not None: |
| for class_id in _class_ids(labels_cpu): |
| mask = labels_cpu == class_id |
| if int(mask.sum()) == 0: |
| continue |
| class_means[class_id] = torch.mean(stats[mask], dim=0) |
| class_counts[class_id] = int(mask.sum()) |
|
|
| if prototype_mode == "class_balanced" and class_means: |
| mean = torch.stack([class_means[key] for key in sorted(class_means)], dim=0).mean(dim=0) |
| else: |
| mean = raw_mean |
|
|
| cov, cov_inv, logdet = _fit_covariance(stats, mean, covariance_mode, shrinkage, eps) |
| distances = _mahalanobis(stats, mean, cov_inv) |
| dist_mean, dist_std = _distance_stats(distances) |
|
|
| entry: Dict[str, object] = { |
| "domain": domain, |
| "task_index": DOMAIN_TO_TASK[domain], |
| "feature_type": feature_type, |
| "prototype_mode": prototype_mode, |
| "covariance_mode": covariance_mode, |
| "mean": mean, |
| "raw_mean": raw_mean, |
| "std": raw_std, |
| "cov": cov, |
| "cov_inv": cov_inv, |
| "logdet": logdet, |
| "dist_mean": dist_mean, |
| "dist_std": dist_std, |
| "n": int(stats.shape[0]), |
| "dim": int(stats.shape[1]), |
| "source_split": source_split, |
| "cov_shrinkage": float(shrinkage), |
| "cov_diag_weight": float(1.0 if covariance_mode == "diag_cov" else shrinkage), |
| "cov_eps": float(eps), |
| } |
|
|
| if class_means: |
| counts = torch.zeros(len(CLASS_LABELS), dtype=torch.long) |
| means = torch.zeros(len(CLASS_LABELS), stats.shape[1], dtype=torch.float32) |
| covs = torch.zeros(len(CLASS_LABELS), stats.shape[1], stats.shape[1], dtype=torch.float32) |
| cov_invs = torch.zeros_like(covs) |
| logdets = torch.zeros(len(CLASS_LABELS), dtype=torch.float32) |
| dist_means = torch.zeros(len(CLASS_LABELS), dtype=torch.float32) |
| dist_stds = torch.ones(len(CLASS_LABELS), dtype=torch.float32) |
| available = torch.zeros(len(CLASS_LABELS), dtype=torch.bool) |
| for class_id, class_mean in class_means.items(): |
| mask = labels_cpu == class_id |
| class_stats = stats[mask] |
| class_cov, class_cov_inv, class_logdet = _fit_covariance( |
| class_stats, |
| class_mean, |
| covariance_mode, |
| shrinkage, |
| eps, |
| ) |
| class_dist = _mahalanobis(class_stats, class_mean, class_cov_inv) |
| class_dist_mean, class_dist_std = _distance_stats(class_dist) |
| counts[class_id] = class_counts[class_id] |
| means[class_id] = class_mean |
| covs[class_id] = class_cov |
| cov_invs[class_id] = class_cov_inv |
| logdets[class_id] = class_logdet |
| dist_means[class_id] = class_dist_mean |
| dist_stds[class_id] = class_dist_std |
| available[class_id] = True |
| entry.update({ |
| "class_counts": counts, |
| "class_means": means, |
| "class_covs": covs, |
| "class_cov_invs": cov_invs, |
| "class_logdets": logdets, |
| "class_dist_means": dist_means, |
| "class_dist_stds": dist_stds, |
| "class_available": available, |
| }) |
| return entry |
|
|
|
|
| def compute_domain_prototype( |
| model: torch.nn.Module, |
| loader: Iterable, |
| domain: str, |
| device: str | torch.device, |
| source_split: str, |
| shrinkage: float = 0.15, |
| diag_weight: Optional[float] = None, |
| eps: float = 1e-3, |
| feature_type: str = FEATURE_TYPE, |
| prototype_mode: str = "global", |
| covariance_mode: str = "full_cov_shrinkage", |
| ) -> Dict[str, object]: |
| model.eval() |
| stats = [] |
| labels = [] |
| for batch in loader: |
| audio = batch[0] |
| batch_labels = _extract_labels_from_batch(batch) |
| audio = audio.float().to(device) |
| with torch.no_grad(): |
| stats.append(extract_logmel_stat(model, audio, feature_type=feature_type).detach().cpu()) |
| if batch_labels is not None: |
| labels.append(batch_labels.detach().cpu()) |
| if not stats: |
| raise RuntimeError(f"No samples available for {domain} prototype") |
| label_tensor = torch.cat(labels, dim=0) if labels else None |
| return make_prototype_entry( |
| torch.cat(stats, dim=0), |
| domain=domain, |
| source_split=source_split, |
| shrinkage=shrinkage, |
| diag_weight=diag_weight, |
| eps=eps, |
| labels=label_tensor, |
| feature_type=feature_type, |
| prototype_mode=prototype_mode, |
| covariance_mode=covariance_mode, |
| ) |
|
|
|
|
| def normalize_domain_key(key: object) -> Optional[str]: |
| if isinstance(key, str): |
| upper = key.upper() |
| if upper in DOMAIN_TO_TASK: |
| return upper |
| if upper.isdigit(): |
| return TASK_TO_DOMAIN.get(int(upper)) |
| return None |
| if isinstance(key, int): |
| return TASK_TO_DOMAIN.get(key) |
| return None |
|
|
|
|
| def normalize_prototypes(obj: Mapping) -> Dict[str, Dict[str, object]]: |
| out: Dict[str, Dict[str, object]] = {} |
| if "domains" in obj and isinstance(obj["domains"], Mapping): |
| obj = obj["domains"] |
| tensor_keys = { |
| "mean", |
| "raw_mean", |
| "std", |
| "cov", |
| "cov_inv", |
| "logdet", |
| "dist_mean", |
| "dist_std", |
| "class_counts", |
| "class_means", |
| "class_covs", |
| "class_cov_invs", |
| "class_logdets", |
| "class_dist_means", |
| "class_dist_stds", |
| "class_available", |
| } |
| for key, value in obj.items(): |
| if key in {"metadata", "meta", "domains"}: |
| continue |
| domain = normalize_domain_key(key) |
| if domain is None or not isinstance(value, Mapping): |
| continue |
| entry = dict(value) |
| entry["domain"] = domain |
| entry["task_index"] = int(entry.get("task_index", DOMAIN_TO_TASK[domain])) |
| entry["feature_type"] = canonical_feature_type(str(entry.get("feature_type", FEATURE_TYPE))) |
| entry["prototype_mode"] = entry.get("prototype_mode", "global") |
| entry["covariance_mode"] = entry.get("covariance_mode", "full_cov_shrinkage") |
| entry["n"] = int(entry.get("n", entry.get("count", -1))) |
| for tensor_key in tensor_keys: |
| if tensor_key in entry and not torch.is_tensor(entry[tensor_key]): |
| entry[tensor_key] = torch.as_tensor(entry[tensor_key]) |
| if tensor_key in entry and torch.is_tensor(entry[tensor_key]) and tensor_key != "class_available": |
| entry[tensor_key] = entry[tensor_key].float() |
| if "mean" not in entry or "cov_inv" not in entry: |
| raise ValueError(f"Prototype for {domain} must contain mean and cov_inv") |
| if "logdet" not in entry and "cov" in entry: |
| entry["logdet"] = _logdet_psd(entry["cov"].float(), float(entry.get("cov_eps", 1e-3))) |
| if "dist_mean" not in entry: |
| entry["dist_mean"] = torch.tensor(0.0) |
| if "dist_std" not in entry: |
| entry["dist_std"] = torch.tensor(1.0) |
| entry["dim"] = int(entry.get("dim", entry["mean"].numel())) |
| out[domain] = entry |
| return out |
|
|
|
|
| def load_prototypes(path: str | Path, map_location: str | torch.device = "cpu") -> Dict[str, Dict[str, object]]: |
| obj = torch.load(str(path), map_location=map_location) |
| if not isinstance(obj, Mapping): |
| raise TypeError(f"Prototype file must contain a mapping, got {type(obj)!r}") |
| return normalize_prototypes(obj) |
|
|
|
|
| def save_prototypes( |
| prototypes: Mapping[str, Mapping[str, object]], |
| path: str | Path, |
| metadata: Optional[Mapping[str, object]] = None, |
| ) -> None: |
| normalized = normalize_prototypes(prototypes) |
| feature_type = next(iter(normalized.values())).get("feature_type", FEATURE_TYPE) if normalized else FEATURE_TYPE |
| payload: MutableMapping[str, object] = { |
| "meta": { |
| "feature_type": feature_type, |
| "sample_rate": 32000, |
| "n_mels": 64, |
| "note": "D1 prototype is not built unless a D1 train split is available.", |
| }, |
| "domains": dict(normalized), |
| } |
| if metadata: |
| payload["meta"].update(dict(metadata)) |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save(payload, str(path)) |
|
|
|
|
| def prototype_feature_type(prototypes: Mapping[str, Mapping[str, object]], fallback: str = FEATURE_TYPE) -> str: |
| for entry in prototypes.values(): |
| return canonical_feature_type(str(entry.get("feature_type", fallback))) |
| return canonical_feature_type(fallback) |
|
|
|
|
| def candidate_domain_list( |
| prototypes: Mapping[str, Mapping[str, object]], |
| candidate_domains: Optional[Sequence[str]] = None, |
| ) -> List[str]: |
| if candidate_domains is not None: |
| return [domain for domain in candidate_domains if domain in prototypes] |
| domains = [domain for domain in ("D1", "D2", "D3") if domain in prototypes] |
| if "D1" not in prototypes: |
| domains = [domain for domain in ("D2", "D3") if domain in prototypes] |
| return domains |
|
|
|
|
| def _entry_distance( |
| stats: torch.Tensor, |
| entry: Mapping[str, object], |
| class_ids: Optional[torch.Tensor] = None, |
| class_conditional: bool = False, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| if class_conditional and class_ids is not None and "class_means" in entry and "class_cov_invs" in entry: |
| class_ids = class_ids.to(stats.device).long() |
| class_means = entry["class_means"].to(stats.device) |
| class_cov_invs = entry["class_cov_invs"].to(stats.device) |
| class_logdets = entry.get("class_logdets", torch.zeros(class_means.shape[0])).to(stats.device) |
| class_dist_means = entry.get("class_dist_means", torch.zeros(class_means.shape[0])).to(stats.device) |
| class_dist_stds = torch.clamp(entry.get("class_dist_stds", torch.ones(class_means.shape[0])).to(stats.device), min=1e-6) |
| means = class_means[class_ids] |
| cov_invs = class_cov_invs[class_ids] |
| diff = stats - means |
| distances = torch.einsum("bi,bij,bj->b", diff, cov_invs, diff) |
| return distances, class_logdets[class_ids], class_dist_means[class_ids], class_dist_stds[class_ids] |
|
|
| mean = entry["mean"].to(stats.device) |
| cov_inv = entry["cov_inv"].to(stats.device) |
| distances = _mahalanobis(stats, mean, cov_inv) |
| logdet = entry.get("logdet", torch.tensor(0.0)).to(stats.device).expand_as(distances) |
| dist_mean = entry.get("dist_mean", torch.tensor(0.0)).to(stats.device).expand_as(distances) |
| dist_std = torch.clamp(entry.get("dist_std", torch.tensor(1.0)).to(stats.device), min=1e-6).expand_as(distances) |
| return distances, logdet, dist_mean, dist_std |
|
|
|
|
| def class_conditional_distance_matrix(stats: torch.Tensor, entry: Mapping[str, object]) -> torch.Tensor: |
| if "class_means" not in entry or "class_cov_invs" not in entry: |
| raise ValueError("class_conditional prototype requires class_means and class_cov_invs") |
| class_means = entry["class_means"].to(stats.device) |
| class_cov_invs = entry["class_cov_invs"].to(stats.device) |
| diff = stats.unsqueeze(1) - class_means.unsqueeze(0) |
| return torch.einsum("bci,cij,bcj->bc", diff, class_cov_invs, diff) |
|
|
|
|
| def prototype_scores( |
| stats: torch.Tensor, |
| prototypes: Mapping[str, Mapping[str, object]], |
| candidate_domains: Optional[Sequence[str]] = None, |
| score_type: str = "mahalanobis", |
| class_ids: Optional[torch.Tensor] = None, |
| class_conditional: bool = False, |
| ) -> Tuple[torch.Tensor, List[str]]: |
| domains = candidate_domain_list(prototypes, candidate_domains) |
| if not domains: |
| raise RuntimeError("No candidate domain prototypes available") |
| scores = [] |
| for domain in domains: |
| distances, logdet, dist_mean, dist_std = _entry_distance( |
| stats, |
| prototypes[domain], |
| class_ids=class_ids, |
| class_conditional=class_conditional, |
| ) |
| if score_type == "mahalanobis": |
| score = distances |
| elif score_type == "gaussian_nll": |
| score = 0.5 * distances + 0.5 * logdet |
| elif score_type == "zscore": |
| score = (distances - dist_mean) / dist_std |
| else: |
| raise ValueError(f"Unknown prototype score_type: {score_type}") |
| scores.append(score) |
| return torch.stack(scores, dim=1), domains |
|
|
|
|
| def prototype_distances( |
| stats: torch.Tensor, |
| prototypes: Mapping[str, Mapping[str, object]], |
| candidate_domains: Optional[Sequence[str]] = None, |
| ) -> Tuple[torch.Tensor, List[str]]: |
| return prototype_scores(stats, prototypes, candidate_domains, score_type="mahalanobis") |
|
|
|
|
| def select_tasks_by_mahalanobis( |
| model: torch.nn.Module, |
| inputs: torch.Tensor, |
| prototypes: Mapping[str, Mapping[str, object]], |
| candidate_domains: Optional[Sequence[str]] = None, |
| ) -> torch.Tensor: |
| feature_type = prototype_feature_type(prototypes) |
| stats = extract_logmel_stat(model, inputs, feature_type=feature_type) |
| distances, domains = prototype_distances(stats, prototypes, candidate_domains) |
| indices = torch.argmin(distances, dim=1) |
| task_ids = torch.tensor([DOMAIN_TO_TASK[d] for d in domains], device=inputs.device, dtype=torch.long) |
| return task_ids[indices] |
|
|
|
|
| def task_probabilities(model: torch.nn.Module, inputs: torch.Tensor, tasks: Sequence[int]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| eps = sys.float_info.min |
| logits = [] |
| probs = [] |
| for task_id in tasks: |
| task_logits = model(inputs, int(task_id)) |
| logits.append(task_logits) |
| probs.append(torch.softmax(task_logits, dim=1)) |
| logits_tensor = torch.stack(logits, dim=1) |
| probs_tensor = torch.stack(probs, dim=1) |
| entropy = -torch.sum(probs_tensor * torch.log(probs_tensor + eps), dim=-1) |
| return probs_tensor, entropy, logits_tensor |
|
|
|
|
| def select_tasks_by_entropy(model: torch.nn.Module, inputs: torch.Tensor, tasks: Sequence[int]) -> torch.Tensor: |
| _, entropy, _ = task_probabilities(model, inputs, tasks) |
| indices = torch.argmin(entropy, dim=1) |
| task_tensor = torch.tensor(list(tasks), device=inputs.device, dtype=torch.long) |
| return task_tensor[indices] |
|
|
|
|
| def _gather_by_task(probs: torch.Tensor, tasks: Sequence[int], selected_tasks: torch.Tensor) -> torch.Tensor: |
| task_tensor = torch.tensor(list(tasks), device=selected_tasks.device, dtype=torch.long) |
| out = [] |
| for row_idx, task_id in enumerate(selected_tasks.view(-1)): |
| match = torch.nonzero(task_tensor == task_id, as_tuple=False) |
| if match.numel() == 0: |
| raise ValueError(f"Selected task {int(task_id)} is not in available task list {tasks}") |
| out.append(probs[row_idx, int(match[0, 0])]) |
| return torch.stack(out, dim=0) |
|
|
|
|
| def _task_ids_for_domains(domains: Sequence[str], device: torch.device) -> torch.Tensor: |
| return torch.tensor([DOMAIN_TO_TASK[domain] for domain in domains], device=device, dtype=torch.long) |
|
|
|
|
| def _mean_candidate_class(probs: torch.Tensor) -> torch.Tensor: |
| return torch.argmax(torch.mean(probs, dim=1), dim=1) |
|
|
|
|
| def _apply_d1_fallback( |
| selected_tasks: torch.Tensor, |
| selected_probs: torch.Tensor, |
| all_tasks: Sequence[int], |
| all_probs: torch.Tensor, |
| all_entropy: torch.Tensor, |
| z_scores: torch.Tensor, |
| d1_fallback: str, |
| d1_z_threshold: float, |
| d1_entropy_margin: float, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| fallback_mask = torch.zeros_like(selected_tasks, dtype=torch.bool) |
| if d1_fallback == "none" or 0 not in all_tasks: |
| return selected_tasks, selected_probs, fallback_mask |
| task_tensor = torch.tensor(list(all_tasks), device=selected_tasks.device, dtype=torch.long) |
| d1_col = int(torch.nonzero(task_tensor == 0, as_tuple=False)[0, 0]) |
| d1_probs = all_probs[:, d1_col, :] |
| d1_entropy = all_entropy[:, d1_col] |
| other_cols = [idx for idx, task in enumerate(all_tasks) if task != 0] |
| min_other_entropy = torch.min(all_entropy[:, other_cols], dim=1).values |
| if d1_fallback == "entropy": |
| entropy_tasks = task_tensor[torch.argmin(all_entropy, dim=1)] |
| fallback_mask = entropy_tasks == 0 |
| elif d1_fallback == "conservative": |
| far_from_known = torch.min(z_scores, dim=1).values > d1_z_threshold |
| d1_confident = d1_entropy + d1_entropy_margin < min_other_entropy |
| fallback_mask = far_from_known & d1_confident |
| else: |
| raise ValueError(f"Unknown D1 fallback: {d1_fallback}") |
| selected_tasks = torch.where(fallback_mask, torch.zeros_like(selected_tasks), selected_tasks) |
| selected_probs = torch.where(fallback_mask.unsqueeze(1), d1_probs, selected_probs) |
| return selected_tasks, selected_probs, fallback_mask |
|
|
|
|
| def predict_with_selector( |
| model: torch.nn.Module, |
| inputs: torch.Tensor, |
| prototypes: Mapping[str, Mapping[str, object]], |
| selector: str = "mahalanobis_proto", |
| candidate_domains: Optional[Sequence[str]] = None, |
| d1_fallback: str = "none", |
| tau: float = 1.0, |
| alpha: float = 1.0, |
| beta: float = 1.0, |
| margin_threshold: float = 0.0, |
| margin_fallback: str = "entropy", |
| d1_z_threshold: float = 2.0, |
| d1_entropy_margin: float = 0.05, |
| feature_type: Optional[str] = None, |
| class_conditional_strategy: str = "top1", |
| class_distance_lambda: float = 1.0, |
| return_details: bool = False, |
| ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, object]] | torch.Tensor: |
| selector = "mahalanobis_proto" if selector == "hard_mahalanobis" else selector |
| if selector not in SELECTOR_MODES: |
| raise ValueError(f"Unknown selector {selector!r}; expected one of {SELECTOR_MODES}") |
| if class_conditional_strategy not in ("top1", "joint"): |
| raise ValueError("--class_conditional_strategy must be top1 or joint") |
| if margin_fallback not in ("entropy", "soft"): |
| raise ValueError("--margin_fallback must be entropy or soft") |
|
|
| device = inputs.device |
| if selector == "entropy": |
| all_tasks = [0, 1, 2] |
| all_probs, all_entropy, _ = task_probabilities(model, inputs, all_tasks) |
| selected_tasks = torch.tensor(all_tasks, device=device, dtype=torch.long)[torch.argmin(all_entropy, dim=1)] |
| selected_probs = _gather_by_task(all_probs, all_tasks, selected_tasks) |
| details = { |
| "selector": selector, |
| "selected_tasks": selected_tasks.detach().cpu(), |
| "selected_domains": [TASK_TO_DOMAIN[int(task)] for task in selected_tasks.detach().cpu().tolist()], |
| "entropy": all_entropy.detach().cpu(), |
| "tasks": all_tasks, |
| } |
| return (selected_probs, selected_tasks, details) if return_details else selected_probs |
|
|
| domains = candidate_domain_list(prototypes, candidate_domains) |
| tasks = [DOMAIN_TO_TASK[domain] for domain in domains] |
| if not domains: |
| return predict_with_selector( |
| model, |
| inputs, |
| prototypes, |
| selector="entropy", |
| d1_fallback="none", |
| return_details=return_details, |
| ) |
| feature_type = canonical_feature_type(feature_type or prototype_feature_type(prototypes)) |
| stats = extract_logmel_stat(model, inputs, feature_type=feature_type) |
| probs, entropy, logits = task_probabilities(model, inputs, tasks) |
| class_conditional = any( |
| prototypes[domain].get("prototype_mode") == "class_conditional" and "class_means" in prototypes[domain] |
| for domain in domains |
| ) |
|
|
| if class_conditional and class_conditional_strategy == "joint": |
| best_scores = [] |
| for col_idx, domain in enumerate(domains): |
| dist_matrix = class_conditional_distance_matrix(stats, prototypes[domain]) |
| best_scores.append(logits[:, col_idx, :] - class_distance_lambda * dist_matrix) |
| joint_scores = torch.stack(best_scores, dim=1) |
| flat = torch.argmax(joint_scores.view(joint_scores.shape[0], -1), dim=1) |
| domain_indices = flat // len(CLASS_LABELS) |
| pred_indices = flat % len(CLASS_LABELS) |
| selected_tasks = _task_ids_for_domains(domains, device)[domain_indices] |
| selected_probs = torch.zeros(inputs.shape[0], len(CLASS_LABELS), device=device) |
| selected_probs.scatter_(1, pred_indices.unsqueeze(1), 1.0) |
| raw_scores, _ = prototype_scores(stats, prototypes, domains, score_type="mahalanobis") |
| z_scores, _ = prototype_scores(stats, prototypes, domains, score_type="zscore") |
| else: |
| class_ids = _mean_candidate_class(probs) if class_conditional else None |
| if selector == "gaussian_nll_proto": |
| scores, _ = prototype_scores(stats, prototypes, domains, score_type="gaussian_nll", class_ids=class_ids, class_conditional=class_conditional) |
| elif selector == "mahalanobis_zscore_proto": |
| scores, _ = prototype_scores(stats, prototypes, domains, score_type="zscore", class_ids=class_ids, class_conditional=class_conditional) |
| else: |
| scores, _ = prototype_scores(stats, prototypes, domains, score_type="mahalanobis", class_ids=class_ids, class_conditional=class_conditional) |
| raw_scores, _ = prototype_scores(stats, prototypes, domains, score_type="mahalanobis", class_ids=class_ids, class_conditional=class_conditional) |
| z_scores, _ = prototype_scores(stats, prototypes, domains, score_type="zscore", class_ids=class_ids, class_conditional=class_conditional) |
| task_ids = _task_ids_for_domains(domains, device) |
|
|
| if selector == "soft_mahalanobis_proto": |
| weights = torch.softmax(-float(tau) * raw_scores, dim=1) |
| selected_probs = torch.sum(weights.unsqueeze(-1) * probs, dim=1) |
| selected_tasks = task_ids[torch.argmax(weights, dim=1)] |
| elif selector == "hybrid_maha_entropy": |
| hybrid_scores = float(alpha) * z_scores + float(beta) * entropy |
| indices = torch.argmin(hybrid_scores, dim=1) |
| selected_tasks = task_ids[indices] |
| selected_probs = probs[torch.arange(inputs.shape[0], device=device), indices] |
| scores = hybrid_scores |
| elif selector == "maha_margin_fallback": |
| sorted_scores, sorted_indices = torch.sort(raw_scores, dim=1) |
| best = sorted_indices[:, 0] |
| second = sorted_scores[:, 1] if raw_scores.shape[1] > 1 else sorted_scores[:, 0] + float("inf") |
| margin = second - sorted_scores[:, 0] |
| confident = margin >= float(margin_threshold) |
| maha_probs = probs[torch.arange(inputs.shape[0], device=device), best] |
| maha_tasks = task_ids[best] |
| if margin_fallback == "entropy": |
| fallback_idx = torch.argmin(entropy, dim=1) |
| fallback_probs = probs[torch.arange(inputs.shape[0], device=device), fallback_idx] |
| fallback_tasks = task_ids[fallback_idx] |
| else: |
| weights = torch.softmax(-float(tau) * raw_scores, dim=1) |
| fallback_probs = torch.sum(weights.unsqueeze(-1) * probs, dim=1) |
| fallback_tasks = task_ids[torch.argmax(weights, dim=1)] |
| selected_probs = torch.where(confident.unsqueeze(1), maha_probs, fallback_probs) |
| selected_tasks = torch.where(confident, maha_tasks, fallback_tasks) |
| scores = raw_scores |
| else: |
| indices = torch.argmin(scores, dim=1) |
| selected_tasks = task_ids[indices] |
| selected_probs = probs[torch.arange(inputs.shape[0], device=device), indices] |
|
|
| all_entropy = None |
| if d1_fallback == "none": |
| d1_mask = torch.zeros_like(selected_tasks, dtype=torch.bool) |
| else: |
| all_tasks = [0, 1, 2] |
| all_probs, all_entropy, _ = task_probabilities(model, inputs, all_tasks) |
| selected_tasks, selected_probs, d1_mask = _apply_d1_fallback( |
| selected_tasks, |
| selected_probs, |
| all_tasks, |
| all_probs, |
| all_entropy, |
| z_scores, |
| d1_fallback=d1_fallback, |
| d1_z_threshold=d1_z_threshold, |
| d1_entropy_margin=d1_entropy_margin, |
| ) |
| details = { |
| "selector": selector, |
| "domains": domains, |
| "tasks": tasks, |
| "scores": scores.detach().cpu() if "scores" in locals() else raw_scores.detach().cpu(), |
| "raw_mahalanobis": raw_scores.detach().cpu(), |
| "zscore": z_scores.detach().cpu(), |
| "entropy": entropy.detach().cpu(), |
| "selected_tasks": selected_tasks.detach().cpu(), |
| "selected_domains": [TASK_TO_DOMAIN[int(task)] for task in selected_tasks.detach().cpu().tolist()], |
| "d1_fallback_mask": d1_mask.detach().cpu(), |
| } |
| if all_entropy is not None: |
| details["all_entropy"] = all_entropy.detach().cpu() |
| return (selected_probs, selected_tasks, details) if return_details else selected_probs |
|
|
|
|
| def select_tasks_with_d1_fallback( |
| model: torch.nn.Module, |
| inputs: torch.Tensor, |
| prototypes: Mapping[str, Mapping[str, object]], |
| d1_fallback: str = "entropy", |
| ) -> torch.Tensor: |
| _, task_ids, _ = predict_with_selector( |
| model, |
| inputs, |
| prototypes, |
| selector="mahalanobis_proto", |
| d1_fallback=d1_fallback, |
| return_details=True, |
| ) |
| return task_ids |
|
|
|
|
| def predict_with_selected_tasks(model: torch.nn.Module, inputs: torch.Tensor, task_ids: torch.Tensor) -> torch.Tensor: |
| logits = [] |
| for row_idx, task_id in enumerate(task_ids.view(-1).tolist()): |
| logits.append(model(inputs[row_idx:row_idx + 1], int(task_id))) |
| return torch.cat(logits, dim=0) |
|
|
|
|
| def manifest_json(data: Mapping[str, object]) -> str: |
| return json.dumps(data, indent=2, ensure_ascii=False, sort_keys=True) |
|
|