#!/usr/bin/env python3 """Utilities for strict log-mel statistic prototype routing.""" from __future__ import annotations import json import math import os import random import sys from pathlib import Path from typing import Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple import numpy as np import torch FEATURE_TYPE = "mean_std" LEGACY_FEATURE_ALIASES = { "logmel_mean_std_128": "mean_std", "mean_std_128": "mean_std", } FEATURE_TYPES = ( "mean_only", "std_only", "mean_std", "mean_std_delta", "mean_std_band_energy", "mean_std_extended", ) PROTOTYPE_MODES = ("global", "class_balanced", "class_conditional") COVARIANCE_MODES = ("full_cov_shrinkage", "diag_cov") SELECTOR_MODES = ( "entropy", "mahalanobis_proto", "hard_mahalanobis", "gaussian_nll_proto", "mahalanobis_zscore_proto", "soft_mahalanobis_proto", "hybrid_maha_entropy", "maha_margin_fallback", ) CLASS_LABELS = [ "alarm", "baby_cry", "dog_bark", "engine", "fire", "footsteps", "knocking", "telephone_ringing", "piano", "speech", ] LABEL_TO_INDEX = {label: index for index, label in enumerate(CLASS_LABELS)} INDEX_TO_LABEL = {index: label for label, index in LABEL_TO_INDEX.items()} DOMAIN_TO_TASK = {"D1": 0, "D2": 1, "D3": 2} TASK_TO_DOMAIN = {task: domain for domain, task in DOMAIN_TO_TASK.items()} def repo_root() -> Path: return Path(os.environ.get("TASK7_REPO", Path(__file__).resolve().parents[1])).resolve() def setup_repo_imports(root: Optional[Path] = None) -> Path: root = (root or repo_root()).resolve() for subdir in ("baseline", "utils"): path = str(root / subdir) if path not in sys.path: sys.path.insert(0, path) return root def set_reproducible_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True def canonical_feature_type(feature_type: Optional[str]) -> str: value = (feature_type or FEATURE_TYPE).strip() value = LEGACY_FEATURE_ALIASES.get(value, value) if value not in FEATURE_TYPES: raise ValueError(f"Unknown feature_type {feature_type!r}; expected one of {FEATURE_TYPES}") return value def resolve_audio_path(filename: str, data_root: str | Path) -> Path: path = Path(str(filename)) if path.is_absolute(): return path return Path(data_root) / path def pad_sequence(waveform: np.ndarray, min_len: int) -> np.ndarray: if waveform.shape[0] < min_len: return np.concatenate((waveform, np.zeros(min_len - waveform.shape[0], dtype=waveform.dtype))) return waveform def load_audio_waveform(filename: str, data_root: str | Path, sample_rate: int = 32000, clip_samples: int = 128000) -> np.ndarray: import librosa path = resolve_audio_path(filename, data_root) waveform, _ = librosa.core.load(str(path), sr=sample_rate, mono=True) return pad_sequence(waveform, clip_samples).astype(np.float32, copy=False) def collate_waveforms(batch: Sequence[Tuple[np.ndarray, str]]) -> Tuple[torch.Tensor, List[str]]: waveforms, filenames = zip(*batch) max_len = max(wave.shape[0] for wave in waveforms) padded = [] for wave in waveforms: if wave.shape[0] < max_len: wave = np.concatenate((wave, np.zeros(max_len - wave.shape[0], dtype=wave.dtype))) padded.append(wave) return torch.from_numpy(np.stack(padded, axis=0)), list(filenames) def make_tta_crops(waveform: torch.Tensor, tta_crops: int, crop_samples: int) -> torch.Tensor: """Return 1, 3, or 4 independent crops from one waveform tensor.""" if tta_crops not in (1, 3, 4): raise ValueError("--tta_crops must be one of 1, 3, 4") wave = waveform.detach().float().view(-1) if tta_crops == 1: return wave.unsqueeze(0) if wave.numel() < crop_samples: wave = torch.cat([wave, torch.zeros(crop_samples - wave.numel(), dtype=wave.dtype, device=wave.device)]) if wave.numel() <= crop_samples: front = wave[:crop_samples] middle = front back = front else: last = wave.numel() - crop_samples front = wave[:crop_samples] middle_start = last // 2 middle = wave[middle_start:middle_start + crop_samples] back = wave[last:last + crop_samples] crops = [front, middle, back] if tta_crops == 4: crops.append(wave) max_len = max(crop.numel() for crop in crops) padded = [] for crop in crops: if crop.numel() < max_len: crop = torch.cat([crop, torch.zeros(max_len - crop.numel(), dtype=crop.dtype, device=crop.device)]) padded.append(crop) return torch.stack(padded, dim=0) def _safe_std(x: torch.Tensor, dim: int) -> torch.Tensor: return torch.std(x, dim=dim, unbiased=x.shape[dim] > 1) def extract_logmel_tensor(model: torch.nn.Module, inputs: torch.Tensor) -> torch.Tensor: x = model.spectrogram_extractor(inputs) return model.logmel_extractor(x).squeeze(1) def _band_ratios(logmel: torch.Tensor) -> torch.Tensor: band_slices = (slice(0, 21), slice(21, 43), slice(43, 64)) band_log_energy = [] for band_slice in band_slices: band_log_energy.append(torch.logsumexp(logmel[:, :, band_slice], dim=(1, 2))) stacked = torch.stack(band_log_energy, dim=1) return torch.softmax(stacked, dim=1) def _spectral_summary(logmel: torch.Tensor) -> torch.Tensor: eps = 1e-8 bins = torch.linspace(0.0, 1.0, logmel.shape[-1], device=logmel.device, dtype=logmel.dtype) weights = torch.softmax(logmel, dim=-1) centroid = torch.sum(weights * bins.view(1, 1, -1), dim=-1) bandwidth = torch.sqrt(torch.sum(weights * (bins.view(1, 1, -1) - centroid.unsqueeze(-1)).pow(2), dim=-1) + eps) cdf = torch.cumsum(weights, dim=-1) rolloff_idx = torch.argmax((cdf >= 0.85).to(torch.int64), dim=-1).float() rolloff = rolloff_idx / max(logmel.shape[-1] - 1, 1) shifted = logmel - torch.amax(logmel, dim=-1, keepdim=True) linear = torch.exp(torch.clamp(shifted, min=-80.0, max=20.0)) + eps flatness = torch.exp(torch.mean(torch.log(linear), dim=-1)) / (torch.mean(linear, dim=-1) + eps) values = [centroid, bandwidth, rolloff, flatness] summary = [] for value in values: summary.append(torch.mean(value, dim=1)) summary.append(_safe_std(value, dim=1)) return torch.stack(summary, dim=1) def extract_logmel_stat( model: torch.nn.Module, inputs: torch.Tensor, feature_type: Optional[str] = None, ) -> torch.Tensor: feature_type = canonical_feature_type(feature_type) logmel = extract_logmel_tensor(model, inputs) mean = torch.mean(logmel, dim=1) std = _safe_std(logmel, dim=1) if feature_type == "mean_only": return mean if feature_type == "std_only": return std parts = [mean, std] if feature_type in ("mean_std_delta", "mean_std_extended"): if logmel.shape[1] > 1: delta = logmel[:, 1:, :] - logmel[:, :-1, :] else: delta = torch.zeros_like(logmel) parts.extend([torch.mean(delta, dim=1), _safe_std(delta, dim=1)]) if feature_type in ("mean_std_band_energy", "mean_std_extended"): parts.append(_band_ratios(logmel)) if feature_type == "mean_std_extended": parts.append(_spectral_summary(logmel)) return torch.cat(parts, dim=1) def _logdet_psd(cov: torch.Tensor, eps: float) -> torch.Tensor: sign, logabsdet = torch.linalg.slogdet(cov) if bool(sign > 0): return logabsdet eigvals = torch.linalg.eigvalsh(cov) return torch.sum(torch.log(torch.clamp(eigvals, min=eps))) def _fit_covariance( stats: torch.Tensor, center: torch.Tensor, covariance_mode: str, shrinkage: float, eps: float, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if covariance_mode not in COVARIANCE_MODES: raise ValueError(f"Unknown covariance_mode {covariance_mode!r}; expected one of {COVARIANCE_MODES}") centered = stats - center.unsqueeze(0) if stats.shape[0] > 1: cov = centered.t().matmul(centered) / (stats.shape[0] - 1) else: cov = torch.diag(torch.ones(stats.shape[1], dtype=stats.dtype)) diag_values = torch.clamp(torch.diag(cov), min=eps) if covariance_mode == "diag_cov": cov = torch.diag(diag_values + eps) else: diag = torch.diag(diag_values) cov = (1.0 - shrinkage) * cov + shrinkage * diag + eps * torch.eye(cov.shape[0], dtype=cov.dtype) cov_inv = torch.linalg.pinv(cov) logdet = _logdet_psd(cov, eps) return cov, cov_inv, logdet def _mahalanobis(stats: torch.Tensor, mean: torch.Tensor, cov_inv: torch.Tensor) -> torch.Tensor: diff = stats - mean.unsqueeze(0) return torch.sum((diff @ cov_inv) * diff, dim=1) def _distance_stats(distances: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: mean = torch.mean(distances) std = torch.std(distances, unbiased=distances.numel() > 1) return mean, torch.clamp(std, min=1e-6) def _extract_labels_from_batch(batch) -> Optional[torch.Tensor]: if len(batch) < 2: return None target = batch[1] if not torch.is_tensor(target): target = torch.as_tensor(target) if target.ndim > 1: return torch.argmax(target.float(), dim=-1).long() return target.long() def _class_ids(labels: Optional[torch.Tensor]) -> List[int]: if labels is None: return [] return sorted(int(value) for value in torch.unique(labels.detach().cpu()).tolist()) def make_prototype_entry( stats: torch.Tensor, domain: str, source_split: str, shrinkage: float = 0.15, diag_weight: Optional[float] = None, eps: float = 1e-3, labels: Optional[torch.Tensor] = None, feature_type: str = FEATURE_TYPE, prototype_mode: str = "global", covariance_mode: str = "full_cov_shrinkage", ) -> Dict[str, object]: if stats.ndim != 2: raise ValueError(f"Expected stats with shape (n, d), got {tuple(stats.shape)}") if prototype_mode not in PROTOTYPE_MODES: raise ValueError(f"Unknown prototype_mode {prototype_mode!r}; expected one of {PROTOTYPE_MODES}") if diag_weight is not None: covariance_mode = "diag_cov" if diag_weight >= 1.0 else covariance_mode stats = stats.detach().cpu().float() feature_type = canonical_feature_type(feature_type) labels_cpu = labels.detach().cpu().long() if labels is not None else None raw_mean = torch.mean(stats, dim=0) raw_std = _safe_std(stats, dim=0) class_means: Dict[int, torch.Tensor] = {} class_counts: Dict[int, int] = {} if labels_cpu is not None: for class_id in _class_ids(labels_cpu): mask = labels_cpu == class_id if int(mask.sum()) == 0: continue class_means[class_id] = torch.mean(stats[mask], dim=0) class_counts[class_id] = int(mask.sum()) if prototype_mode == "class_balanced" and class_means: mean = torch.stack([class_means[key] for key in sorted(class_means)], dim=0).mean(dim=0) else: mean = raw_mean cov, cov_inv, logdet = _fit_covariance(stats, mean, covariance_mode, shrinkage, eps) distances = _mahalanobis(stats, mean, cov_inv) dist_mean, dist_std = _distance_stats(distances) entry: Dict[str, object] = { "domain": domain, "task_index": DOMAIN_TO_TASK[domain], "feature_type": feature_type, "prototype_mode": prototype_mode, "covariance_mode": covariance_mode, "mean": mean, "raw_mean": raw_mean, "std": raw_std, "cov": cov, "cov_inv": cov_inv, "logdet": logdet, "dist_mean": dist_mean, "dist_std": dist_std, "n": int(stats.shape[0]), "dim": int(stats.shape[1]), "source_split": source_split, "cov_shrinkage": float(shrinkage), "cov_diag_weight": float(1.0 if covariance_mode == "diag_cov" else shrinkage), "cov_eps": float(eps), } if class_means: counts = torch.zeros(len(CLASS_LABELS), dtype=torch.long) means = torch.zeros(len(CLASS_LABELS), stats.shape[1], dtype=torch.float32) covs = torch.zeros(len(CLASS_LABELS), stats.shape[1], stats.shape[1], dtype=torch.float32) cov_invs = torch.zeros_like(covs) logdets = torch.zeros(len(CLASS_LABELS), dtype=torch.float32) dist_means = torch.zeros(len(CLASS_LABELS), dtype=torch.float32) dist_stds = torch.ones(len(CLASS_LABELS), dtype=torch.float32) available = torch.zeros(len(CLASS_LABELS), dtype=torch.bool) for class_id, class_mean in class_means.items(): mask = labels_cpu == class_id class_stats = stats[mask] class_cov, class_cov_inv, class_logdet = _fit_covariance( class_stats, class_mean, covariance_mode, shrinkage, eps, ) class_dist = _mahalanobis(class_stats, class_mean, class_cov_inv) class_dist_mean, class_dist_std = _distance_stats(class_dist) counts[class_id] = class_counts[class_id] means[class_id] = class_mean covs[class_id] = class_cov cov_invs[class_id] = class_cov_inv logdets[class_id] = class_logdet dist_means[class_id] = class_dist_mean dist_stds[class_id] = class_dist_std available[class_id] = True entry.update({ "class_counts": counts, "class_means": means, "class_covs": covs, "class_cov_invs": cov_invs, "class_logdets": logdets, "class_dist_means": dist_means, "class_dist_stds": dist_stds, "class_available": available, }) return entry def compute_domain_prototype( model: torch.nn.Module, loader: Iterable, domain: str, device: str | torch.device, source_split: str, shrinkage: float = 0.15, diag_weight: Optional[float] = None, eps: float = 1e-3, feature_type: str = FEATURE_TYPE, prototype_mode: str = "global", covariance_mode: str = "full_cov_shrinkage", ) -> Dict[str, object]: model.eval() stats = [] labels = [] for batch in loader: audio = batch[0] batch_labels = _extract_labels_from_batch(batch) audio = audio.float().to(device) with torch.no_grad(): stats.append(extract_logmel_stat(model, audio, feature_type=feature_type).detach().cpu()) if batch_labels is not None: labels.append(batch_labels.detach().cpu()) if not stats: raise RuntimeError(f"No samples available for {domain} prototype") label_tensor = torch.cat(labels, dim=0) if labels else None return make_prototype_entry( torch.cat(stats, dim=0), domain=domain, source_split=source_split, shrinkage=shrinkage, diag_weight=diag_weight, eps=eps, labels=label_tensor, feature_type=feature_type, prototype_mode=prototype_mode, covariance_mode=covariance_mode, ) def normalize_domain_key(key: object) -> Optional[str]: if isinstance(key, str): upper = key.upper() if upper in DOMAIN_TO_TASK: return upper if upper.isdigit(): return TASK_TO_DOMAIN.get(int(upper)) return None if isinstance(key, int): return TASK_TO_DOMAIN.get(key) return None def normalize_prototypes(obj: Mapping) -> Dict[str, Dict[str, object]]: out: Dict[str, Dict[str, object]] = {} if "domains" in obj and isinstance(obj["domains"], Mapping): obj = obj["domains"] tensor_keys = { "mean", "raw_mean", "std", "cov", "cov_inv", "logdet", "dist_mean", "dist_std", "class_counts", "class_means", "class_covs", "class_cov_invs", "class_logdets", "class_dist_means", "class_dist_stds", "class_available", } for key, value in obj.items(): if key in {"metadata", "meta", "domains"}: continue domain = normalize_domain_key(key) if domain is None or not isinstance(value, Mapping): continue entry = dict(value) entry["domain"] = domain entry["task_index"] = int(entry.get("task_index", DOMAIN_TO_TASK[domain])) entry["feature_type"] = canonical_feature_type(str(entry.get("feature_type", FEATURE_TYPE))) entry["prototype_mode"] = entry.get("prototype_mode", "global") entry["covariance_mode"] = entry.get("covariance_mode", "full_cov_shrinkage") entry["n"] = int(entry.get("n", entry.get("count", -1))) for tensor_key in tensor_keys: if tensor_key in entry and not torch.is_tensor(entry[tensor_key]): entry[tensor_key] = torch.as_tensor(entry[tensor_key]) if tensor_key in entry and torch.is_tensor(entry[tensor_key]) and tensor_key != "class_available": entry[tensor_key] = entry[tensor_key].float() if "mean" not in entry or "cov_inv" not in entry: raise ValueError(f"Prototype for {domain} must contain mean and cov_inv") if "logdet" not in entry and "cov" in entry: entry["logdet"] = _logdet_psd(entry["cov"].float(), float(entry.get("cov_eps", 1e-3))) if "dist_mean" not in entry: entry["dist_mean"] = torch.tensor(0.0) if "dist_std" not in entry: entry["dist_std"] = torch.tensor(1.0) entry["dim"] = int(entry.get("dim", entry["mean"].numel())) out[domain] = entry return out def load_prototypes(path: str | Path, map_location: str | torch.device = "cpu") -> Dict[str, Dict[str, object]]: obj = torch.load(str(path), map_location=map_location) if not isinstance(obj, Mapping): raise TypeError(f"Prototype file must contain a mapping, got {type(obj)!r}") return normalize_prototypes(obj) def save_prototypes( prototypes: Mapping[str, Mapping[str, object]], path: str | Path, metadata: Optional[Mapping[str, object]] = None, ) -> None: normalized = normalize_prototypes(prototypes) feature_type = next(iter(normalized.values())).get("feature_type", FEATURE_TYPE) if normalized else FEATURE_TYPE payload: MutableMapping[str, object] = { "meta": { "feature_type": feature_type, "sample_rate": 32000, "n_mels": 64, "note": "D1 prototype is not built unless a D1 train split is available.", }, "domains": dict(normalized), } if metadata: payload["meta"].update(dict(metadata)) path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) torch.save(payload, str(path)) def prototype_feature_type(prototypes: Mapping[str, Mapping[str, object]], fallback: str = FEATURE_TYPE) -> str: for entry in prototypes.values(): return canonical_feature_type(str(entry.get("feature_type", fallback))) return canonical_feature_type(fallback) def candidate_domain_list( prototypes: Mapping[str, Mapping[str, object]], candidate_domains: Optional[Sequence[str]] = None, ) -> List[str]: if candidate_domains is not None: return [domain for domain in candidate_domains if domain in prototypes] domains = [domain for domain in ("D1", "D2", "D3") if domain in prototypes] if "D1" not in prototypes: domains = [domain for domain in ("D2", "D3") if domain in prototypes] return domains def _entry_distance( stats: torch.Tensor, entry: Mapping[str, object], class_ids: Optional[torch.Tensor] = None, class_conditional: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if class_conditional and class_ids is not None and "class_means" in entry and "class_cov_invs" in entry: class_ids = class_ids.to(stats.device).long() class_means = entry["class_means"].to(stats.device) class_cov_invs = entry["class_cov_invs"].to(stats.device) class_logdets = entry.get("class_logdets", torch.zeros(class_means.shape[0])).to(stats.device) class_dist_means = entry.get("class_dist_means", torch.zeros(class_means.shape[0])).to(stats.device) class_dist_stds = torch.clamp(entry.get("class_dist_stds", torch.ones(class_means.shape[0])).to(stats.device), min=1e-6) means = class_means[class_ids] cov_invs = class_cov_invs[class_ids] diff = stats - means distances = torch.einsum("bi,bij,bj->b", diff, cov_invs, diff) return distances, class_logdets[class_ids], class_dist_means[class_ids], class_dist_stds[class_ids] mean = entry["mean"].to(stats.device) cov_inv = entry["cov_inv"].to(stats.device) distances = _mahalanobis(stats, mean, cov_inv) logdet = entry.get("logdet", torch.tensor(0.0)).to(stats.device).expand_as(distances) dist_mean = entry.get("dist_mean", torch.tensor(0.0)).to(stats.device).expand_as(distances) dist_std = torch.clamp(entry.get("dist_std", torch.tensor(1.0)).to(stats.device), min=1e-6).expand_as(distances) return distances, logdet, dist_mean, dist_std def class_conditional_distance_matrix(stats: torch.Tensor, entry: Mapping[str, object]) -> torch.Tensor: if "class_means" not in entry or "class_cov_invs" not in entry: raise ValueError("class_conditional prototype requires class_means and class_cov_invs") class_means = entry["class_means"].to(stats.device) class_cov_invs = entry["class_cov_invs"].to(stats.device) diff = stats.unsqueeze(1) - class_means.unsqueeze(0) return torch.einsum("bci,cij,bcj->bc", diff, class_cov_invs, diff) def prototype_scores( stats: torch.Tensor, prototypes: Mapping[str, Mapping[str, object]], candidate_domains: Optional[Sequence[str]] = None, score_type: str = "mahalanobis", class_ids: Optional[torch.Tensor] = None, class_conditional: bool = False, ) -> Tuple[torch.Tensor, List[str]]: domains = candidate_domain_list(prototypes, candidate_domains) if not domains: raise RuntimeError("No candidate domain prototypes available") scores = [] for domain in domains: distances, logdet, dist_mean, dist_std = _entry_distance( stats, prototypes[domain], class_ids=class_ids, class_conditional=class_conditional, ) if score_type == "mahalanobis": score = distances elif score_type == "gaussian_nll": score = 0.5 * distances + 0.5 * logdet elif score_type == "zscore": score = (distances - dist_mean) / dist_std else: raise ValueError(f"Unknown prototype score_type: {score_type}") scores.append(score) return torch.stack(scores, dim=1), domains def prototype_distances( stats: torch.Tensor, prototypes: Mapping[str, Mapping[str, object]], candidate_domains: Optional[Sequence[str]] = None, ) -> Tuple[torch.Tensor, List[str]]: return prototype_scores(stats, prototypes, candidate_domains, score_type="mahalanobis") def select_tasks_by_mahalanobis( model: torch.nn.Module, inputs: torch.Tensor, prototypes: Mapping[str, Mapping[str, object]], candidate_domains: Optional[Sequence[str]] = None, ) -> torch.Tensor: feature_type = prototype_feature_type(prototypes) stats = extract_logmel_stat(model, inputs, feature_type=feature_type) distances, domains = prototype_distances(stats, prototypes, candidate_domains) indices = torch.argmin(distances, dim=1) task_ids = torch.tensor([DOMAIN_TO_TASK[d] for d in domains], device=inputs.device, dtype=torch.long) return task_ids[indices] def task_probabilities(model: torch.nn.Module, inputs: torch.Tensor, tasks: Sequence[int]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: eps = sys.float_info.min logits = [] probs = [] for task_id in tasks: task_logits = model(inputs, int(task_id)) logits.append(task_logits) probs.append(torch.softmax(task_logits, dim=1)) logits_tensor = torch.stack(logits, dim=1) probs_tensor = torch.stack(probs, dim=1) entropy = -torch.sum(probs_tensor * torch.log(probs_tensor + eps), dim=-1) return probs_tensor, entropy, logits_tensor def select_tasks_by_entropy(model: torch.nn.Module, inputs: torch.Tensor, tasks: Sequence[int]) -> torch.Tensor: _, entropy, _ = task_probabilities(model, inputs, tasks) indices = torch.argmin(entropy, dim=1) task_tensor = torch.tensor(list(tasks), device=inputs.device, dtype=torch.long) return task_tensor[indices] def _gather_by_task(probs: torch.Tensor, tasks: Sequence[int], selected_tasks: torch.Tensor) -> torch.Tensor: task_tensor = torch.tensor(list(tasks), device=selected_tasks.device, dtype=torch.long) out = [] for row_idx, task_id in enumerate(selected_tasks.view(-1)): match = torch.nonzero(task_tensor == task_id, as_tuple=False) if match.numel() == 0: raise ValueError(f"Selected task {int(task_id)} is not in available task list {tasks}") out.append(probs[row_idx, int(match[0, 0])]) return torch.stack(out, dim=0) def _task_ids_for_domains(domains: Sequence[str], device: torch.device) -> torch.Tensor: return torch.tensor([DOMAIN_TO_TASK[domain] for domain in domains], device=device, dtype=torch.long) def _mean_candidate_class(probs: torch.Tensor) -> torch.Tensor: return torch.argmax(torch.mean(probs, dim=1), dim=1) def _apply_d1_fallback( selected_tasks: torch.Tensor, selected_probs: torch.Tensor, all_tasks: Sequence[int], all_probs: torch.Tensor, all_entropy: torch.Tensor, z_scores: torch.Tensor, d1_fallback: str, d1_z_threshold: float, d1_entropy_margin: float, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: fallback_mask = torch.zeros_like(selected_tasks, dtype=torch.bool) if d1_fallback == "none" or 0 not in all_tasks: return selected_tasks, selected_probs, fallback_mask task_tensor = torch.tensor(list(all_tasks), device=selected_tasks.device, dtype=torch.long) d1_col = int(torch.nonzero(task_tensor == 0, as_tuple=False)[0, 0]) d1_probs = all_probs[:, d1_col, :] d1_entropy = all_entropy[:, d1_col] other_cols = [idx for idx, task in enumerate(all_tasks) if task != 0] min_other_entropy = torch.min(all_entropy[:, other_cols], dim=1).values if d1_fallback == "entropy": entropy_tasks = task_tensor[torch.argmin(all_entropy, dim=1)] fallback_mask = entropy_tasks == 0 elif d1_fallback == "conservative": far_from_known = torch.min(z_scores, dim=1).values > d1_z_threshold d1_confident = d1_entropy + d1_entropy_margin < min_other_entropy fallback_mask = far_from_known & d1_confident else: raise ValueError(f"Unknown D1 fallback: {d1_fallback}") selected_tasks = torch.where(fallback_mask, torch.zeros_like(selected_tasks), selected_tasks) selected_probs = torch.where(fallback_mask.unsqueeze(1), d1_probs, selected_probs) return selected_tasks, selected_probs, fallback_mask def predict_with_selector( model: torch.nn.Module, inputs: torch.Tensor, prototypes: Mapping[str, Mapping[str, object]], selector: str = "mahalanobis_proto", candidate_domains: Optional[Sequence[str]] = None, d1_fallback: str = "none", tau: float = 1.0, alpha: float = 1.0, beta: float = 1.0, margin_threshold: float = 0.0, margin_fallback: str = "entropy", d1_z_threshold: float = 2.0, d1_entropy_margin: float = 0.05, feature_type: Optional[str] = None, class_conditional_strategy: str = "top1", class_distance_lambda: float = 1.0, return_details: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, object]] | torch.Tensor: selector = "mahalanobis_proto" if selector == "hard_mahalanobis" else selector if selector not in SELECTOR_MODES: raise ValueError(f"Unknown selector {selector!r}; expected one of {SELECTOR_MODES}") if class_conditional_strategy not in ("top1", "joint"): raise ValueError("--class_conditional_strategy must be top1 or joint") if margin_fallback not in ("entropy", "soft"): raise ValueError("--margin_fallback must be entropy or soft") device = inputs.device if selector == "entropy": all_tasks = [0, 1, 2] all_probs, all_entropy, _ = task_probabilities(model, inputs, all_tasks) selected_tasks = torch.tensor(all_tasks, device=device, dtype=torch.long)[torch.argmin(all_entropy, dim=1)] selected_probs = _gather_by_task(all_probs, all_tasks, selected_tasks) details = { "selector": selector, "selected_tasks": selected_tasks.detach().cpu(), "selected_domains": [TASK_TO_DOMAIN[int(task)] for task in selected_tasks.detach().cpu().tolist()], "entropy": all_entropy.detach().cpu(), "tasks": all_tasks, } return (selected_probs, selected_tasks, details) if return_details else selected_probs domains = candidate_domain_list(prototypes, candidate_domains) tasks = [DOMAIN_TO_TASK[domain] for domain in domains] if not domains: return predict_with_selector( model, inputs, prototypes, selector="entropy", d1_fallback="none", return_details=return_details, ) feature_type = canonical_feature_type(feature_type or prototype_feature_type(prototypes)) stats = extract_logmel_stat(model, inputs, feature_type=feature_type) probs, entropy, logits = task_probabilities(model, inputs, tasks) class_conditional = any( prototypes[domain].get("prototype_mode") == "class_conditional" and "class_means" in prototypes[domain] for domain in domains ) if class_conditional and class_conditional_strategy == "joint": best_scores = [] for col_idx, domain in enumerate(domains): dist_matrix = class_conditional_distance_matrix(stats, prototypes[domain]) best_scores.append(logits[:, col_idx, :] - class_distance_lambda * dist_matrix) joint_scores = torch.stack(best_scores, dim=1) flat = torch.argmax(joint_scores.view(joint_scores.shape[0], -1), dim=1) domain_indices = flat // len(CLASS_LABELS) pred_indices = flat % len(CLASS_LABELS) selected_tasks = _task_ids_for_domains(domains, device)[domain_indices] selected_probs = torch.zeros(inputs.shape[0], len(CLASS_LABELS), device=device) selected_probs.scatter_(1, pred_indices.unsqueeze(1), 1.0) raw_scores, _ = prototype_scores(stats, prototypes, domains, score_type="mahalanobis") z_scores, _ = prototype_scores(stats, prototypes, domains, score_type="zscore") else: class_ids = _mean_candidate_class(probs) if class_conditional else None if selector == "gaussian_nll_proto": scores, _ = prototype_scores(stats, prototypes, domains, score_type="gaussian_nll", class_ids=class_ids, class_conditional=class_conditional) elif selector == "mahalanobis_zscore_proto": scores, _ = prototype_scores(stats, prototypes, domains, score_type="zscore", class_ids=class_ids, class_conditional=class_conditional) else: scores, _ = prototype_scores(stats, prototypes, domains, score_type="mahalanobis", class_ids=class_ids, class_conditional=class_conditional) raw_scores, _ = prototype_scores(stats, prototypes, domains, score_type="mahalanobis", class_ids=class_ids, class_conditional=class_conditional) z_scores, _ = prototype_scores(stats, prototypes, domains, score_type="zscore", class_ids=class_ids, class_conditional=class_conditional) task_ids = _task_ids_for_domains(domains, device) if selector == "soft_mahalanobis_proto": weights = torch.softmax(-float(tau) * raw_scores, dim=1) selected_probs = torch.sum(weights.unsqueeze(-1) * probs, dim=1) selected_tasks = task_ids[torch.argmax(weights, dim=1)] elif selector == "hybrid_maha_entropy": hybrid_scores = float(alpha) * z_scores + float(beta) * entropy indices = torch.argmin(hybrid_scores, dim=1) selected_tasks = task_ids[indices] selected_probs = probs[torch.arange(inputs.shape[0], device=device), indices] scores = hybrid_scores elif selector == "maha_margin_fallback": sorted_scores, sorted_indices = torch.sort(raw_scores, dim=1) best = sorted_indices[:, 0] second = sorted_scores[:, 1] if raw_scores.shape[1] > 1 else sorted_scores[:, 0] + float("inf") margin = second - sorted_scores[:, 0] confident = margin >= float(margin_threshold) maha_probs = probs[torch.arange(inputs.shape[0], device=device), best] maha_tasks = task_ids[best] if margin_fallback == "entropy": fallback_idx = torch.argmin(entropy, dim=1) fallback_probs = probs[torch.arange(inputs.shape[0], device=device), fallback_idx] fallback_tasks = task_ids[fallback_idx] else: weights = torch.softmax(-float(tau) * raw_scores, dim=1) fallback_probs = torch.sum(weights.unsqueeze(-1) * probs, dim=1) fallback_tasks = task_ids[torch.argmax(weights, dim=1)] selected_probs = torch.where(confident.unsqueeze(1), maha_probs, fallback_probs) selected_tasks = torch.where(confident, maha_tasks, fallback_tasks) scores = raw_scores else: indices = torch.argmin(scores, dim=1) selected_tasks = task_ids[indices] selected_probs = probs[torch.arange(inputs.shape[0], device=device), indices] all_entropy = None if d1_fallback == "none": d1_mask = torch.zeros_like(selected_tasks, dtype=torch.bool) else: all_tasks = [0, 1, 2] all_probs, all_entropy, _ = task_probabilities(model, inputs, all_tasks) selected_tasks, selected_probs, d1_mask = _apply_d1_fallback( selected_tasks, selected_probs, all_tasks, all_probs, all_entropy, z_scores, d1_fallback=d1_fallback, d1_z_threshold=d1_z_threshold, d1_entropy_margin=d1_entropy_margin, ) details = { "selector": selector, "domains": domains, "tasks": tasks, "scores": scores.detach().cpu() if "scores" in locals() else raw_scores.detach().cpu(), "raw_mahalanobis": raw_scores.detach().cpu(), "zscore": z_scores.detach().cpu(), "entropy": entropy.detach().cpu(), "selected_tasks": selected_tasks.detach().cpu(), "selected_domains": [TASK_TO_DOMAIN[int(task)] for task in selected_tasks.detach().cpu().tolist()], "d1_fallback_mask": d1_mask.detach().cpu(), } if all_entropy is not None: details["all_entropy"] = all_entropy.detach().cpu() return (selected_probs, selected_tasks, details) if return_details else selected_probs def select_tasks_with_d1_fallback( model: torch.nn.Module, inputs: torch.Tensor, prototypes: Mapping[str, Mapping[str, object]], d1_fallback: str = "entropy", ) -> torch.Tensor: _, task_ids, _ = predict_with_selector( model, inputs, prototypes, selector="mahalanobis_proto", d1_fallback=d1_fallback, return_details=True, ) return task_ids def predict_with_selected_tasks(model: torch.nn.Module, inputs: torch.Tensor, task_ids: torch.Tensor) -> torch.Tensor: logits = [] for row_idx, task_id in enumerate(task_ids.view(-1).tolist()): logits.append(model(inputs[row_idx:row_idx + 1], int(task_id))) return torch.cat(logits, dim=0) def manifest_json(data: Mapping[str, object]) -> str: return json.dumps(data, indent=2, ensure_ascii=False, sort_keys=True)