#!/usr/bin/env python3 """Apply an internally selected two-model pair resolver. The pair table is built from an internal validation split. Target labels and target-set statistics are not used; each output depends only on the current sample's base/advisor prediction pair and a fixed table. """ from __future__ import annotations import argparse from collections import Counter import math import time from pathlib import Path import numpy as np import pandas as pd from statproto_utils import CLASS_LABELS, INDEX_TO_LABEL, LABEL_TO_INDEX, manifest_json from strict_prediction_postprocess import ( build_pair_table, metric, pair_resolver_predict, parse_pair, read_metadata, read_prediction, target_array, write_predictions, ) def read_prediction_order(path: Path) -> list[str]: filenames = [] with path.open("r", encoding="utf-8") as handle: for raw in handle: line = raw.rstrip("\n") if not line: continue parts = line.split("\t") if len(parts) != 2 or parts[1] not in LABEL_TO_INDEX: raise ValueError(f"Bad prediction line in {path}: {line!r}") filenames.append(parts[0]) return filenames def read_target_filenames(path: Path | None, fallback_prediction: Path) -> list[str]: if path is None: return read_prediction_order(fallback_prediction) df = pd.read_csv(path, sep="\t", header=None, dtype=str) if df.shape[1] < 1: raise ValueError(f"Target file list has no filename column: {path}") return df.iloc[:, 0].tolist() def maybe_target_metrics(target_file_list: Path | None, pred) -> dict | None: if target_file_list is None: return None df = pd.read_csv(target_file_list, sep="\t", header=None, dtype=str) if df.shape[1] < 4: return None metadata = read_metadata(target_file_list) if not set(metadata["target"].tolist()).issubset(set(CLASS_LABELS)): return None return metric(pred, metadata, target_array(metadata)) def label_table(table: dict[tuple[int, int], int], base_name: str, advisor_name: str) -> list[dict]: rows = [] for key, value in sorted(table.items()): rows.append({ "prediction_pair": { base_name: INDEX_TO_LABEL[int(key[0])], advisor_name: INDEX_TO_LABEL[int(key[1])], }, "output": INDEX_TO_LABEL[int(value)], }) return rows def full_output_table(table: dict[tuple[int, int], int], base_name: str, advisor_name: str) -> list[dict]: rows = [] for base_idx, base_label in enumerate(CLASS_LABELS): for advisor_idx, advisor_label in enumerate(CLASS_LABELS): key = (base_idx, advisor_idx) if key in table: output = int(table[key]) source = "pair_table" else: output = base_idx source = "base_prediction" rows.append({ "prediction_pair": { base_name: base_label, advisor_name: advisor_label, }, "output": INDEX_TO_LABEL[output], "source": source, }) return rows def capped_pair_resolver_predict( predictions: np.ndarray, table: dict[tuple[int, int], int], base_idx: int, advisor_idx: int, max_change_rate_per_base_class: float, ) -> tuple[np.ndarray, dict[str, int]]: if max_change_rate_per_base_class <= 0.0: pred = pair_resolver_predict(predictions, table, base_idx=base_idx, advisor_idx=advisor_idx) stats = Counter() base = predictions[base_idx] for old, new in zip(base.tolist(), pred.tolist()): if int(old) != int(new): stats[f"{INDEX_TO_LABEL[int(old)]}->{INDEX_TO_LABEL[int(new)]}"] += 1 return pred, dict(stats) base = predictions[base_idx] out = base.copy() candidates: dict[int, list[tuple[int, int]]] = {} for sample_idx in range(predictions.shape[1]): key = (int(predictions[base_idx, sample_idx]), int(predictions[advisor_idx, sample_idx])) if key not in table: continue output = int(table[key]) if output == int(base[sample_idx]): continue candidates.setdefault(int(base[sample_idx]), []).append((sample_idx, output)) stats = Counter() for base_label, rows in sorted(candidates.items()): n_base = int((base == base_label).sum()) cap = int(math.floor(float(max_change_rate_per_base_class) * n_base)) if n_base > 0: cap = max(1, cap) cap = min(cap, len(rows)) for sample_idx, output in rows[:cap]: out[sample_idx] = output stats[f"{INDEX_TO_LABEL[int(base_label)]}->{INDEX_TO_LABEL[int(output)]}"] += 1 return out, dict(stats) def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--internal_metadata", type=Path, required=True) parser.add_argument("--target_file_list", type=Path) parser.add_argument("--output", type=Path, required=True) parser.add_argument("--base_pair", required=True) parser.add_argument("--advisor_pair", required=True) parser.add_argument("--alpha", type=float, default=0.05) parser.add_argument("--min_count", type=int, default=2) parser.add_argument("--confidence", type=float, default=0.4) parser.add_argument("--max_change_rate_per_base_class", type=float, default=0.0) parser.add_argument("--skip_base_noop_rules", action="store_true") args = parser.parse_args() base_pair = parse_pair(args.base_pair) advisor_pair = parse_pair(args.advisor_pair) pairs = [base_pair, advisor_pair] internal_metadata = read_metadata(args.internal_metadata.resolve()) internal_target = target_array(internal_metadata) internal_filenames = internal_metadata["filename"].tolist() internal_predictions = np.stack([ read_prediction(Path(pair["internal_path"]).resolve(), internal_filenames) for pair in pairs ], axis=0) table = build_pair_table( internal_predictions, internal_target, base_idx=0, advisor_idx=1, min_count=args.min_count, confidence=args.confidence, alpha=args.alpha, ) if args.skip_base_noop_rules: table = {key: value for key, value in table.items() if int(value) != int(key[0])} target_filenames = read_target_filenames( args.target_file_list.resolve() if args.target_file_list else None, Path(base_pair["final_path"]).resolve(), ) target_predictions = np.stack([ read_prediction(Path(pair["final_path"]).resolve(), target_filenames) for pair in pairs ], axis=0) pred, change_stats = capped_pair_resolver_predict( target_predictions, table, base_idx=0, advisor_idx=1, max_change_rate_per_base_class=args.max_change_rate_per_base_class, ) args.output.parent.mkdir(parents=True, exist_ok=True) write_predictions(args.output, target_filenames, pred) target_metrics = maybe_target_metrics(args.target_file_list.resolve() if args.target_file_list else None, pred) manifest = { "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "method": "prediction_pair_resolver_apply_only", "output": str(args.output.resolve()), "base_pair": { "name": base_pair["name"], "internal_path": str(Path(base_pair["internal_path"]).resolve()), "final_path": str(Path(base_pair["final_path"]).resolve()), }, "advisor_pair": { "name": advisor_pair["name"], "internal_path": str(Path(advisor_pair["internal_path"]).resolve()), "final_path": str(Path(advisor_pair["final_path"]).resolve()), }, "selection": { "method": "pair_resolver", "alpha": args.alpha, "min_count": args.min_count, "confidence": args.confidence, "max_change_rate_per_base_class": args.max_change_rate_per_base_class, "skip_base_noop_rules": args.skip_base_noop_rules, "base_idx": 0, "advisor_idx": 1, }, "table_size": len(table), "change_count": int(sum(change_stats.values())), "change_stats": change_stats, "table": label_table(table, base_pair["name"], advisor_pair["name"]), "full_output_table": full_output_table(table, base_pair["name"], advisor_pair["name"]), "target_file_list": str(args.target_file_list.resolve()) if args.target_file_list else None, "target_metrics_if_labeled": target_metrics, "compliance": { "selection_metadata": str(args.internal_metadata.resolve()), "target_labels_used": False, "target_set_statistics_used": False, "per_sample_independent_decisions": True, "external_data_used": False, }, } manifest_path = args.output.with_suffix(args.output.suffix + ".manifest.json") manifest_path.write_text(manifest_json(manifest) + "\n", encoding="utf-8") print(manifest_json({ "output": str(args.output.resolve()), "manifest": str(manifest_path.resolve()), "table_size": len(table), "target_metrics_if_labeled": target_metrics, })) if __name__ == "__main__": main()