| | """CLI entrypoint for ocr-bench.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import sys |
| |
|
| | import structlog |
| | from rich.console import Console |
| | from rich.table import Table |
| |
|
| | from ocr_bench.backends import ( |
| | DEFAULT_JUDGE, |
| | DEFAULT_MAX_TOKENS, |
| | aggregate_jury_votes, |
| | parse_judge_spec, |
| | ) |
| | from ocr_bench.dataset import ( |
| | DatasetError, |
| | discover_configs, |
| | discover_pr_configs, |
| | load_config_dataset, |
| | load_flat_dataset, |
| | ) |
| | from ocr_bench.elo import ComparisonResult, Leaderboard, compute_elo, rankings_resolved |
| | from ocr_bench.judge import Comparison, _normalize_pair, build_comparisons, sample_indices |
| | from ocr_bench.publish import ( |
| | EvalMetadata, |
| | load_existing_comparisons, |
| | load_existing_metadata, |
| | publish_results, |
| | ) |
| |
|
| | logger = structlog.get_logger() |
| | console = Console() |
| |
|
| |
|
| | def build_parser() -> argparse.ArgumentParser: |
| | parser = argparse.ArgumentParser( |
| | prog="ocr-bench", |
| | description="OCR model evaluation toolkit — VLM-as-judge with per-dataset leaderboards", |
| | ) |
| | sub = parser.add_subparsers(dest="command") |
| |
|
| | judge = sub.add_parser("judge", help="Run pairwise VLM judge on OCR outputs") |
| |
|
| | |
| | judge.add_argument("dataset", help="HF dataset repo id") |
| | judge.add_argument("--split", default="train", help="Dataset split (default: train)") |
| | judge.add_argument("--columns", nargs="+", default=None, help="Explicit OCR column names") |
| | judge.add_argument( |
| | "--configs", nargs="+", default=None, help="Config-per-model: list of config names" |
| | ) |
| | judge.add_argument("--from-prs", action="store_true", help="Force PR-based config discovery") |
| | judge.add_argument( |
| | "--merge", |
| | action="store_true", |
| | help="Merge PRs to main after discovery (default: load via revision)", |
| | ) |
| |
|
| | |
| | judge.add_argument( |
| | "--model", |
| | action="append", |
| | dest="models", |
| | help=f"Judge model spec (repeatable for jury). Default: {DEFAULT_JUDGE}", |
| | ) |
| |
|
| | |
| | judge.add_argument("--max-samples", type=int, default=None, help="Max samples to evaluate") |
| | judge.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)") |
| | judge.add_argument( |
| | "--max-tokens", |
| | type=int, |
| | default=DEFAULT_MAX_TOKENS, |
| | help=f"Max tokens for judge response (default: {DEFAULT_MAX_TOKENS})", |
| | ) |
| |
|
| | |
| | judge.add_argument( |
| | "--save-results", |
| | default=None, |
| | help="HF repo id to publish results to (default: {dataset}-results)", |
| | ) |
| | judge.add_argument( |
| | "--no-publish", |
| | action="store_true", |
| | help="Don't publish results (default: publish to {dataset}-results)", |
| | ) |
| | judge.add_argument( |
| | "--full-rejudge", |
| | action="store_true", |
| | help="Re-judge all pairs, ignoring existing comparisons in --save-results repo", |
| | ) |
| | judge.add_argument( |
| | "--no-adaptive", |
| | action="store_true", |
| | help="Disable adaptive stopping (default: adaptive is on)", |
| | ) |
| | judge.add_argument( |
| | "--concurrency", |
| | type=int, |
| | default=1, |
| | help="Number of concurrent judge API calls (default: 1)", |
| | ) |
| |
|
| | |
| | run = sub.add_parser("run", help="Launch OCR models on a dataset via HF Jobs") |
| | run.add_argument("input_dataset", help="HF dataset repo id with images") |
| | run.add_argument("output_repo", help="Output dataset repo (all models push here)") |
| | run.add_argument( |
| | "--models", nargs="+", default=None, help="Model slugs to run (default: all 4 core)" |
| | ) |
| | run.add_argument("--max-samples", type=int, default=None, help="Per-model sample limit") |
| | run.add_argument("--split", default="train", help="Dataset split (default: train)") |
| | run.add_argument("--flavor", default=None, help="Override GPU flavor for all models") |
| | run.add_argument("--timeout", default="4h", help="Per-job timeout (default: 4h)") |
| | run.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)") |
| | run.add_argument("--shuffle", action="store_true", help="Shuffle source dataset") |
| | run.add_argument("--list-models", action="store_true", help="Print available models and exit") |
| | run.add_argument( |
| | "--dry-run", action="store_true", help="Show what would launch without launching" |
| | ) |
| | run.add_argument( |
| | "--no-wait", action="store_true", help="Launch and exit without polling (default: wait)" |
| | ) |
| |
|
| | |
| | view = sub.add_parser("view", help="Browse and validate results in a web UI") |
| | view.add_argument("results", help="HF dataset repo id with published results") |
| | view.add_argument("--port", type=int, default=7860, help="Port (default: 7860)") |
| | view.add_argument("--host", default="127.0.0.1", help="Host (default: 127.0.0.1)") |
| | view.add_argument("--output", default=None, help="Path to save annotations JSON") |
| |
|
| | return parser |
| |
|
| |
|
| | def print_leaderboard(board: Leaderboard) -> None: |
| | """Print leaderboard as a Rich table.""" |
| | table = Table(title="OCR Model Leaderboard") |
| | table.add_column("Rank", style="bold") |
| | table.add_column("Model") |
| | has_ci = bool(board.elo_ci) |
| | if has_ci: |
| | table.add_column("ELO (95% CI)", justify="right") |
| | else: |
| | table.add_column("ELO", justify="right") |
| | table.add_column("Wins", justify="right") |
| | table.add_column("Losses", justify="right") |
| | table.add_column("Ties", justify="right") |
| | table.add_column("Win%", justify="right") |
| |
|
| | for rank, (model, elo) in enumerate(board.ranked, 1): |
| | pct = board.win_pct(model) |
| | pct_str = f"{pct:.0f}%" if pct is not None else "-" |
| | if has_ci and model in board.elo_ci: |
| | lo, hi = board.elo_ci[model] |
| | elo_str = f"{round(elo)} ({round(lo)}\u2013{round(hi)})" |
| | else: |
| | elo_str = str(round(elo)) |
| | table.add_row( |
| | str(rank), |
| | model, |
| | elo_str, |
| | str(board.wins[model]), |
| | str(board.losses[model]), |
| | str(board.ties[model]), |
| | pct_str, |
| | ) |
| |
|
| | console.print(table) |
| |
|
| |
|
| | def _convert_results( |
| | comparisons: list[Comparison], aggregated: list[dict] |
| | ) -> list[ComparisonResult]: |
| | """Convert judged comparisons + aggregated outputs into ComparisonResult list.""" |
| | results: list[ComparisonResult] = [] |
| | for comp, result in zip(comparisons, aggregated): |
| | if not result: |
| | continue |
| | results.append( |
| | ComparisonResult( |
| | sample_idx=comp.sample_idx, |
| | model_a=comp.model_a, |
| | model_b=comp.model_b, |
| | winner=result.get("winner", "tie"), |
| | reason=result.get("reason", ""), |
| | agreement=result.get("agreement", "1/1"), |
| | swapped=comp.swapped, |
| | text_a=comp.text_a, |
| | text_b=comp.text_b, |
| | col_a=comp.col_a, |
| | col_b=comp.col_b, |
| | ) |
| | ) |
| | return results |
| |
|
| |
|
| | def _resolve_results_repo(dataset: str, save_results: str | None, no_publish: bool) -> str | None: |
| | """Derive the results repo id. Returns None if publishing is disabled.""" |
| | if no_publish: |
| | return None |
| | if save_results: |
| | return save_results |
| | return f"{dataset}-results" |
| |
|
| |
|
| | def cmd_judge(args: argparse.Namespace) -> None: |
| | """Orchestrate: load → compare → judge → elo → print → publish.""" |
| | |
| | adaptive = not args.no_adaptive |
| | merge = args.merge |
| | results_repo = _resolve_results_repo(args.dataset, args.save_results, args.no_publish) |
| | from_prs = False |
| |
|
| | if results_repo: |
| | console.print(f"Results will be published to [bold]{results_repo}[/bold]") |
| |
|
| | |
| | if args.configs: |
| | |
| | config_names = args.configs |
| | ds, ocr_columns = load_config_dataset(args.dataset, config_names, split=args.split) |
| | elif args.columns: |
| | |
| | ds, ocr_columns = load_flat_dataset(args.dataset, split=args.split, columns=args.columns) |
| | elif args.from_prs: |
| | |
| | config_names, pr_revisions = discover_pr_configs(args.dataset, merge=merge) |
| | if not config_names: |
| | raise DatasetError("No configs found in open PRs") |
| | from_prs = True |
| | console.print(f"Discovered {len(config_names)} configs from PRs: {config_names}") |
| | ds, ocr_columns = load_config_dataset( |
| | args.dataset, |
| | config_names, |
| | split=args.split, |
| | pr_revisions=pr_revisions if not merge else None, |
| | ) |
| | else: |
| | |
| | pr_configs, pr_revisions = discover_pr_configs(args.dataset, merge=merge) |
| | main_configs = discover_configs(args.dataset) |
| |
|
| | |
| | config_names = list(pr_configs) |
| | for mc in main_configs: |
| | if mc not in pr_configs: |
| | config_names.append(mc) |
| |
|
| | if config_names: |
| | if pr_configs: |
| | from_prs = True |
| | console.print(f"Auto-detected {len(pr_configs)} configs from PRs: {pr_configs}") |
| | if main_configs: |
| | main_only = [c for c in main_configs if c not in pr_configs] |
| | if main_only: |
| | console.print(f"Auto-detected {len(main_only)} configs on main: {main_only}") |
| | ds, ocr_columns = load_config_dataset( |
| | args.dataset, |
| | config_names, |
| | split=args.split, |
| | pr_revisions=pr_revisions if pr_configs else None, |
| | ) |
| | else: |
| | |
| | ds, ocr_columns = load_flat_dataset(args.dataset, split=args.split) |
| |
|
| | console.print(f"Loaded {len(ds)} samples with {len(ocr_columns)} models:") |
| | for col, model in ocr_columns.items(): |
| | console.print(f" {col} → {model}") |
| |
|
| | |
| | existing_results: list[ComparisonResult] = [] |
| | existing_meta_rows: list[dict] = [] |
| | skip_pairs: set[tuple[str, str]] | None = None |
| |
|
| | if results_repo and not args.full_rejudge: |
| | existing_results = load_existing_comparisons(results_repo) |
| | if existing_results: |
| | judged_pairs = {_normalize_pair(r.model_a, r.model_b) for r in existing_results} |
| | skip_pairs = judged_pairs |
| | console.print( |
| | f"\nIncremental mode: {len(existing_results)} existing comparisons " |
| | f"across {len(judged_pairs)} model pairs — skipping those." |
| | ) |
| | existing_meta_rows = load_existing_metadata(results_repo) |
| | else: |
| | console.print("\nNo existing comparisons found — full judge run.") |
| |
|
| | model_names = list(set(ocr_columns.values())) |
| |
|
| | |
| | model_specs = args.models or [DEFAULT_JUDGE] |
| | judges = [ |
| | parse_judge_spec(spec, max_tokens=args.max_tokens, concurrency=args.concurrency) |
| | for spec in model_specs |
| | ] |
| | is_jury = len(judges) > 1 |
| |
|
| | def _judge_batch(batch_comps: list[Comparison]) -> list[ComparisonResult]: |
| | """Run judge(s) on a batch of comparisons and return ComparisonResults.""" |
| | all_judge_outputs: list[list[dict]] = [] |
| | for judge in judges: |
| | results = judge.judge(batch_comps) |
| | all_judge_outputs.append(results) |
| | if is_jury: |
| | judge_names = [j.name for j in judges] |
| | aggregated = aggregate_jury_votes(all_judge_outputs, judge_names) |
| | else: |
| | aggregated = all_judge_outputs[0] |
| | return _convert_results(batch_comps, aggregated) |
| |
|
| | if adaptive: |
| | |
| | from itertools import combinations as _combs |
| |
|
| | all_indices = sample_indices(len(ds), args.max_samples, args.seed) |
| | n_pairs = len(list(_combs(model_names, 2))) |
| | batch_samples = 5 |
| | min_before_check = max(3 * n_pairs, 20) |
| |
|
| | if is_jury: |
| | console.print(f"\nJury mode: {len(judges)} judges") |
| | console.print( |
| | f"\n[bold]Adaptive mode[/bold]: {len(all_indices)} samples, " |
| | f"{n_pairs} pairs, batch size {batch_samples}, " |
| | f"checking after {min_before_check} comparisons" |
| | ) |
| |
|
| | new_results: list[ComparisonResult] = [] |
| | total_comparisons = 0 |
| | for batch_num, batch_start in enumerate(range(0, len(all_indices), batch_samples)): |
| | batch_indices = all_indices[batch_start : batch_start + batch_samples] |
| | batch_comps = build_comparisons( |
| | ds, |
| | ocr_columns, |
| | skip_pairs=skip_pairs, |
| | indices=batch_indices, |
| | seed=args.seed, |
| | ) |
| | if not batch_comps: |
| | continue |
| |
|
| | batch_results = _judge_batch(batch_comps) |
| | new_results.extend(batch_results) |
| | total_comparisons += len(batch_comps) |
| | |
| |
|
| | total = len(existing_results) + len(new_results) |
| | console.print(f" Batch {batch_num + 1}: {len(batch_results)} new, {total} total") |
| |
|
| | if total >= min_before_check: |
| | board = compute_elo(existing_results + new_results, model_names) |
| | |
| | ranked = board.ranked |
| | if board.elo_ci: |
| | gaps: list[str] = [] |
| | for i in range(len(ranked) - 1): |
| | hi_model, _ = ranked[i] |
| | lo_model, _ = ranked[i + 1] |
| | hi_ci = board.elo_ci.get(hi_model) |
| | lo_ci = board.elo_ci.get(lo_model) |
| | if hi_ci and lo_ci: |
| | gap = hi_ci[0] - lo_ci[1] |
| | if gap > 0: |
| | status = "[green]ok[/green]" |
| | else: |
| | status = f"[yellow]overlap {-gap:.0f}[/yellow]" |
| | gaps.append(f" {hi_model} vs {lo_model}: gap={gap:+.0f} {status}") |
| | if gaps: |
| | console.print(" CI gaps:") |
| | for g in gaps: |
| | console.print(g) |
| |
|
| | if rankings_resolved(board): |
| | remaining = len(all_indices) - batch_start - len(batch_indices) |
| | console.print( |
| | f"[green]Rankings converged after {total} comparisons! " |
| | f"Skipped ~{remaining * n_pairs} remaining.[/green]" |
| | ) |
| | break |
| |
|
| | console.print(f"\n{len(new_results)}/{total_comparisons} valid comparisons") |
| | else: |
| | |
| | comparisons = build_comparisons( |
| | ds, |
| | ocr_columns, |
| | max_samples=args.max_samples, |
| | seed=args.seed, |
| | skip_pairs=skip_pairs, |
| | ) |
| | console.print(f"\nBuilt {len(comparisons)} new pairwise comparisons") |
| |
|
| | if not comparisons and not existing_results: |
| | console.print( |
| | "[yellow]No valid comparisons — check that OCR columns have text.[/yellow]" |
| | ) |
| | return |
| |
|
| | if not comparisons: |
| | console.print("[green]All pairs already judged — refitting leaderboard.[/green]") |
| | board = compute_elo(existing_results, model_names) |
| | console.print() |
| | print_leaderboard(board) |
| | if results_repo: |
| | metadata = EvalMetadata( |
| | source_dataset=args.dataset, |
| | judge_models=[], |
| | seed=args.seed, |
| | max_samples=args.max_samples or len(ds), |
| | total_comparisons=0, |
| | valid_comparisons=0, |
| | from_prs=from_prs, |
| | ) |
| | publish_results( |
| | results_repo, |
| | board, |
| | metadata, |
| | existing_metadata=existing_meta_rows, |
| | ) |
| | console.print(f"\nResults published to [bold]{results_repo}[/bold]") |
| | return |
| |
|
| | if is_jury: |
| | console.print(f"\nJury mode: {len(judges)} judges") |
| |
|
| | for judge in judges: |
| | console.print(f"\nRunning judge: {judge.name}") |
| |
|
| | new_results = _judge_batch(comparisons) |
| | total_comparisons = len(comparisons) |
| | console.print(f"\n{len(new_results)}/{total_comparisons} valid comparisons") |
| |
|
| | |
| | all_results = existing_results + new_results |
| | board = compute_elo(all_results, model_names) |
| | console.print() |
| | print_leaderboard(board) |
| |
|
| | |
| | if results_repo: |
| | metadata = EvalMetadata( |
| | source_dataset=args.dataset, |
| | judge_models=[j.name for j in judges], |
| | seed=args.seed, |
| | max_samples=args.max_samples or len(ds), |
| | total_comparisons=total_comparisons, |
| | valid_comparisons=len(new_results), |
| | from_prs=from_prs, |
| | ) |
| | publish_results(results_repo, board, metadata, existing_metadata=existing_meta_rows) |
| | console.print(f"\nResults published to [bold]{results_repo}[/bold]") |
| |
|
| |
|
| | def cmd_run(args: argparse.Namespace) -> None: |
| | """Launch OCR models on a dataset via HF Jobs.""" |
| | from ocr_bench.run import ( |
| | DEFAULT_MODELS, |
| | MODEL_REGISTRY, |
| | build_script_args, |
| | launch_ocr_jobs, |
| | poll_jobs, |
| | ) |
| |
|
| | |
| | if args.list_models: |
| | table = Table(title="Available OCR Models", show_lines=True) |
| | table.add_column("Slug", style="cyan bold") |
| | table.add_column("Model ID") |
| | table.add_column("Size", justify="right") |
| | table.add_column("Default GPU", justify="center") |
| |
|
| | for slug in sorted(MODEL_REGISTRY): |
| | cfg = MODEL_REGISTRY[slug] |
| | default = " (default)" if slug in DEFAULT_MODELS else "" |
| | table.add_row(slug + default, cfg.model_id, cfg.size, cfg.default_flavor) |
| |
|
| | console.print(table) |
| | console.print(f"\nDefault set: {', '.join(DEFAULT_MODELS)}") |
| | return |
| |
|
| | selected = args.models or DEFAULT_MODELS |
| | for slug in selected: |
| | if slug not in MODEL_REGISTRY: |
| | console.print(f"[red]Unknown model: {slug}[/red]") |
| | console.print(f"Available: {', '.join(MODEL_REGISTRY.keys())}") |
| | sys.exit(1) |
| |
|
| | console.print("\n[bold]OCR Benchmark Run[/bold]") |
| | console.print(f" Source: {args.input_dataset}") |
| | console.print(f" Output: {args.output_repo}") |
| | console.print(f" Models: {', '.join(selected)}") |
| | if args.max_samples: |
| | console.print(f" Samples: {args.max_samples} per model") |
| | console.print() |
| |
|
| | |
| | if args.dry_run: |
| | console.print("[bold yellow]DRY RUN[/bold yellow] — no jobs will be launched\n") |
| | for slug in selected: |
| | cfg = MODEL_REGISTRY[slug] |
| | flavor = args.flavor or cfg.default_flavor |
| | script_args = build_script_args( |
| | args.input_dataset, |
| | args.output_repo, |
| | slug, |
| | max_samples=args.max_samples, |
| | shuffle=args.shuffle, |
| | seed=args.seed, |
| | extra_args=cfg.default_args or None, |
| | ) |
| | console.print(f"[cyan]{slug}[/cyan] ({cfg.model_id})") |
| | console.print(f" Flavor: {flavor}") |
| | console.print(f" Timeout: {args.timeout}") |
| | console.print(f" Script: {cfg.script}") |
| | console.print(f" Args: {' '.join(script_args)}") |
| | console.print() |
| | console.print("Remove --dry-run to launch these jobs.") |
| | return |
| |
|
| | |
| | jobs = launch_ocr_jobs( |
| | args.input_dataset, |
| | args.output_repo, |
| | models=selected, |
| | max_samples=args.max_samples, |
| | split=args.split, |
| | shuffle=args.shuffle, |
| | seed=args.seed, |
| | flavor_override=args.flavor, |
| | timeout=args.timeout, |
| | ) |
| |
|
| | console.print(f"\n[green]{len(jobs)} jobs launched.[/green]") |
| | for job in jobs: |
| | console.print(f" [cyan]{job.model_slug}[/cyan]: {job.job_url}") |
| |
|
| | if not args.no_wait: |
| | console.print("\n[bold]Waiting for jobs to complete...[/bold]") |
| | poll_jobs(jobs) |
| | console.print("\n[bold green]All jobs finished![/bold green]") |
| | console.print("\nEvaluate:") |
| | console.print(f" ocr-bench judge {args.output_repo}") |
| | else: |
| | console.print("\nJobs running in background.") |
| | console.print("Check status at: https://huggingface.co/settings/jobs") |
| | console.print(f"When complete: ocr-bench judge {args.output_repo}") |
| |
|
| |
|
| | def cmd_view(args: argparse.Namespace) -> None: |
| | """Launch the FastAPI + HTMX results viewer.""" |
| | try: |
| | import uvicorn |
| |
|
| | from ocr_bench.web import create_app |
| | except ImportError: |
| | console.print( |
| | "[red]Error:[/red] FastAPI/uvicorn not installed. " |
| | "Install the viewer extra: [bold]pip install ocr-bench\\[viewer][/bold]" |
| | ) |
| | sys.exit(1) |
| |
|
| | console.print(f"Loading results from [bold]{args.results}[/bold]...") |
| | app = create_app(args.results, output_path=args.output) |
| | console.print(f"Starting viewer at [bold]http://{args.host}:{args.port}[/bold]") |
| | uvicorn.run(app, host=args.host, port=args.port) |
| |
|
| |
|
| | def main() -> None: |
| | parser = build_parser() |
| | args = parser.parse_args() |
| |
|
| | if args.command is None: |
| | parser.print_help() |
| | sys.exit(0) |
| |
|
| | try: |
| | if args.command == "judge": |
| | cmd_judge(args) |
| | elif args.command == "run": |
| | cmd_run(args) |
| | elif args.command == "view": |
| | cmd_view(args) |
| | except DatasetError as exc: |
| | console.print(f"[red]Error:[/red] {exc}") |
| | sys.exit(1) |
| |
|