| |
| """Validate DCASE Task 7 TSV prediction file format.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from collections import Counter |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| try: |
| from statproto_utils import CLASS_LABELS |
| except ModuleNotFoundError: |
| CLASS_LABELS = [ |
| "alarm", |
| "baby_cry", |
| "dog_bark", |
| "engine", |
| "fire", |
| "footsteps", |
| "knocking", |
| "telephone_ringing", |
| "piano", |
| "speech", |
| ] |
|
|
|
|
| def read_expected_audio(path: Optional[Path]) -> Tuple[List[str], Dict[str, str]]: |
| if path is None: |
| return [], {} |
|
|
| filenames: List[str] = [] |
| truth: Dict[str, str] = {} |
| with path.open("r", encoding="utf-8") as handle: |
| for line_no, raw_line in enumerate(handle, start=1): |
| line = raw_line.rstrip("\n") |
| if not line: |
| continue |
| parts = line.split("\t") |
| if not parts or parts[0] == "": |
| raise ValueError(f"{path}:{line_no}: empty audio_path in expected list") |
| filename = parts[0] |
| filenames.append(filename) |
| label = parts[1] if len(parts) >= 2 else "" |
| if label in CLASS_LABELS: |
| truth[filename] = label |
| if not filenames: |
| raise ValueError(f"No rows found in {path}") |
| return filenames, truth |
|
|
|
|
| def validate(prediction_path: Path, expected_path: Optional[Path]) -> Tuple[List[str], List[str]]: |
| errors: List[str] = [] |
| warnings: List[str] = [] |
| expected_files, truth = read_expected_audio(expected_path) |
| expected_set = set(expected_files) |
|
|
| seen: List[str] = [] |
| predictions: Dict[str, str] = {} |
|
|
| with prediction_path.open("r", encoding="utf-8") as handle: |
| for line_no, raw_line in enumerate(handle, start=1): |
| line = raw_line.rstrip("\n") |
| if not line: |
| warnings.append(f"{prediction_path}:{line_no}: empty line ignored") |
| continue |
| parts = line.split("\t") |
| if len(parts) != 2: |
| errors.append(f"{prediction_path}:{line_no}: expected exactly 2 TSV columns, got {len(parts)}") |
| continue |
| filename, label = parts |
| if not filename: |
| errors.append(f"{prediction_path}:{line_no}: empty audio_path") |
| if label not in CLASS_LABELS: |
| errors.append(f"{prediction_path}:{line_no}: invalid predicted_label {label!r}") |
| seen.append(filename) |
| predictions[filename] = label |
|
|
| counts = Counter(seen) |
| duplicates = [filename for filename, count in counts.items() if count > 1] |
| if duplicates: |
| errors.append(f"Duplicate audio_path rows: {len(duplicates)}, first={duplicates[0]}") |
|
|
| if expected_files: |
| missing = [filename for filename in expected_files if filename not in predictions] |
| extra = [filename for filename in predictions if filename not in expected_set] |
| if missing: |
| errors.append(f"Missing predictions: {len(missing)}, first={missing[0]}") |
| if extra: |
| errors.append(f"Predictions for files outside expected list: {len(extra)}, first={extra[0]}") |
|
|
| if truth and expected_files and not errors: |
| compared = [filename for filename in expected_files if filename in predictions and filename in truth] |
| if compared: |
| matches = sum(1 for filename in compared if predictions[filename] == truth[filename]) |
| if matches == len(compared): |
| errors.append( |
| "Predictions exactly match every available ground-truth label; " |
| "this strongly suggests the output file contains labels instead of model predictions." |
| ) |
| elif matches / len(compared) > 0.95: |
| warnings.append( |
| f"Predictions match {matches}/{len(compared)} ground-truth labels; " |
| "double-check that the file was generated by inference." |
| ) |
|
|
| return errors, warnings |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument("--predictions", "--submission", dest="predictions", required=True, type=Path) |
| parser.add_argument("--expected_audio_list", "--audio_list", dest="expected_audio_list", type=Path) |
| args = parser.parse_args() |
|
|
| errors, warnings = validate(args.predictions, args.expected_audio_list) |
| for warning in warnings: |
| print(f"WARNING: {warning}") |
| if errors: |
| for error in errors: |
| print(f"ERROR: {error}", file=sys.stderr) |
| raise SystemExit(1) |
| print("Submission format check passed.") |
| print(f"Allowed labels: {', '.join(CLASS_LABELS)}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|