DCASE-Task7-model4 / code /scripts /check_submission_format.py
OrigamiShido's picture
Upload 17 files
e192899 verified
Raw
History Blame Contribute Delete
4.85 kB
#!/usr/bin/env python3
"""Validate DCASE Task 7 TSV prediction file format."""
from __future__ import annotations
import argparse
import sys
from collections import Counter
from pathlib import Path
from typing import Dict, List, Optional, Tuple
try:
from statproto_utils import CLASS_LABELS
except ModuleNotFoundError:
CLASS_LABELS = [
"alarm",
"baby_cry",
"dog_bark",
"engine",
"fire",
"footsteps",
"knocking",
"telephone_ringing",
"piano",
"speech",
]
def read_expected_audio(path: Optional[Path]) -> Tuple[List[str], Dict[str, str]]:
if path is None:
return [], {}
filenames: List[str] = []
truth: Dict[str, str] = {}
with path.open("r", encoding="utf-8") as handle:
for line_no, raw_line in enumerate(handle, start=1):
line = raw_line.rstrip("\n")
if not line:
continue
parts = line.split("\t")
if not parts or parts[0] == "":
raise ValueError(f"{path}:{line_no}: empty audio_path in expected list")
filename = parts[0]
filenames.append(filename)
label = parts[1] if len(parts) >= 2 else ""
if label in CLASS_LABELS:
truth[filename] = label
if not filenames:
raise ValueError(f"No rows found in {path}")
return filenames, truth
def validate(prediction_path: Path, expected_path: Optional[Path]) -> Tuple[List[str], List[str]]:
errors: List[str] = []
warnings: List[str] = []
expected_files, truth = read_expected_audio(expected_path)
expected_set = set(expected_files)
seen: List[str] = []
predictions: Dict[str, str] = {}
with prediction_path.open("r", encoding="utf-8") as handle:
for line_no, raw_line in enumerate(handle, start=1):
line = raw_line.rstrip("\n")
if not line:
warnings.append(f"{prediction_path}:{line_no}: empty line ignored")
continue
parts = line.split("\t")
if len(parts) != 2:
errors.append(f"{prediction_path}:{line_no}: expected exactly 2 TSV columns, got {len(parts)}")
continue
filename, label = parts
if not filename:
errors.append(f"{prediction_path}:{line_no}: empty audio_path")
if label not in CLASS_LABELS:
errors.append(f"{prediction_path}:{line_no}: invalid predicted_label {label!r}")
seen.append(filename)
predictions[filename] = label
counts = Counter(seen)
duplicates = [filename for filename, count in counts.items() if count > 1]
if duplicates:
errors.append(f"Duplicate audio_path rows: {len(duplicates)}, first={duplicates[0]}")
if expected_files:
missing = [filename for filename in expected_files if filename not in predictions]
extra = [filename for filename in predictions if filename not in expected_set]
if missing:
errors.append(f"Missing predictions: {len(missing)}, first={missing[0]}")
if extra:
errors.append(f"Predictions for files outside expected list: {len(extra)}, first={extra[0]}")
if truth and expected_files and not errors:
compared = [filename for filename in expected_files if filename in predictions and filename in truth]
if compared:
matches = sum(1 for filename in compared if predictions[filename] == truth[filename])
if matches == len(compared):
errors.append(
"Predictions exactly match every available ground-truth label; "
"this strongly suggests the output file contains labels instead of model predictions."
)
elif matches / len(compared) > 0.95:
warnings.append(
f"Predictions match {matches}/{len(compared)} ground-truth labels; "
"double-check that the file was generated by inference."
)
return errors, warnings
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--predictions", "--submission", dest="predictions", required=True, type=Path)
parser.add_argument("--expected_audio_list", "--audio_list", dest="expected_audio_list", type=Path)
args = parser.parse_args()
errors, warnings = validate(args.predictions, args.expected_audio_list)
for warning in warnings:
print(f"WARNING: {warning}")
if errors:
for error in errors:
print(f"ERROR: {error}", file=sys.stderr)
raise SystemExit(1)
print("Submission format check passed.")
print(f"Allowed labels: {', '.join(CLASS_LABELS)}")
if __name__ == "__main__":
main()