Spaces:
Running
Running
| """CLI entrypoint for ocr-bench.""" | |
| from __future__ import annotations | |
| import argparse | |
| import sys | |
| import structlog | |
| from rich.console import Console | |
| from rich.table import Table | |
| from ocr_bench.backends import ( | |
| DEFAULT_JUDGE, | |
| DEFAULT_MAX_TOKENS, | |
| aggregate_jury_votes, | |
| parse_judge_spec, | |
| ) | |
| from ocr_bench.dataset import ( | |
| DatasetError, | |
| discover_configs, | |
| discover_pr_configs, | |
| load_config_dataset, | |
| load_flat_dataset, | |
| ) | |
| from ocr_bench.elo import ComparisonResult, Leaderboard, compute_elo, rankings_resolved | |
| from ocr_bench.judge import Comparison, _normalize_pair, build_comparisons, sample_indices | |
| from ocr_bench.publish import ( | |
| EvalMetadata, | |
| load_existing_comparisons, | |
| load_existing_metadata, | |
| publish_results, | |
| ) | |
| logger = structlog.get_logger() | |
| console = Console() | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser( | |
| prog="ocr-bench", | |
| description="OCR model evaluation toolkit — VLM-as-judge with per-dataset leaderboards", | |
| ) | |
| sub = parser.add_subparsers(dest="command") | |
| judge = sub.add_parser("judge", help="Run pairwise VLM judge on OCR outputs") | |
| # Dataset | |
| judge.add_argument("dataset", help="HF dataset repo id") | |
| judge.add_argument("--split", default="train", help="Dataset split (default: train)") | |
| judge.add_argument("--columns", nargs="+", default=None, help="Explicit OCR column names") | |
| judge.add_argument( | |
| "--configs", nargs="+", default=None, help="Config-per-model: list of config names" | |
| ) | |
| judge.add_argument("--from-prs", action="store_true", help="Force PR-based config discovery") | |
| judge.add_argument( | |
| "--merge", | |
| action="store_true", | |
| help="Merge PRs to main after discovery (default: load via revision)", | |
| ) | |
| # Judge | |
| judge.add_argument( | |
| "--model", | |
| action="append", | |
| dest="models", | |
| help=f"Judge model spec (repeatable for jury). Default: {DEFAULT_JUDGE}", | |
| ) | |
| # Eval | |
| judge.add_argument("--max-samples", type=int, default=None, help="Max samples to evaluate") | |
| judge.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)") | |
| judge.add_argument( | |
| "--max-tokens", | |
| type=int, | |
| default=DEFAULT_MAX_TOKENS, | |
| help=f"Max tokens for judge response (default: {DEFAULT_MAX_TOKENS})", | |
| ) | |
| # Output | |
| judge.add_argument( | |
| "--save-results", | |
| default=None, | |
| help="HF repo id to publish results to (default: {dataset}-results)", | |
| ) | |
| judge.add_argument( | |
| "--no-publish", | |
| action="store_true", | |
| help="Don't publish results (default: publish to {dataset}-results)", | |
| ) | |
| judge.add_argument( | |
| "--full-rejudge", | |
| action="store_true", | |
| help="Re-judge all pairs, ignoring existing comparisons in --save-results repo", | |
| ) | |
| judge.add_argument( | |
| "--no-adaptive", | |
| action="store_true", | |
| help="Disable adaptive stopping (default: adaptive is on)", | |
| ) | |
| judge.add_argument( | |
| "--concurrency", | |
| type=int, | |
| default=1, | |
| help="Number of concurrent judge API calls (default: 1)", | |
| ) | |
| # --- run subcommand --- | |
| run = sub.add_parser("run", help="Launch OCR models on a dataset via HF Jobs") | |
| run.add_argument("input_dataset", help="HF dataset repo id with images") | |
| run.add_argument("output_repo", help="Output dataset repo (all models push here)") | |
| run.add_argument( | |
| "--models", nargs="+", default=None, help="Model slugs to run (default: all 4 core)" | |
| ) | |
| run.add_argument("--max-samples", type=int, default=None, help="Per-model sample limit") | |
| run.add_argument("--split", default="train", help="Dataset split (default: train)") | |
| run.add_argument("--flavor", default=None, help="Override GPU flavor for all models") | |
| run.add_argument("--timeout", default="4h", help="Per-job timeout (default: 4h)") | |
| run.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)") | |
| run.add_argument("--shuffle", action="store_true", help="Shuffle source dataset") | |
| run.add_argument("--list-models", action="store_true", help="Print available models and exit") | |
| run.add_argument( | |
| "--dry-run", action="store_true", help="Show what would launch without launching" | |
| ) | |
| run.add_argument( | |
| "--no-wait", action="store_true", help="Launch and exit without polling (default: wait)" | |
| ) | |
| # --- view subcommand --- | |
| view = sub.add_parser("view", help="Browse and validate results in a web UI") | |
| view.add_argument("results", help="HF dataset repo id with published results") | |
| view.add_argument("--port", type=int, default=7860, help="Port (default: 7860)") | |
| view.add_argument("--host", default="127.0.0.1", help="Host (default: 127.0.0.1)") | |
| view.add_argument("--output", default=None, help="Path to save annotations JSON") | |
| return parser | |
| def print_leaderboard(board: Leaderboard) -> None: | |
| """Print leaderboard as a Rich table.""" | |
| table = Table(title="OCR Model Leaderboard") | |
| table.add_column("Rank", style="bold") | |
| table.add_column("Model") | |
| has_ci = bool(board.elo_ci) | |
| if has_ci: | |
| table.add_column("ELO (95% CI)", justify="right") | |
| else: | |
| table.add_column("ELO", justify="right") | |
| table.add_column("Wins", justify="right") | |
| table.add_column("Losses", justify="right") | |
| table.add_column("Ties", justify="right") | |
| table.add_column("Win%", justify="right") | |
| for rank, (model, elo) in enumerate(board.ranked, 1): | |
| pct = board.win_pct(model) | |
| pct_str = f"{pct:.0f}%" if pct is not None else "-" | |
| if has_ci and model in board.elo_ci: | |
| lo, hi = board.elo_ci[model] | |
| elo_str = f"{round(elo)} ({round(lo)}\u2013{round(hi)})" | |
| else: | |
| elo_str = str(round(elo)) | |
| table.add_row( | |
| str(rank), | |
| model, | |
| elo_str, | |
| str(board.wins[model]), | |
| str(board.losses[model]), | |
| str(board.ties[model]), | |
| pct_str, | |
| ) | |
| console.print(table) | |
| def _convert_results( | |
| comparisons: list[Comparison], aggregated: list[dict] | |
| ) -> list[ComparisonResult]: | |
| """Convert judged comparisons + aggregated outputs into ComparisonResult list.""" | |
| results: list[ComparisonResult] = [] | |
| for comp, result in zip(comparisons, aggregated): | |
| if not result: | |
| continue | |
| results.append( | |
| ComparisonResult( | |
| sample_idx=comp.sample_idx, | |
| model_a=comp.model_a, | |
| model_b=comp.model_b, | |
| winner=result.get("winner", "tie"), | |
| reason=result.get("reason", ""), | |
| agreement=result.get("agreement", "1/1"), | |
| swapped=comp.swapped, | |
| text_a=comp.text_a, | |
| text_b=comp.text_b, | |
| col_a=comp.col_a, | |
| col_b=comp.col_b, | |
| ) | |
| ) | |
| return results | |
| def _resolve_results_repo(dataset: str, save_results: str | None, no_publish: bool) -> str | None: | |
| """Derive the results repo id. Returns None if publishing is disabled.""" | |
| if no_publish: | |
| return None | |
| if save_results: | |
| return save_results | |
| return f"{dataset}-results" | |
| def cmd_judge(args: argparse.Namespace) -> None: | |
| """Orchestrate: load → compare → judge → elo → print → publish.""" | |
| # --- Resolve flags --- | |
| adaptive = not args.no_adaptive | |
| merge = args.merge | |
| results_repo = _resolve_results_repo(args.dataset, args.save_results, args.no_publish) | |
| from_prs = False # track for metadata | |
| if results_repo: | |
| console.print(f"Results will be published to [bold]{results_repo}[/bold]") | |
| # --- Load dataset (cascading auto-detection) --- | |
| if args.configs: | |
| # Explicit configs — use them directly | |
| config_names = args.configs | |
| ds, ocr_columns = load_config_dataset(args.dataset, config_names, split=args.split) | |
| elif args.columns: | |
| # Explicit columns — flat loading | |
| ds, ocr_columns = load_flat_dataset(args.dataset, split=args.split, columns=args.columns) | |
| elif args.from_prs: | |
| # Forced PR discovery | |
| config_names, pr_revisions = discover_pr_configs(args.dataset, merge=merge) | |
| if not config_names: | |
| raise DatasetError("No configs found in open PRs") | |
| from_prs = True | |
| console.print(f"Discovered {len(config_names)} configs from PRs: {config_names}") | |
| ds, ocr_columns = load_config_dataset( | |
| args.dataset, | |
| config_names, | |
| split=args.split, | |
| pr_revisions=pr_revisions if not merge else None, | |
| ) | |
| else: | |
| # Auto-detect: PRs + main branch configs combined, fall back to flat | |
| pr_configs, pr_revisions = discover_pr_configs(args.dataset, merge=merge) | |
| main_configs = discover_configs(args.dataset) | |
| # Combine: PR configs + main configs not already in PRs | |
| config_names = list(pr_configs) | |
| for mc in main_configs: | |
| if mc not in pr_configs: | |
| config_names.append(mc) | |
| if config_names: | |
| if pr_configs: | |
| from_prs = True | |
| console.print(f"Auto-detected {len(pr_configs)} configs from PRs: {pr_configs}") | |
| if main_configs: | |
| main_only = [c for c in main_configs if c not in pr_configs] | |
| if main_only: | |
| console.print(f"Auto-detected {len(main_only)} configs on main: {main_only}") | |
| ds, ocr_columns = load_config_dataset( | |
| args.dataset, | |
| config_names, | |
| split=args.split, | |
| pr_revisions=pr_revisions if pr_configs else None, | |
| ) | |
| else: | |
| # No configs anywhere — fall back to flat loading | |
| ds, ocr_columns = load_flat_dataset(args.dataset, split=args.split) | |
| console.print(f"Loaded {len(ds)} samples with {len(ocr_columns)} models:") | |
| for col, model in ocr_columns.items(): | |
| console.print(f" {col} → {model}") | |
| # --- Incremental: load existing comparisons --- | |
| existing_results: list[ComparisonResult] = [] | |
| existing_meta_rows: list[dict] = [] | |
| skip_pairs: set[tuple[str, str]] | None = None | |
| if results_repo and not args.full_rejudge: | |
| existing_results = load_existing_comparisons(results_repo) | |
| if existing_results: | |
| judged_pairs = {_normalize_pair(r.model_a, r.model_b) for r in existing_results} | |
| skip_pairs = judged_pairs | |
| console.print( | |
| f"\nIncremental mode: {len(existing_results)} existing comparisons " | |
| f"across {len(judged_pairs)} model pairs — skipping those." | |
| ) | |
| existing_meta_rows = load_existing_metadata(results_repo) | |
| else: | |
| console.print("\nNo existing comparisons found — full judge run.") | |
| model_names = list(set(ocr_columns.values())) | |
| # --- Judge setup (shared by both paths) --- | |
| model_specs = args.models or [DEFAULT_JUDGE] | |
| judges = [ | |
| parse_judge_spec(spec, max_tokens=args.max_tokens, concurrency=args.concurrency) | |
| for spec in model_specs | |
| ] | |
| is_jury = len(judges) > 1 | |
| def _judge_batch(batch_comps: list[Comparison]) -> list[ComparisonResult]: | |
| """Run judge(s) on a batch of comparisons and return ComparisonResults.""" | |
| all_judge_outputs: list[list[dict]] = [] | |
| for judge in judges: | |
| results = judge.judge(batch_comps) | |
| all_judge_outputs.append(results) | |
| if is_jury: | |
| judge_names = [j.name for j in judges] | |
| aggregated = aggregate_jury_votes(all_judge_outputs, judge_names) | |
| else: | |
| aggregated = all_judge_outputs[0] | |
| return _convert_results(batch_comps, aggregated) | |
| if adaptive: | |
| # --- Adaptive stopping: batch-by-batch with convergence check --- | |
| from itertools import combinations as _combs | |
| all_indices = sample_indices(len(ds), args.max_samples, args.seed) | |
| n_pairs = len(list(_combs(model_names, 2))) | |
| batch_samples = 5 | |
| min_before_check = max(3 * n_pairs, 20) | |
| if is_jury: | |
| console.print(f"\nJury mode: {len(judges)} judges") | |
| console.print( | |
| f"\n[bold]Adaptive mode[/bold]: {len(all_indices)} samples, " | |
| f"{n_pairs} pairs, batch size {batch_samples}, " | |
| f"checking after {min_before_check} comparisons" | |
| ) | |
| new_results: list[ComparisonResult] = [] | |
| total_comparisons = 0 | |
| for batch_num, batch_start in enumerate(range(0, len(all_indices), batch_samples)): | |
| batch_indices = all_indices[batch_start : batch_start + batch_samples] | |
| batch_comps = build_comparisons( | |
| ds, | |
| ocr_columns, | |
| skip_pairs=skip_pairs, | |
| indices=batch_indices, | |
| seed=args.seed, | |
| ) | |
| if not batch_comps: | |
| continue | |
| batch_results = _judge_batch(batch_comps) | |
| new_results.extend(batch_results) | |
| total_comparisons += len(batch_comps) | |
| # batch_comps goes out of scope → GC can free images | |
| total = len(existing_results) + len(new_results) | |
| console.print(f" Batch {batch_num + 1}: {len(batch_results)} new, {total} total") | |
| if total >= min_before_check: | |
| board = compute_elo(existing_results + new_results, model_names) | |
| # Show CI gaps for each adjacent pair | |
| ranked = board.ranked | |
| if board.elo_ci: | |
| gaps: list[str] = [] | |
| for i in range(len(ranked) - 1): | |
| hi_model, _ = ranked[i] | |
| lo_model, _ = ranked[i + 1] | |
| hi_ci = board.elo_ci.get(hi_model) | |
| lo_ci = board.elo_ci.get(lo_model) | |
| if hi_ci and lo_ci: | |
| gap = hi_ci[0] - lo_ci[1] # positive = resolved | |
| if gap > 0: | |
| status = "[green]ok[/green]" | |
| else: | |
| status = f"[yellow]overlap {-gap:.0f}[/yellow]" | |
| gaps.append(f" {hi_model} vs {lo_model}: gap={gap:+.0f} {status}") | |
| if gaps: | |
| console.print(" CI gaps:") | |
| for g in gaps: | |
| console.print(g) | |
| if rankings_resolved(board): | |
| remaining = len(all_indices) - batch_start - len(batch_indices) | |
| console.print( | |
| f"[green]Rankings converged after {total} comparisons! " | |
| f"Skipped ~{remaining * n_pairs} remaining.[/green]" | |
| ) | |
| break | |
| console.print(f"\n{len(new_results)}/{total_comparisons} valid comparisons") | |
| else: | |
| # --- Standard single-pass flow --- | |
| comparisons = build_comparisons( | |
| ds, | |
| ocr_columns, | |
| max_samples=args.max_samples, | |
| seed=args.seed, | |
| skip_pairs=skip_pairs, | |
| ) | |
| console.print(f"\nBuilt {len(comparisons)} new pairwise comparisons") | |
| if not comparisons and not existing_results: | |
| console.print( | |
| "[yellow]No valid comparisons — check that OCR columns have text.[/yellow]" | |
| ) | |
| return | |
| if not comparisons: | |
| console.print("[green]All pairs already judged — refitting leaderboard.[/green]") | |
| board = compute_elo(existing_results, model_names) | |
| console.print() | |
| print_leaderboard(board) | |
| if results_repo: | |
| metadata = EvalMetadata( | |
| source_dataset=args.dataset, | |
| judge_models=[], | |
| seed=args.seed, | |
| max_samples=args.max_samples or len(ds), | |
| total_comparisons=0, | |
| valid_comparisons=0, | |
| from_prs=from_prs, | |
| ) | |
| publish_results( | |
| results_repo, | |
| board, | |
| metadata, | |
| existing_metadata=existing_meta_rows, | |
| ) | |
| console.print(f"\nResults published to [bold]{results_repo}[/bold]") | |
| return | |
| if is_jury: | |
| console.print(f"\nJury mode: {len(judges)} judges") | |
| for judge in judges: | |
| console.print(f"\nRunning judge: {judge.name}") | |
| new_results = _judge_batch(comparisons) | |
| total_comparisons = len(comparisons) | |
| console.print(f"\n{len(new_results)}/{total_comparisons} valid comparisons") | |
| # --- Merge existing + new, compute ELO --- | |
| all_results = existing_results + new_results | |
| board = compute_elo(all_results, model_names) | |
| console.print() | |
| print_leaderboard(board) | |
| # --- Publish --- | |
| if results_repo: | |
| metadata = EvalMetadata( | |
| source_dataset=args.dataset, | |
| judge_models=[j.name for j in judges], | |
| seed=args.seed, | |
| max_samples=args.max_samples or len(ds), | |
| total_comparisons=total_comparisons, | |
| valid_comparisons=len(new_results), | |
| from_prs=from_prs, | |
| ) | |
| publish_results(results_repo, board, metadata, existing_metadata=existing_meta_rows) | |
| console.print(f"\nResults published to [bold]{results_repo}[/bold]") | |
| def cmd_run(args: argparse.Namespace) -> None: | |
| """Launch OCR models on a dataset via HF Jobs.""" | |
| from ocr_bench.run import ( | |
| DEFAULT_MODELS, | |
| MODEL_REGISTRY, | |
| build_script_args, | |
| launch_ocr_jobs, | |
| poll_jobs, | |
| ) | |
| # --list-models | |
| if args.list_models: | |
| table = Table(title="Available OCR Models", show_lines=True) | |
| table.add_column("Slug", style="cyan bold") | |
| table.add_column("Model ID") | |
| table.add_column("Size", justify="right") | |
| table.add_column("Default GPU", justify="center") | |
| for slug in sorted(MODEL_REGISTRY): | |
| cfg = MODEL_REGISTRY[slug] | |
| default = " (default)" if slug in DEFAULT_MODELS else "" | |
| table.add_row(slug + default, cfg.model_id, cfg.size, cfg.default_flavor) | |
| console.print(table) | |
| console.print(f"\nDefault set: {', '.join(DEFAULT_MODELS)}") | |
| return | |
| selected = args.models or DEFAULT_MODELS | |
| for slug in selected: | |
| if slug not in MODEL_REGISTRY: | |
| console.print(f"[red]Unknown model: {slug}[/red]") | |
| console.print(f"Available: {', '.join(MODEL_REGISTRY.keys())}") | |
| sys.exit(1) | |
| console.print("\n[bold]OCR Benchmark Run[/bold]") | |
| console.print(f" Source: {args.input_dataset}") | |
| console.print(f" Output: {args.output_repo}") | |
| console.print(f" Models: {', '.join(selected)}") | |
| if args.max_samples: | |
| console.print(f" Samples: {args.max_samples} per model") | |
| console.print() | |
| # Dry run | |
| if args.dry_run: | |
| console.print("[bold yellow]DRY RUN[/bold yellow] — no jobs will be launched\n") | |
| for slug in selected: | |
| cfg = MODEL_REGISTRY[slug] | |
| flavor = args.flavor or cfg.default_flavor | |
| script_args = build_script_args( | |
| args.input_dataset, | |
| args.output_repo, | |
| slug, | |
| max_samples=args.max_samples, | |
| shuffle=args.shuffle, | |
| seed=args.seed, | |
| extra_args=cfg.default_args or None, | |
| ) | |
| console.print(f"[cyan]{slug}[/cyan] ({cfg.model_id})") | |
| console.print(f" Flavor: {flavor}") | |
| console.print(f" Timeout: {args.timeout}") | |
| console.print(f" Script: {cfg.script}") | |
| console.print(f" Args: {' '.join(script_args)}") | |
| console.print() | |
| console.print("Remove --dry-run to launch these jobs.") | |
| return | |
| # Launch | |
| jobs = launch_ocr_jobs( | |
| args.input_dataset, | |
| args.output_repo, | |
| models=selected, | |
| max_samples=args.max_samples, | |
| split=args.split, | |
| shuffle=args.shuffle, | |
| seed=args.seed, | |
| flavor_override=args.flavor, | |
| timeout=args.timeout, | |
| ) | |
| console.print(f"\n[green]{len(jobs)} jobs launched.[/green]") | |
| for job in jobs: | |
| console.print(f" [cyan]{job.model_slug}[/cyan]: {job.job_url}") | |
| if not args.no_wait: | |
| console.print("\n[bold]Waiting for jobs to complete...[/bold]") | |
| poll_jobs(jobs) | |
| console.print("\n[bold green]All jobs finished![/bold green]") | |
| console.print("\nEvaluate:") | |
| console.print(f" ocr-bench judge {args.output_repo}") | |
| else: | |
| console.print("\nJobs running in background.") | |
| console.print("Check status at: https://huggingface.co/settings/jobs") | |
| console.print(f"When complete: ocr-bench judge {args.output_repo}") | |
| def cmd_view(args: argparse.Namespace) -> None: | |
| """Launch the FastAPI + HTMX results viewer.""" | |
| try: | |
| import uvicorn | |
| from ocr_bench.web import create_app | |
| except ImportError: | |
| console.print( | |
| "[red]Error:[/red] FastAPI/uvicorn not installed. " | |
| "Install the viewer extra: [bold]pip install ocr-bench\\[viewer][/bold]" | |
| ) | |
| sys.exit(1) | |
| console.print(f"Loading results from [bold]{args.results}[/bold]...") | |
| app = create_app(args.results, output_path=args.output) | |
| console.print(f"Starting viewer at [bold]http://{args.host}:{args.port}[/bold]") | |
| uvicorn.run(app, host=args.host, port=args.port) | |
| def main() -> None: | |
| parser = build_parser() | |
| args = parser.parse_args() | |
| if args.command is None: | |
| parser.print_help() | |
| sys.exit(0) | |
| try: | |
| if args.command == "judge": | |
| cmd_judge(args) | |
| elif args.command == "run": | |
| cmd_run(args) | |
| elif args.command == "view": | |
| cmd_view(args) | |
| except DatasetError as exc: | |
| console.print(f"[red]Error:[/red] {exc}") | |
| sys.exit(1) | |