Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import Dict, List, Set, Tuple | |
| TARGETS = ["balance_sheet", "profit_and_loss", "cash_flow"] | |
| SCOPES = ["consolidated", "standalone"] | |
| def load_json(p: Path): | |
| with open(p, "r", encoding="utf-8") as fh: | |
| return json.load(fh) | |
| def to_set_pages(obj) -> Set[int]: | |
| """Normalize a GT or predicted pages value into a set of ints.""" | |
| if obj is None: | |
| return set() | |
| if isinstance(obj, (int, float)): | |
| return {int(obj)} | |
| if isinstance(obj, str): | |
| if obj.isdigit(): | |
| return {int(obj)} | |
| return set() | |
| if isinstance(obj, (list, tuple, set)): | |
| return set(int(x) for x in obj if isinstance(x, (int, float)) or (isinstance(x, str) and x.isdigit())) | |
| # fallback: attempt to parse iterable | |
| try: | |
| return set(int(x) for x in obj) | |
| except Exception: | |
| return set() | |
| def jaccard(a: Set[int], b: Set[int]) -> float: | |
| if not a and not b: | |
| return 1.0 | |
| if not a and b: | |
| return 0.0 | |
| inter = len(a & b) | |
| union = len(a | b) | |
| return inter / union if union > 0 else 0.0 | |
| def precision_recall_f1(tp: int, fp: int, fn: int) -> Tuple[float, float, float]: | |
| p = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| r = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0 | |
| return p, r, f1 | |
| def evaluate_file(gt_path: Path, pred_path: Path) -> Dict: | |
| gt = load_json(gt_path) | |
| pred = load_json(pred_path) | |
| # Map possible GT key synonyms to canonical targets | |
| gt_key_map = {"pnl": "profit_and_loss", "profit_and_loss": "profit_and_loss"} | |
| per_stmt_scores = {} | |
| per_stmt_counts = {} | |
| # For confusion counts aggregated by (stmt, scope) | |
| counts = {(stmt, scope): {"tp": 0, "fp": 0, "fn": 0} for stmt in TARGETS for scope in SCOPES} | |
| for stmt in TARGETS: | |
| # GT: GT sometimes uses 'pnl' key | |
| raw_gt = None | |
| if stmt in gt: | |
| raw_gt = gt.get(stmt) | |
| elif stmt == "profit_and_loss" and "pnl" in gt: | |
| raw_gt = gt.get("pnl") | |
| # Normalize GT scopes -> sets | |
| gt_scopes: Dict[str, Set[int]] = {} | |
| if isinstance(raw_gt, dict): | |
| for scope in SCOPES: | |
| if scope in raw_gt and raw_gt[scope]: | |
| gt_scopes[scope] = to_set_pages(raw_gt[scope]) | |
| else: | |
| # If GT is list (no scope), treat as 'consolidated' single scope | |
| if isinstance(raw_gt, list): | |
| gt_scopes["consolidated"] = to_set_pages(raw_gt) | |
| # Predictions: predicted blocks per stmt | |
| pred_blocks = pred.get(stmt) or [] | |
| pred_by_scope: Dict[str, Set[int]] = {"consolidated": set(), "standalone": set(), "unknown": set()} | |
| for b in pred_blocks: | |
| if not isinstance(b, dict): | |
| continue | |
| scope = (b.get("scope") or "unknown").lower() | |
| # Try 'pages' first, then 'start_page' to 'end_page' range | |
| pages = to_set_pages(b.get("pages") or []) | |
| if not pages: | |
| sp = b.get("start_page") | |
| ep = b.get("end_page") | |
| if isinstance(sp, int) and isinstance(ep, int): | |
| pages = set(range(sp, ep + 1)) | |
| if scope not in pred_by_scope: | |
| pred_by_scope[scope] = set() | |
| pred_by_scope[scope] |= pages | |
| pred_any_scope = set().union(*pred_by_scope.values()) | |
| # Scoring logic per statement | |
| stmt_scores = [] | |
| if gt_scopes: | |
| # If GT has both scopes, score each separately and average | |
| if all(s in gt_scopes for s in SCOPES): | |
| for scope in SCOPES: | |
| gt_pages = gt_scopes.get(scope, set()) | |
| pred_pages = pred_by_scope.get(scope, set()) | |
| # Jaccard | |
| j = jaccard(gt_pages, pred_pages) | |
| stmt_scores.append(j) | |
| # Update TP/FP/FN counts (page-level) | |
| tp = len(gt_pages & pred_pages) | |
| fp = len(pred_pages - gt_pages) | |
| fn = len(gt_pages - pred_pages) | |
| counts[(stmt, scope)]["tp"] += tp | |
| counts[(stmt, scope)]["fp"] += fp | |
| counts[(stmt, scope)]["fn"] += fn | |
| else: | |
| # Single scope in GT: compare GT pages to any predicted pages (scope-agnostic) | |
| # choose the GT scope name | |
| gt_scope = next(iter(gt_scopes.keys())) | |
| gt_pages = gt_scopes[gt_scope] | |
| pred_pages = pred_any_scope | |
| j = jaccard(gt_pages, pred_pages) | |
| stmt_scores.append(j) | |
| # For counting, attribute predicted pages to the GT scope | |
| tp = len(gt_pages & pred_pages) | |
| fp = len(pred_pages - gt_pages) | |
| fn = len(gt_pages - pred_pages) | |
| counts[(stmt, gt_scope)]["tp"] += tp | |
| counts[(stmt, gt_scope)]["fp"] += fp | |
| counts[(stmt, gt_scope)]["fn"] += fn | |
| else: | |
| # No GT for this statement: treat as not-applicable; but penalize false positives | |
| # Any predicted pages here are false positives for both scopes (we count under 'consolidated') | |
| pred_count = len(pred_any_scope) | |
| if pred_count > 0: | |
| counts[(stmt, "consolidated")]["fp"] += pred_count | |
| stmt_scores.append(1.0) # neutral / perfect since nothing to predict | |
| per_stmt_scores[stmt] = sum(stmt_scores) / max(1, len(stmt_scores)) | |
| # store a copy of counts per scope for this statement | |
| per_stmt_counts[stmt] = {s: counts[(stmt, s)].copy() for s in SCOPES} if stmt_scores else {} | |
| return { | |
| "gt_path": str(gt_path), | |
| "pred_path": str(pred_path), | |
| "per_stmt_scores": per_stmt_scores, | |
| "counts": counts, | |
| } | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--split", default="eval", help="Which split folder under dataset/ to use (default: eval)") | |
| args = ap.parse_args() | |
| base = Path("./dataset") | |
| split = base / args.split | |
| gt_dir = split / "GTs" | |
| pred_dir = split / "classifier_output" | |
| if not gt_dir.exists(): | |
| raise FileNotFoundError(f"GTs dir not found: {gt_dir}") | |
| if not pred_dir.exists(): | |
| raise FileNotFoundError(f"Predictions dir not found: {pred_dir}") | |
| gt_files = sorted([p for p in gt_dir.iterdir() if p.suffix.lower() == ".json"]) | |
| if not gt_files: | |
| print("No GT files found.") | |
| return | |
| total_counts = {(stmt, scope): {"tp": 0, "fp": 0, "fn": 0} for stmt in TARGETS for scope in SCOPES} | |
| per_file_scores = [] | |
| for gt_p in gt_files: | |
| stem = gt_p.stem | |
| pred_p = pred_dir / f"{stem}.json" | |
| if not pred_p.exists(): | |
| print(f"WARN: prediction missing for {stem}, skipping") | |
| continue | |
| res = evaluate_file(gt_p, pred_p) | |
| per_file_scores.append((stem, res["per_stmt_scores"])) | |
| # accumulate counts | |
| for k, v in res["counts"].items(): | |
| total_counts[k]["tp"] += v["tp"] | |
| total_counts[k]["fp"] += v["fp"] | |
| total_counts[k]["fn"] += v["fn"] | |
| # print per-file breakdown | |
| print(f"\nFile: {stem}") | |
| for stmt, score in res["per_stmt_scores"].items(): | |
| print(f" {stmt}: Jaccard={score:.3f}") | |
| # Aggregate metrics | |
| print("\n=== Aggregate metrics ===") | |
| stmt_scope_results: Dict[Tuple[str, str], Tuple[float, float, float]] = {} | |
| for stmt in TARGETS: | |
| for scope in SCOPES: | |
| tp = total_counts[(stmt, scope)]["tp"] | |
| fp = total_counts[(stmt, scope)]["fp"] | |
| fn = total_counts[(stmt, scope)]["fn"] | |
| p, r, f1 = precision_recall_f1(tp, fp, fn) | |
| stmt_scope_results[(stmt, scope)] = (p, r, f1) | |
| print(f"{stmt}/{scope}: TP={tp} FP={fp} FN={fn} P={p:.3f} R={r:.3f} F1={f1:.3f}") | |
| # Mean Jaccard across files and statements | |
| all_scores = [] | |
| for _, per in per_file_scores: | |
| for stmt in TARGETS: | |
| if stmt in per: | |
| all_scores.append(per[stmt]) | |
| mean_jaccard = sum(all_scores) / len(all_scores) if all_scores else 0.0 | |
| print(f"\nMean per-statement Jaccard (averaged over files and statements): {mean_jaccard:.3f}") | |
| if __name__ == "__main__": | |
| main() | |