#!/usr/bin/env python3 """Validate DCASE Task 7 TSV prediction file format.""" from __future__ import annotations import argparse import sys from collections import Counter from pathlib import Path from typing import Dict, List, Optional, Tuple try: from statproto_utils import CLASS_LABELS except ModuleNotFoundError: CLASS_LABELS = [ "alarm", "baby_cry", "dog_bark", "engine", "fire", "footsteps", "knocking", "telephone_ringing", "piano", "speech", ] def read_expected_audio(path: Optional[Path]) -> Tuple[List[str], Dict[str, str]]: if path is None: return [], {} filenames: List[str] = [] truth: Dict[str, str] = {} with path.open("r", encoding="utf-8") as handle: for line_no, raw_line in enumerate(handle, start=1): line = raw_line.rstrip("\n") if not line: continue parts = line.split("\t") if not parts or parts[0] == "": raise ValueError(f"{path}:{line_no}: empty audio_path in expected list") filename = parts[0] filenames.append(filename) label = parts[1] if len(parts) >= 2 else "" if label in CLASS_LABELS: truth[filename] = label if not filenames: raise ValueError(f"No rows found in {path}") return filenames, truth def validate(prediction_path: Path, expected_path: Optional[Path]) -> Tuple[List[str], List[str]]: errors: List[str] = [] warnings: List[str] = [] expected_files, truth = read_expected_audio(expected_path) expected_set = set(expected_files) seen: List[str] = [] predictions: Dict[str, str] = {} with prediction_path.open("r", encoding="utf-8") as handle: for line_no, raw_line in enumerate(handle, start=1): line = raw_line.rstrip("\n") if not line: warnings.append(f"{prediction_path}:{line_no}: empty line ignored") continue parts = line.split("\t") if len(parts) != 2: errors.append(f"{prediction_path}:{line_no}: expected exactly 2 TSV columns, got {len(parts)}") continue filename, label = parts if not filename: errors.append(f"{prediction_path}:{line_no}: empty audio_path") if label not in CLASS_LABELS: errors.append(f"{prediction_path}:{line_no}: invalid predicted_label {label!r}") seen.append(filename) predictions[filename] = label counts = Counter(seen) duplicates = [filename for filename, count in counts.items() if count > 1] if duplicates: errors.append(f"Duplicate audio_path rows: {len(duplicates)}, first={duplicates[0]}") if expected_files: missing = [filename for filename in expected_files if filename not in predictions] extra = [filename for filename in predictions if filename not in expected_set] if missing: errors.append(f"Missing predictions: {len(missing)}, first={missing[0]}") if extra: errors.append(f"Predictions for files outside expected list: {len(extra)}, first={extra[0]}") if truth and expected_files and not errors: compared = [filename for filename in expected_files if filename in predictions and filename in truth] if compared: matches = sum(1 for filename in compared if predictions[filename] == truth[filename]) if matches == len(compared): errors.append( "Predictions exactly match every available ground-truth label; " "this strongly suggests the output file contains labels instead of model predictions." ) elif matches / len(compared) > 0.95: warnings.append( f"Predictions match {matches}/{len(compared)} ground-truth labels; " "double-check that the file was generated by inference." ) return errors, warnings def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--predictions", "--submission", dest="predictions", required=True, type=Path) parser.add_argument("--expected_audio_list", "--audio_list", dest="expected_audio_list", type=Path) args = parser.parse_args() errors, warnings = validate(args.predictions, args.expected_audio_list) for warning in warnings: print(f"WARNING: {warning}") if errors: for error in errors: print(f"ERROR: {error}", file=sys.stderr) raise SystemExit(1) print("Submission format check passed.") print(f"Allowed labels: {', '.join(CLASS_LABELS)}") if __name__ == "__main__": main()