| |
| """Apply an internally selected two-model pair resolver. |
| |
| The pair table is built from an internal validation split. Target labels and |
| target-set statistics are not used; each output depends only on the current |
| sample's base/advisor prediction pair and a fixed table. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| from collections import Counter |
| import math |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import pandas as pd |
|
|
| from statproto_utils import CLASS_LABELS, INDEX_TO_LABEL, LABEL_TO_INDEX, manifest_json |
| from strict_prediction_postprocess import ( |
| build_pair_table, |
| metric, |
| pair_resolver_predict, |
| parse_pair, |
| read_metadata, |
| read_prediction, |
| target_array, |
| write_predictions, |
| ) |
|
|
|
|
| def read_prediction_order(path: Path) -> list[str]: |
| filenames = [] |
| with path.open("r", encoding="utf-8") as handle: |
| for raw in handle: |
| line = raw.rstrip("\n") |
| if not line: |
| continue |
| parts = line.split("\t") |
| if len(parts) != 2 or parts[1] not in LABEL_TO_INDEX: |
| raise ValueError(f"Bad prediction line in {path}: {line!r}") |
| filenames.append(parts[0]) |
| return filenames |
|
|
|
|
| def read_target_filenames(path: Path | None, fallback_prediction: Path) -> list[str]: |
| if path is None: |
| return read_prediction_order(fallback_prediction) |
| df = pd.read_csv(path, sep="\t", header=None, dtype=str) |
| if df.shape[1] < 1: |
| raise ValueError(f"Target file list has no filename column: {path}") |
| return df.iloc[:, 0].tolist() |
|
|
|
|
| def maybe_target_metrics(target_file_list: Path | None, pred) -> dict | None: |
| if target_file_list is None: |
| return None |
| df = pd.read_csv(target_file_list, sep="\t", header=None, dtype=str) |
| if df.shape[1] < 4: |
| return None |
| metadata = read_metadata(target_file_list) |
| if not set(metadata["target"].tolist()).issubset(set(CLASS_LABELS)): |
| return None |
| return metric(pred, metadata, target_array(metadata)) |
|
|
|
|
| def label_table(table: dict[tuple[int, int], int], base_name: str, advisor_name: str) -> list[dict]: |
| rows = [] |
| for key, value in sorted(table.items()): |
| rows.append({ |
| "prediction_pair": { |
| base_name: INDEX_TO_LABEL[int(key[0])], |
| advisor_name: INDEX_TO_LABEL[int(key[1])], |
| }, |
| "output": INDEX_TO_LABEL[int(value)], |
| }) |
| return rows |
|
|
|
|
| def full_output_table(table: dict[tuple[int, int], int], base_name: str, advisor_name: str) -> list[dict]: |
| rows = [] |
| for base_idx, base_label in enumerate(CLASS_LABELS): |
| for advisor_idx, advisor_label in enumerate(CLASS_LABELS): |
| key = (base_idx, advisor_idx) |
| if key in table: |
| output = int(table[key]) |
| source = "pair_table" |
| else: |
| output = base_idx |
| source = "base_prediction" |
| rows.append({ |
| "prediction_pair": { |
| base_name: base_label, |
| advisor_name: advisor_label, |
| }, |
| "output": INDEX_TO_LABEL[output], |
| "source": source, |
| }) |
| return rows |
|
|
|
|
| def capped_pair_resolver_predict( |
| predictions: np.ndarray, |
| table: dict[tuple[int, int], int], |
| base_idx: int, |
| advisor_idx: int, |
| max_change_rate_per_base_class: float, |
| ) -> tuple[np.ndarray, dict[str, int]]: |
| if max_change_rate_per_base_class <= 0.0: |
| pred = pair_resolver_predict(predictions, table, base_idx=base_idx, advisor_idx=advisor_idx) |
| stats = Counter() |
| base = predictions[base_idx] |
| for old, new in zip(base.tolist(), pred.tolist()): |
| if int(old) != int(new): |
| stats[f"{INDEX_TO_LABEL[int(old)]}->{INDEX_TO_LABEL[int(new)]}"] += 1 |
| return pred, dict(stats) |
|
|
| base = predictions[base_idx] |
| out = base.copy() |
| candidates: dict[int, list[tuple[int, int]]] = {} |
| for sample_idx in range(predictions.shape[1]): |
| key = (int(predictions[base_idx, sample_idx]), int(predictions[advisor_idx, sample_idx])) |
| if key not in table: |
| continue |
| output = int(table[key]) |
| if output == int(base[sample_idx]): |
| continue |
| candidates.setdefault(int(base[sample_idx]), []).append((sample_idx, output)) |
|
|
| stats = Counter() |
| for base_label, rows in sorted(candidates.items()): |
| n_base = int((base == base_label).sum()) |
| cap = int(math.floor(float(max_change_rate_per_base_class) * n_base)) |
| if n_base > 0: |
| cap = max(1, cap) |
| cap = min(cap, len(rows)) |
| for sample_idx, output in rows[:cap]: |
| out[sample_idx] = output |
| stats[f"{INDEX_TO_LABEL[int(base_label)]}->{INDEX_TO_LABEL[int(output)]}"] += 1 |
| return out, dict(stats) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument("--internal_metadata", type=Path, required=True) |
| parser.add_argument("--target_file_list", type=Path) |
| parser.add_argument("--output", type=Path, required=True) |
| parser.add_argument("--base_pair", required=True) |
| parser.add_argument("--advisor_pair", required=True) |
| parser.add_argument("--alpha", type=float, default=0.05) |
| parser.add_argument("--min_count", type=int, default=2) |
| parser.add_argument("--confidence", type=float, default=0.4) |
| parser.add_argument("--max_change_rate_per_base_class", type=float, default=0.0) |
| parser.add_argument("--skip_base_noop_rules", action="store_true") |
| args = parser.parse_args() |
|
|
| base_pair = parse_pair(args.base_pair) |
| advisor_pair = parse_pair(args.advisor_pair) |
| pairs = [base_pair, advisor_pair] |
|
|
| internal_metadata = read_metadata(args.internal_metadata.resolve()) |
| internal_target = target_array(internal_metadata) |
| internal_filenames = internal_metadata["filename"].tolist() |
| internal_predictions = np.stack([ |
| read_prediction(Path(pair["internal_path"]).resolve(), internal_filenames) |
| for pair in pairs |
| ], axis=0) |
|
|
| table = build_pair_table( |
| internal_predictions, |
| internal_target, |
| base_idx=0, |
| advisor_idx=1, |
| min_count=args.min_count, |
| confidence=args.confidence, |
| alpha=args.alpha, |
| ) |
| if args.skip_base_noop_rules: |
| table = {key: value for key, value in table.items() if int(value) != int(key[0])} |
|
|
| target_filenames = read_target_filenames( |
| args.target_file_list.resolve() if args.target_file_list else None, |
| Path(base_pair["final_path"]).resolve(), |
| ) |
| target_predictions = np.stack([ |
| read_prediction(Path(pair["final_path"]).resolve(), target_filenames) |
| for pair in pairs |
| ], axis=0) |
| pred, change_stats = capped_pair_resolver_predict( |
| target_predictions, |
| table, |
| base_idx=0, |
| advisor_idx=1, |
| max_change_rate_per_base_class=args.max_change_rate_per_base_class, |
| ) |
|
|
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| write_predictions(args.output, target_filenames, pred) |
| target_metrics = maybe_target_metrics(args.target_file_list.resolve() if args.target_file_list else None, pred) |
|
|
| manifest = { |
| "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), |
| "method": "prediction_pair_resolver_apply_only", |
| "output": str(args.output.resolve()), |
| "base_pair": { |
| "name": base_pair["name"], |
| "internal_path": str(Path(base_pair["internal_path"]).resolve()), |
| "final_path": str(Path(base_pair["final_path"]).resolve()), |
| }, |
| "advisor_pair": { |
| "name": advisor_pair["name"], |
| "internal_path": str(Path(advisor_pair["internal_path"]).resolve()), |
| "final_path": str(Path(advisor_pair["final_path"]).resolve()), |
| }, |
| "selection": { |
| "method": "pair_resolver", |
| "alpha": args.alpha, |
| "min_count": args.min_count, |
| "confidence": args.confidence, |
| "max_change_rate_per_base_class": args.max_change_rate_per_base_class, |
| "skip_base_noop_rules": args.skip_base_noop_rules, |
| "base_idx": 0, |
| "advisor_idx": 1, |
| }, |
| "table_size": len(table), |
| "change_count": int(sum(change_stats.values())), |
| "change_stats": change_stats, |
| "table": label_table(table, base_pair["name"], advisor_pair["name"]), |
| "full_output_table": full_output_table(table, base_pair["name"], advisor_pair["name"]), |
| "target_file_list": str(args.target_file_list.resolve()) if args.target_file_list else None, |
| "target_metrics_if_labeled": target_metrics, |
| "compliance": { |
| "selection_metadata": str(args.internal_metadata.resolve()), |
| "target_labels_used": False, |
| "target_set_statistics_used": False, |
| "per_sample_independent_decisions": True, |
| "external_data_used": False, |
| }, |
| } |
| manifest_path = args.output.with_suffix(args.output.suffix + ".manifest.json") |
| manifest_path.write_text(manifest_json(manifest) + "\n", encoding="utf-8") |
| print(manifest_json({ |
| "output": str(args.output.resolve()), |
| "manifest": str(manifest_path.resolve()), |
| "table_size": len(table), |
| "target_metrics_if_labeled": target_metrics, |
| })) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|