DCASE-Task7-model4 / code /scripts /apply_prediction_pair_resolver.py
OrigamiShido's picture
Upload 17 files
e192899 verified
Raw
History Blame Contribute Delete
9.49 kB
#!/usr/bin/env python3
"""Apply an internally selected two-model pair resolver.
The pair table is built from an internal validation split. Target labels and
target-set statistics are not used; each output depends only on the current
sample's base/advisor prediction pair and a fixed table.
"""
from __future__ import annotations
import argparse
from collections import Counter
import math
import time
from pathlib import Path
import numpy as np
import pandas as pd
from statproto_utils import CLASS_LABELS, INDEX_TO_LABEL, LABEL_TO_INDEX, manifest_json
from strict_prediction_postprocess import (
build_pair_table,
metric,
pair_resolver_predict,
parse_pair,
read_metadata,
read_prediction,
target_array,
write_predictions,
)
def read_prediction_order(path: Path) -> list[str]:
filenames = []
with path.open("r", encoding="utf-8") as handle:
for raw in handle:
line = raw.rstrip("\n")
if not line:
continue
parts = line.split("\t")
if len(parts) != 2 or parts[1] not in LABEL_TO_INDEX:
raise ValueError(f"Bad prediction line in {path}: {line!r}")
filenames.append(parts[0])
return filenames
def read_target_filenames(path: Path | None, fallback_prediction: Path) -> list[str]:
if path is None:
return read_prediction_order(fallback_prediction)
df = pd.read_csv(path, sep="\t", header=None, dtype=str)
if df.shape[1] < 1:
raise ValueError(f"Target file list has no filename column: {path}")
return df.iloc[:, 0].tolist()
def maybe_target_metrics(target_file_list: Path | None, pred) -> dict | None:
if target_file_list is None:
return None
df = pd.read_csv(target_file_list, sep="\t", header=None, dtype=str)
if df.shape[1] < 4:
return None
metadata = read_metadata(target_file_list)
if not set(metadata["target"].tolist()).issubset(set(CLASS_LABELS)):
return None
return metric(pred, metadata, target_array(metadata))
def label_table(table: dict[tuple[int, int], int], base_name: str, advisor_name: str) -> list[dict]:
rows = []
for key, value in sorted(table.items()):
rows.append({
"prediction_pair": {
base_name: INDEX_TO_LABEL[int(key[0])],
advisor_name: INDEX_TO_LABEL[int(key[1])],
},
"output": INDEX_TO_LABEL[int(value)],
})
return rows
def full_output_table(table: dict[tuple[int, int], int], base_name: str, advisor_name: str) -> list[dict]:
rows = []
for base_idx, base_label in enumerate(CLASS_LABELS):
for advisor_idx, advisor_label in enumerate(CLASS_LABELS):
key = (base_idx, advisor_idx)
if key in table:
output = int(table[key])
source = "pair_table"
else:
output = base_idx
source = "base_prediction"
rows.append({
"prediction_pair": {
base_name: base_label,
advisor_name: advisor_label,
},
"output": INDEX_TO_LABEL[output],
"source": source,
})
return rows
def capped_pair_resolver_predict(
predictions: np.ndarray,
table: dict[tuple[int, int], int],
base_idx: int,
advisor_idx: int,
max_change_rate_per_base_class: float,
) -> tuple[np.ndarray, dict[str, int]]:
if max_change_rate_per_base_class <= 0.0:
pred = pair_resolver_predict(predictions, table, base_idx=base_idx, advisor_idx=advisor_idx)
stats = Counter()
base = predictions[base_idx]
for old, new in zip(base.tolist(), pred.tolist()):
if int(old) != int(new):
stats[f"{INDEX_TO_LABEL[int(old)]}->{INDEX_TO_LABEL[int(new)]}"] += 1
return pred, dict(stats)
base = predictions[base_idx]
out = base.copy()
candidates: dict[int, list[tuple[int, int]]] = {}
for sample_idx in range(predictions.shape[1]):
key = (int(predictions[base_idx, sample_idx]), int(predictions[advisor_idx, sample_idx]))
if key not in table:
continue
output = int(table[key])
if output == int(base[sample_idx]):
continue
candidates.setdefault(int(base[sample_idx]), []).append((sample_idx, output))
stats = Counter()
for base_label, rows in sorted(candidates.items()):
n_base = int((base == base_label).sum())
cap = int(math.floor(float(max_change_rate_per_base_class) * n_base))
if n_base > 0:
cap = max(1, cap)
cap = min(cap, len(rows))
for sample_idx, output in rows[:cap]:
out[sample_idx] = output
stats[f"{INDEX_TO_LABEL[int(base_label)]}->{INDEX_TO_LABEL[int(output)]}"] += 1
return out, dict(stats)
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--internal_metadata", type=Path, required=True)
parser.add_argument("--target_file_list", type=Path)
parser.add_argument("--output", type=Path, required=True)
parser.add_argument("--base_pair", required=True)
parser.add_argument("--advisor_pair", required=True)
parser.add_argument("--alpha", type=float, default=0.05)
parser.add_argument("--min_count", type=int, default=2)
parser.add_argument("--confidence", type=float, default=0.4)
parser.add_argument("--max_change_rate_per_base_class", type=float, default=0.0)
parser.add_argument("--skip_base_noop_rules", action="store_true")
args = parser.parse_args()
base_pair = parse_pair(args.base_pair)
advisor_pair = parse_pair(args.advisor_pair)
pairs = [base_pair, advisor_pair]
internal_metadata = read_metadata(args.internal_metadata.resolve())
internal_target = target_array(internal_metadata)
internal_filenames = internal_metadata["filename"].tolist()
internal_predictions = np.stack([
read_prediction(Path(pair["internal_path"]).resolve(), internal_filenames)
for pair in pairs
], axis=0)
table = build_pair_table(
internal_predictions,
internal_target,
base_idx=0,
advisor_idx=1,
min_count=args.min_count,
confidence=args.confidence,
alpha=args.alpha,
)
if args.skip_base_noop_rules:
table = {key: value for key, value in table.items() if int(value) != int(key[0])}
target_filenames = read_target_filenames(
args.target_file_list.resolve() if args.target_file_list else None,
Path(base_pair["final_path"]).resolve(),
)
target_predictions = np.stack([
read_prediction(Path(pair["final_path"]).resolve(), target_filenames)
for pair in pairs
], axis=0)
pred, change_stats = capped_pair_resolver_predict(
target_predictions,
table,
base_idx=0,
advisor_idx=1,
max_change_rate_per_base_class=args.max_change_rate_per_base_class,
)
args.output.parent.mkdir(parents=True, exist_ok=True)
write_predictions(args.output, target_filenames, pred)
target_metrics = maybe_target_metrics(args.target_file_list.resolve() if args.target_file_list else None, pred)
manifest = {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"method": "prediction_pair_resolver_apply_only",
"output": str(args.output.resolve()),
"base_pair": {
"name": base_pair["name"],
"internal_path": str(Path(base_pair["internal_path"]).resolve()),
"final_path": str(Path(base_pair["final_path"]).resolve()),
},
"advisor_pair": {
"name": advisor_pair["name"],
"internal_path": str(Path(advisor_pair["internal_path"]).resolve()),
"final_path": str(Path(advisor_pair["final_path"]).resolve()),
},
"selection": {
"method": "pair_resolver",
"alpha": args.alpha,
"min_count": args.min_count,
"confidence": args.confidence,
"max_change_rate_per_base_class": args.max_change_rate_per_base_class,
"skip_base_noop_rules": args.skip_base_noop_rules,
"base_idx": 0,
"advisor_idx": 1,
},
"table_size": len(table),
"change_count": int(sum(change_stats.values())),
"change_stats": change_stats,
"table": label_table(table, base_pair["name"], advisor_pair["name"]),
"full_output_table": full_output_table(table, base_pair["name"], advisor_pair["name"]),
"target_file_list": str(args.target_file_list.resolve()) if args.target_file_list else None,
"target_metrics_if_labeled": target_metrics,
"compliance": {
"selection_metadata": str(args.internal_metadata.resolve()),
"target_labels_used": False,
"target_set_statistics_used": False,
"per_sample_independent_decisions": True,
"external_data_used": False,
},
}
manifest_path = args.output.with_suffix(args.output.suffix + ".manifest.json")
manifest_path.write_text(manifest_json(manifest) + "\n", encoding="utf-8")
print(manifest_json({
"output": str(args.output.resolve()),
"manifest": str(manifest_path.resolve()),
"table_size": len(table),
"target_metrics_if_labeled": target_metrics,
}))
if __name__ == "__main__":
main()