import argparse import os.path import random import time from functools import partial import evaluate from tabulate import tabulate from tqdm import tqdm from texify.inference import batch_inference from texify.model.model import load_model from texify.model.processor import load_processor from PIL import Image from texify.settings import settings import json import base64 import io from rapidfuzz.distance import Levenshtein def normalize_text(text): # Replace fences text = text.replace("$", "") text = text.replace("\[", "") text = text.replace("\]", "") text = text.replace("\(", "") text = text.replace("\)", "") text = text.strip() return text def score_text(predictions, references): bleu = evaluate.load("bleu") bleu_results = bleu.compute(predictions=predictions, references=references) meteor = evaluate.load('meteor') meteor_results = meteor.compute(predictions=predictions, references=references) lev_dist = [] for p, r in zip(predictions, references): lev_dist.append(Levenshtein.normalized_distance(p, r)) return { 'bleu': bleu_results["bleu"], 'meteor': meteor_results['meteor'], 'edit': sum(lev_dist) / len(lev_dist) } def image_to_pil(image): decoded = base64.b64decode(image) return Image.open(io.BytesIO(decoded)) def load_images(source_data): images = [sd["image"] for sd in source_data] images = [image_to_pil(image) for image in images] return images def inference_texify(source_data, model, processor): images = load_images(source_data) write_data = [] for i in tqdm(range(0, len(images), settings.BATCH_SIZE), desc="Texify inference"): batch = images[i:i+settings.BATCH_SIZE] text = batch_inference(batch, model, processor) for j, t in enumerate(text): eq_idx = i + j write_data.append({"text": t, "equation": source_data[eq_idx]["equation"]}) return write_data def inference_pix2tex(source_data): from pix2tex.cli import LatexOCR model = LatexOCR() images = load_images(source_data) write_data = [] for i in tqdm(range(len(images)), desc="Pix2tex inference"): try: text = model(images[i]) except ValueError: # Happens when resize fails text = "" write_data.append({"text": text, "equation": source_data[i]["equation"]}) return write_data def image_to_bmp(image): img_out = io.BytesIO() image.save(img_out, format="BMP") return img_out def inference_nougat(source_data, batch_size=1): import torch from nougat.postprocessing import markdown_compatible from nougat.utils.checkpoint import get_checkpoint from nougat.utils.dataset import ImageDataset from nougat.utils.device import move_to_device from nougat import NougatModel # Load images, then convert to bmp format for nougat images = load_images(source_data) images = [image_to_bmp(image) for image in images] predictions = [] ckpt = get_checkpoint(None, model_tag="0.1.0-small") model = NougatModel.from_pretrained(ckpt) if settings.TORCH_DEVICE_MODEL != "cpu": move_to_device(model, bf16=settings.CUDA, cuda=settings.CUDA) model.eval() dataset = ImageDataset( images, partial(model.encoder.prepare_input, random_padding=False), ) # Batch sizes higher than 1 explode memory usage on CPU/MPS dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, pin_memory=True, shuffle=False, ) for idx, sample in tqdm(enumerate(dataloader), desc="Nougat inference", total=len(dataloader)): model.config.max_length = settings.MAX_TOKENS model_output = model.inference(image_tensors=sample, early_stopping=False) output = [markdown_compatible(o) for o in model_output["predictions"]] predictions.extend(output) return predictions def main(): parser = argparse.ArgumentParser(description="Benchmark the performance of texify.") parser.add_argument("--data_path", type=str, help="Path to JSON file with source images/equations", default=os.path.join(settings.DATA_DIR, "bench_data.json")) parser.add_argument("--result_path", type=str, help="Path to JSON file to save results to.", default=os.path.join(settings.DATA_DIR, "bench_results.json")) parser.add_argument("--max", type=int, help="Maximum number of images to benchmark.", default=None) parser.add_argument("--pix2tex", action="store_true", help="Run pix2tex scoring", default=False) parser.add_argument("--nougat", action="store_true", help="Run nougat scoring", default=False) args = parser.parse_args() source_path = os.path.abspath(args.data_path) result_path = os.path.abspath(args.result_path) os.makedirs(os.path.dirname(result_path), exist_ok=True) model = load_model() processor = load_processor() with open(source_path, "r") as f: source_data = json.load(f) if args.max: random.seed(1) source_data = random.sample(source_data, args.max) start = time.time() predictions = inference_texify(source_data, model, processor) times = {"texify": time.time() - start} text = [normalize_text(p["text"]) for p in predictions] references = [normalize_text(p["equation"]) for p in predictions] scores = score_text(text, references) write_data = { "texify": { "scores": scores, "text": [{"prediction": p, "reference": r} for p, r in zip(text, references)] } } if args.pix2tex: start = time.time() predictions = inference_pix2tex(source_data) times["pix2tex"] = time.time() - start p_text = [normalize_text(p["text"]) for p in predictions] p_scores = score_text(p_text, references) write_data["pix2tex"] = { "scores": p_scores, "text": [{"prediction": p, "reference": r} for p, r in zip(p_text, references)] } if args.nougat: start = time.time() predictions = inference_nougat(source_data) times["nougat"] = time.time() - start n_text = [normalize_text(p) for p in predictions] n_scores = score_text(n_text, references) write_data["nougat"] = { "scores": n_scores, "text": [{"prediction": p, "reference": r} for p, r in zip(n_text, references)] } score_table = [] score_headers = ["bleu", "meteor", "edit"] score_dirs = ["⬆", "⬆", "⬇", "⬇"] for method in write_data.keys(): score_table.append([method, *[write_data[method]["scores"][h] for h in score_headers], times[method]]) score_headers.append("time taken (s)") score_headers = [f"{h} {d}" for h, d in zip(score_headers, score_dirs)] print() print(tabulate(score_table, headers=["Method", *score_headers])) print() print("Higher is better for BLEU and METEOR, lower is better for edit distance and time taken.") print("Note that pix2tex is unbatched (I couldn't find a batch inference method in the docs), so time taken is higher than it should be.") with open(result_path, "w") as f: json.dump(write_data, f, indent=4) if __name__ == "__main__": main()