davanstrien's picture
davanstrien HF Staff
Upload folder using huggingface_hub
1118181 verified
"""CLI entrypoint for ocr-bench."""
from __future__ import annotations
import argparse
import sys
import structlog
from rich.console import Console
from rich.table import Table
from ocr_bench.backends import (
DEFAULT_JUDGE,
DEFAULT_MAX_TOKENS,
aggregate_jury_votes,
parse_judge_spec,
)
from ocr_bench.dataset import (
DatasetError,
discover_configs,
discover_pr_configs,
load_config_dataset,
load_flat_dataset,
)
from ocr_bench.elo import ComparisonResult, Leaderboard, compute_elo, rankings_resolved
from ocr_bench.judge import Comparison, _normalize_pair, build_comparisons, sample_indices
from ocr_bench.publish import (
EvalMetadata,
load_existing_comparisons,
load_existing_metadata,
publish_results,
)
logger = structlog.get_logger()
console = Console()
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="ocr-bench",
description="OCR model evaluation toolkit — VLM-as-judge with per-dataset leaderboards",
)
sub = parser.add_subparsers(dest="command")
judge = sub.add_parser("judge", help="Run pairwise VLM judge on OCR outputs")
# Dataset
judge.add_argument("dataset", help="HF dataset repo id")
judge.add_argument("--split", default="train", help="Dataset split (default: train)")
judge.add_argument("--columns", nargs="+", default=None, help="Explicit OCR column names")
judge.add_argument(
"--configs", nargs="+", default=None, help="Config-per-model: list of config names"
)
judge.add_argument("--from-prs", action="store_true", help="Force PR-based config discovery")
judge.add_argument(
"--merge",
action="store_true",
help="Merge PRs to main after discovery (default: load via revision)",
)
# Judge
judge.add_argument(
"--model",
action="append",
dest="models",
help=f"Judge model spec (repeatable for jury). Default: {DEFAULT_JUDGE}",
)
# Eval
judge.add_argument("--max-samples", type=int, default=None, help="Max samples to evaluate")
judge.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
judge.add_argument(
"--max-tokens",
type=int,
default=DEFAULT_MAX_TOKENS,
help=f"Max tokens for judge response (default: {DEFAULT_MAX_TOKENS})",
)
# Output
judge.add_argument(
"--save-results",
default=None,
help="HF repo id to publish results to (default: {dataset}-results)",
)
judge.add_argument(
"--no-publish",
action="store_true",
help="Don't publish results (default: publish to {dataset}-results)",
)
judge.add_argument(
"--full-rejudge",
action="store_true",
help="Re-judge all pairs, ignoring existing comparisons in --save-results repo",
)
judge.add_argument(
"--no-adaptive",
action="store_true",
help="Disable adaptive stopping (default: adaptive is on)",
)
judge.add_argument(
"--concurrency",
type=int,
default=1,
help="Number of concurrent judge API calls (default: 1)",
)
# --- run subcommand ---
run = sub.add_parser("run", help="Launch OCR models on a dataset via HF Jobs")
run.add_argument("input_dataset", help="HF dataset repo id with images")
run.add_argument("output_repo", help="Output dataset repo (all models push here)")
run.add_argument(
"--models", nargs="+", default=None, help="Model slugs to run (default: all 4 core)"
)
run.add_argument("--max-samples", type=int, default=None, help="Per-model sample limit")
run.add_argument("--split", default="train", help="Dataset split (default: train)")
run.add_argument("--flavor", default=None, help="Override GPU flavor for all models")
run.add_argument("--timeout", default="4h", help="Per-job timeout (default: 4h)")
run.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
run.add_argument("--shuffle", action="store_true", help="Shuffle source dataset")
run.add_argument("--list-models", action="store_true", help="Print available models and exit")
run.add_argument(
"--dry-run", action="store_true", help="Show what would launch without launching"
)
run.add_argument(
"--no-wait", action="store_true", help="Launch and exit without polling (default: wait)"
)
# --- view subcommand ---
view = sub.add_parser("view", help="Browse and validate results in a web UI")
view.add_argument("results", help="HF dataset repo id with published results")
view.add_argument("--port", type=int, default=7860, help="Port (default: 7860)")
view.add_argument("--host", default="127.0.0.1", help="Host (default: 127.0.0.1)")
view.add_argument("--output", default=None, help="Path to save annotations JSON")
return parser
def print_leaderboard(board: Leaderboard) -> None:
"""Print leaderboard as a Rich table."""
table = Table(title="OCR Model Leaderboard")
table.add_column("Rank", style="bold")
table.add_column("Model")
has_ci = bool(board.elo_ci)
if has_ci:
table.add_column("ELO (95% CI)", justify="right")
else:
table.add_column("ELO", justify="right")
table.add_column("Wins", justify="right")
table.add_column("Losses", justify="right")
table.add_column("Ties", justify="right")
table.add_column("Win%", justify="right")
for rank, (model, elo) in enumerate(board.ranked, 1):
pct = board.win_pct(model)
pct_str = f"{pct:.0f}%" if pct is not None else "-"
if has_ci and model in board.elo_ci:
lo, hi = board.elo_ci[model]
elo_str = f"{round(elo)} ({round(lo)}\u2013{round(hi)})"
else:
elo_str = str(round(elo))
table.add_row(
str(rank),
model,
elo_str,
str(board.wins[model]),
str(board.losses[model]),
str(board.ties[model]),
pct_str,
)
console.print(table)
def _convert_results(
comparisons: list[Comparison], aggregated: list[dict]
) -> list[ComparisonResult]:
"""Convert judged comparisons + aggregated outputs into ComparisonResult list."""
results: list[ComparisonResult] = []
for comp, result in zip(comparisons, aggregated):
if not result:
continue
results.append(
ComparisonResult(
sample_idx=comp.sample_idx,
model_a=comp.model_a,
model_b=comp.model_b,
winner=result.get("winner", "tie"),
reason=result.get("reason", ""),
agreement=result.get("agreement", "1/1"),
swapped=comp.swapped,
text_a=comp.text_a,
text_b=comp.text_b,
col_a=comp.col_a,
col_b=comp.col_b,
)
)
return results
def _resolve_results_repo(dataset: str, save_results: str | None, no_publish: bool) -> str | None:
"""Derive the results repo id. Returns None if publishing is disabled."""
if no_publish:
return None
if save_results:
return save_results
return f"{dataset}-results"
def cmd_judge(args: argparse.Namespace) -> None:
"""Orchestrate: load → compare → judge → elo → print → publish."""
# --- Resolve flags ---
adaptive = not args.no_adaptive
merge = args.merge
results_repo = _resolve_results_repo(args.dataset, args.save_results, args.no_publish)
from_prs = False # track for metadata
if results_repo:
console.print(f"Results will be published to [bold]{results_repo}[/bold]")
# --- Load dataset (cascading auto-detection) ---
if args.configs:
# Explicit configs — use them directly
config_names = args.configs
ds, ocr_columns = load_config_dataset(args.dataset, config_names, split=args.split)
elif args.columns:
# Explicit columns — flat loading
ds, ocr_columns = load_flat_dataset(args.dataset, split=args.split, columns=args.columns)
elif args.from_prs:
# Forced PR discovery
config_names, pr_revisions = discover_pr_configs(args.dataset, merge=merge)
if not config_names:
raise DatasetError("No configs found in open PRs")
from_prs = True
console.print(f"Discovered {len(config_names)} configs from PRs: {config_names}")
ds, ocr_columns = load_config_dataset(
args.dataset,
config_names,
split=args.split,
pr_revisions=pr_revisions if not merge else None,
)
else:
# Auto-detect: PRs + main branch configs combined, fall back to flat
pr_configs, pr_revisions = discover_pr_configs(args.dataset, merge=merge)
main_configs = discover_configs(args.dataset)
# Combine: PR configs + main configs not already in PRs
config_names = list(pr_configs)
for mc in main_configs:
if mc not in pr_configs:
config_names.append(mc)
if config_names:
if pr_configs:
from_prs = True
console.print(f"Auto-detected {len(pr_configs)} configs from PRs: {pr_configs}")
if main_configs:
main_only = [c for c in main_configs if c not in pr_configs]
if main_only:
console.print(f"Auto-detected {len(main_only)} configs on main: {main_only}")
ds, ocr_columns = load_config_dataset(
args.dataset,
config_names,
split=args.split,
pr_revisions=pr_revisions if pr_configs else None,
)
else:
# No configs anywhere — fall back to flat loading
ds, ocr_columns = load_flat_dataset(args.dataset, split=args.split)
console.print(f"Loaded {len(ds)} samples with {len(ocr_columns)} models:")
for col, model in ocr_columns.items():
console.print(f" {col}{model}")
# --- Incremental: load existing comparisons ---
existing_results: list[ComparisonResult] = []
existing_meta_rows: list[dict] = []
skip_pairs: set[tuple[str, str]] | None = None
if results_repo and not args.full_rejudge:
existing_results = load_existing_comparisons(results_repo)
if existing_results:
judged_pairs = {_normalize_pair(r.model_a, r.model_b) for r in existing_results}
skip_pairs = judged_pairs
console.print(
f"\nIncremental mode: {len(existing_results)} existing comparisons "
f"across {len(judged_pairs)} model pairs — skipping those."
)
existing_meta_rows = load_existing_metadata(results_repo)
else:
console.print("\nNo existing comparisons found — full judge run.")
model_names = list(set(ocr_columns.values()))
# --- Judge setup (shared by both paths) ---
model_specs = args.models or [DEFAULT_JUDGE]
judges = [
parse_judge_spec(spec, max_tokens=args.max_tokens, concurrency=args.concurrency)
for spec in model_specs
]
is_jury = len(judges) > 1
def _judge_batch(batch_comps: list[Comparison]) -> list[ComparisonResult]:
"""Run judge(s) on a batch of comparisons and return ComparisonResults."""
all_judge_outputs: list[list[dict]] = []
for judge in judges:
results = judge.judge(batch_comps)
all_judge_outputs.append(results)
if is_jury:
judge_names = [j.name for j in judges]
aggregated = aggregate_jury_votes(all_judge_outputs, judge_names)
else:
aggregated = all_judge_outputs[0]
return _convert_results(batch_comps, aggregated)
if adaptive:
# --- Adaptive stopping: batch-by-batch with convergence check ---
from itertools import combinations as _combs
all_indices = sample_indices(len(ds), args.max_samples, args.seed)
n_pairs = len(list(_combs(model_names, 2)))
batch_samples = 5
min_before_check = max(3 * n_pairs, 20)
if is_jury:
console.print(f"\nJury mode: {len(judges)} judges")
console.print(
f"\n[bold]Adaptive mode[/bold]: {len(all_indices)} samples, "
f"{n_pairs} pairs, batch size {batch_samples}, "
f"checking after {min_before_check} comparisons"
)
new_results: list[ComparisonResult] = []
total_comparisons = 0
for batch_num, batch_start in enumerate(range(0, len(all_indices), batch_samples)):
batch_indices = all_indices[batch_start : batch_start + batch_samples]
batch_comps = build_comparisons(
ds,
ocr_columns,
skip_pairs=skip_pairs,
indices=batch_indices,
seed=args.seed,
)
if not batch_comps:
continue
batch_results = _judge_batch(batch_comps)
new_results.extend(batch_results)
total_comparisons += len(batch_comps)
# batch_comps goes out of scope → GC can free images
total = len(existing_results) + len(new_results)
console.print(f" Batch {batch_num + 1}: {len(batch_results)} new, {total} total")
if total >= min_before_check:
board = compute_elo(existing_results + new_results, model_names)
# Show CI gaps for each adjacent pair
ranked = board.ranked
if board.elo_ci:
gaps: list[str] = []
for i in range(len(ranked) - 1):
hi_model, _ = ranked[i]
lo_model, _ = ranked[i + 1]
hi_ci = board.elo_ci.get(hi_model)
lo_ci = board.elo_ci.get(lo_model)
if hi_ci and lo_ci:
gap = hi_ci[0] - lo_ci[1] # positive = resolved
if gap > 0:
status = "[green]ok[/green]"
else:
status = f"[yellow]overlap {-gap:.0f}[/yellow]"
gaps.append(f" {hi_model} vs {lo_model}: gap={gap:+.0f} {status}")
if gaps:
console.print(" CI gaps:")
for g in gaps:
console.print(g)
if rankings_resolved(board):
remaining = len(all_indices) - batch_start - len(batch_indices)
console.print(
f"[green]Rankings converged after {total} comparisons! "
f"Skipped ~{remaining * n_pairs} remaining.[/green]"
)
break
console.print(f"\n{len(new_results)}/{total_comparisons} valid comparisons")
else:
# --- Standard single-pass flow ---
comparisons = build_comparisons(
ds,
ocr_columns,
max_samples=args.max_samples,
seed=args.seed,
skip_pairs=skip_pairs,
)
console.print(f"\nBuilt {len(comparisons)} new pairwise comparisons")
if not comparisons and not existing_results:
console.print(
"[yellow]No valid comparisons — check that OCR columns have text.[/yellow]"
)
return
if not comparisons:
console.print("[green]All pairs already judged — refitting leaderboard.[/green]")
board = compute_elo(existing_results, model_names)
console.print()
print_leaderboard(board)
if results_repo:
metadata = EvalMetadata(
source_dataset=args.dataset,
judge_models=[],
seed=args.seed,
max_samples=args.max_samples or len(ds),
total_comparisons=0,
valid_comparisons=0,
from_prs=from_prs,
)
publish_results(
results_repo,
board,
metadata,
existing_metadata=existing_meta_rows,
)
console.print(f"\nResults published to [bold]{results_repo}[/bold]")
return
if is_jury:
console.print(f"\nJury mode: {len(judges)} judges")
for judge in judges:
console.print(f"\nRunning judge: {judge.name}")
new_results = _judge_batch(comparisons)
total_comparisons = len(comparisons)
console.print(f"\n{len(new_results)}/{total_comparisons} valid comparisons")
# --- Merge existing + new, compute ELO ---
all_results = existing_results + new_results
board = compute_elo(all_results, model_names)
console.print()
print_leaderboard(board)
# --- Publish ---
if results_repo:
metadata = EvalMetadata(
source_dataset=args.dataset,
judge_models=[j.name for j in judges],
seed=args.seed,
max_samples=args.max_samples or len(ds),
total_comparisons=total_comparisons,
valid_comparisons=len(new_results),
from_prs=from_prs,
)
publish_results(results_repo, board, metadata, existing_metadata=existing_meta_rows)
console.print(f"\nResults published to [bold]{results_repo}[/bold]")
def cmd_run(args: argparse.Namespace) -> None:
"""Launch OCR models on a dataset via HF Jobs."""
from ocr_bench.run import (
DEFAULT_MODELS,
MODEL_REGISTRY,
build_script_args,
launch_ocr_jobs,
poll_jobs,
)
# --list-models
if args.list_models:
table = Table(title="Available OCR Models", show_lines=True)
table.add_column("Slug", style="cyan bold")
table.add_column("Model ID")
table.add_column("Size", justify="right")
table.add_column("Default GPU", justify="center")
for slug in sorted(MODEL_REGISTRY):
cfg = MODEL_REGISTRY[slug]
default = " (default)" if slug in DEFAULT_MODELS else ""
table.add_row(slug + default, cfg.model_id, cfg.size, cfg.default_flavor)
console.print(table)
console.print(f"\nDefault set: {', '.join(DEFAULT_MODELS)}")
return
selected = args.models or DEFAULT_MODELS
for slug in selected:
if slug not in MODEL_REGISTRY:
console.print(f"[red]Unknown model: {slug}[/red]")
console.print(f"Available: {', '.join(MODEL_REGISTRY.keys())}")
sys.exit(1)
console.print("\n[bold]OCR Benchmark Run[/bold]")
console.print(f" Source: {args.input_dataset}")
console.print(f" Output: {args.output_repo}")
console.print(f" Models: {', '.join(selected)}")
if args.max_samples:
console.print(f" Samples: {args.max_samples} per model")
console.print()
# Dry run
if args.dry_run:
console.print("[bold yellow]DRY RUN[/bold yellow] — no jobs will be launched\n")
for slug in selected:
cfg = MODEL_REGISTRY[slug]
flavor = args.flavor or cfg.default_flavor
script_args = build_script_args(
args.input_dataset,
args.output_repo,
slug,
max_samples=args.max_samples,
shuffle=args.shuffle,
seed=args.seed,
extra_args=cfg.default_args or None,
)
console.print(f"[cyan]{slug}[/cyan] ({cfg.model_id})")
console.print(f" Flavor: {flavor}")
console.print(f" Timeout: {args.timeout}")
console.print(f" Script: {cfg.script}")
console.print(f" Args: {' '.join(script_args)}")
console.print()
console.print("Remove --dry-run to launch these jobs.")
return
# Launch
jobs = launch_ocr_jobs(
args.input_dataset,
args.output_repo,
models=selected,
max_samples=args.max_samples,
split=args.split,
shuffle=args.shuffle,
seed=args.seed,
flavor_override=args.flavor,
timeout=args.timeout,
)
console.print(f"\n[green]{len(jobs)} jobs launched.[/green]")
for job in jobs:
console.print(f" [cyan]{job.model_slug}[/cyan]: {job.job_url}")
if not args.no_wait:
console.print("\n[bold]Waiting for jobs to complete...[/bold]")
poll_jobs(jobs)
console.print("\n[bold green]All jobs finished![/bold green]")
console.print("\nEvaluate:")
console.print(f" ocr-bench judge {args.output_repo}")
else:
console.print("\nJobs running in background.")
console.print("Check status at: https://huggingface.co/settings/jobs")
console.print(f"When complete: ocr-bench judge {args.output_repo}")
def cmd_view(args: argparse.Namespace) -> None:
"""Launch the FastAPI + HTMX results viewer."""
try:
import uvicorn
from ocr_bench.web import create_app
except ImportError:
console.print(
"[red]Error:[/red] FastAPI/uvicorn not installed. "
"Install the viewer extra: [bold]pip install ocr-bench\\[viewer][/bold]"
)
sys.exit(1)
console.print(f"Loading results from [bold]{args.results}[/bold]...")
app = create_app(args.results, output_path=args.output)
console.print(f"Starting viewer at [bold]http://{args.host}:{args.port}[/bold]")
uvicorn.run(app, host=args.host, port=args.port)
def main() -> None:
parser = build_parser()
args = parser.parse_args()
if args.command is None:
parser.print_help()
sys.exit(0)
try:
if args.command == "judge":
cmd_judge(args)
elif args.command == "run":
cmd_run(args)
elif args.command == "view":
cmd_view(args)
except DatasetError as exc:
console.print(f"[red]Error:[/red] {exc}")
sys.exit(1)