Spaces:
Paused
Paused
| import argparse | |
| import os.path | |
| import random | |
| import time | |
| from functools import partial | |
| import evaluate | |
| from tabulate import tabulate | |
| from tqdm import tqdm | |
| from texify.inference import batch_inference | |
| from texify.model.model import load_model | |
| from texify.model.processor import load_processor | |
| from PIL import Image | |
| from texify.settings import settings | |
| import json | |
| import base64 | |
| import io | |
| from rapidfuzz.distance import Levenshtein | |
| def normalize_text(text): | |
| # Replace fences | |
| text = text.replace("$", "") | |
| text = text.replace("\[", "") | |
| text = text.replace("\]", "") | |
| text = text.replace("\(", "") | |
| text = text.replace("\)", "") | |
| text = text.strip() | |
| return text | |
| def score_text(predictions, references): | |
| bleu = evaluate.load("bleu") | |
| bleu_results = bleu.compute(predictions=predictions, references=references) | |
| meteor = evaluate.load('meteor') | |
| meteor_results = meteor.compute(predictions=predictions, references=references) | |
| lev_dist = [] | |
| for p, r in zip(predictions, references): | |
| lev_dist.append(Levenshtein.normalized_distance(p, r)) | |
| return { | |
| 'bleu': bleu_results["bleu"], | |
| 'meteor': meteor_results['meteor'], | |
| 'edit': sum(lev_dist) / len(lev_dist) | |
| } | |
| def image_to_pil(image): | |
| decoded = base64.b64decode(image) | |
| return Image.open(io.BytesIO(decoded)) | |
| def load_images(source_data): | |
| images = [sd["image"] for sd in source_data] | |
| images = [image_to_pil(image) for image in images] | |
| return images | |
| def inference_texify(source_data, model, processor): | |
| images = load_images(source_data) | |
| write_data = [] | |
| for i in tqdm(range(0, len(images), settings.BATCH_SIZE), desc="Texify inference"): | |
| batch = images[i:i+settings.BATCH_SIZE] | |
| text = batch_inference(batch, model, processor) | |
| for j, t in enumerate(text): | |
| eq_idx = i + j | |
| write_data.append({"text": t, "equation": source_data[eq_idx]["equation"]}) | |
| return write_data | |
| def inference_pix2tex(source_data): | |
| from pix2tex.cli import LatexOCR | |
| model = LatexOCR() | |
| images = load_images(source_data) | |
| write_data = [] | |
| for i in tqdm(range(len(images)), desc="Pix2tex inference"): | |
| try: | |
| text = model(images[i]) | |
| except ValueError: | |
| # Happens when resize fails | |
| text = "" | |
| write_data.append({"text": text, "equation": source_data[i]["equation"]}) | |
| return write_data | |
| def image_to_bmp(image): | |
| img_out = io.BytesIO() | |
| image.save(img_out, format="BMP") | |
| return img_out | |
| def inference_nougat(source_data, batch_size=1): | |
| import torch | |
| from nougat.postprocessing import markdown_compatible | |
| from nougat.utils.checkpoint import get_checkpoint | |
| from nougat.utils.dataset import ImageDataset | |
| from nougat.utils.device import move_to_device | |
| from nougat import NougatModel | |
| # Load images, then convert to bmp format for nougat | |
| images = load_images(source_data) | |
| images = [image_to_bmp(image) for image in images] | |
| predictions = [] | |
| ckpt = get_checkpoint(None, model_tag="0.1.0-small") | |
| model = NougatModel.from_pretrained(ckpt) | |
| if settings.TORCH_DEVICE_MODEL != "cpu": | |
| move_to_device(model, bf16=settings.CUDA, cuda=settings.CUDA) | |
| model.eval() | |
| dataset = ImageDataset( | |
| images, | |
| partial(model.encoder.prepare_input, random_padding=False), | |
| ) | |
| # Batch sizes higher than 1 explode memory usage on CPU/MPS | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| pin_memory=True, | |
| shuffle=False, | |
| ) | |
| for idx, sample in tqdm(enumerate(dataloader), desc="Nougat inference", total=len(dataloader)): | |
| model.config.max_length = settings.MAX_TOKENS | |
| model_output = model.inference(image_tensors=sample, early_stopping=False) | |
| output = [markdown_compatible(o) for o in model_output["predictions"]] | |
| predictions.extend(output) | |
| return predictions | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Benchmark the performance of texify.") | |
| parser.add_argument("--data_path", type=str, help="Path to JSON file with source images/equations", default=os.path.join(settings.DATA_DIR, "bench_data.json")) | |
| parser.add_argument("--result_path", type=str, help="Path to JSON file to save results to.", default=os.path.join(settings.DATA_DIR, "bench_results.json")) | |
| parser.add_argument("--max", type=int, help="Maximum number of images to benchmark.", default=None) | |
| parser.add_argument("--pix2tex", action="store_true", help="Run pix2tex scoring", default=False) | |
| parser.add_argument("--nougat", action="store_true", help="Run nougat scoring", default=False) | |
| args = parser.parse_args() | |
| source_path = os.path.abspath(args.data_path) | |
| result_path = os.path.abspath(args.result_path) | |
| os.makedirs(os.path.dirname(result_path), exist_ok=True) | |
| model = load_model() | |
| processor = load_processor() | |
| with open(source_path, "r") as f: | |
| source_data = json.load(f) | |
| if args.max: | |
| random.seed(1) | |
| source_data = random.sample(source_data, args.max) | |
| start = time.time() | |
| predictions = inference_texify(source_data, model, processor) | |
| times = {"texify": time.time() - start} | |
| text = [normalize_text(p["text"]) for p in predictions] | |
| references = [normalize_text(p["equation"]) for p in predictions] | |
| scores = score_text(text, references) | |
| write_data = { | |
| "texify": { | |
| "scores": scores, | |
| "text": [{"prediction": p, "reference": r} for p, r in zip(text, references)] | |
| } | |
| } | |
| if args.pix2tex: | |
| start = time.time() | |
| predictions = inference_pix2tex(source_data) | |
| times["pix2tex"] = time.time() - start | |
| p_text = [normalize_text(p["text"]) for p in predictions] | |
| p_scores = score_text(p_text, references) | |
| write_data["pix2tex"] = { | |
| "scores": p_scores, | |
| "text": [{"prediction": p, "reference": r} for p, r in zip(p_text, references)] | |
| } | |
| if args.nougat: | |
| start = time.time() | |
| predictions = inference_nougat(source_data) | |
| times["nougat"] = time.time() - start | |
| n_text = [normalize_text(p) for p in predictions] | |
| n_scores = score_text(n_text, references) | |
| write_data["nougat"] = { | |
| "scores": n_scores, | |
| "text": [{"prediction": p, "reference": r} for p, r in zip(n_text, references)] | |
| } | |
| score_table = [] | |
| score_headers = ["bleu", "meteor", "edit"] | |
| score_dirs = ["⬆", "⬆", "⬇", "⬇"] | |
| for method in write_data.keys(): | |
| score_table.append([method, *[write_data[method]["scores"][h] for h in score_headers], times[method]]) | |
| score_headers.append("time taken (s)") | |
| score_headers = [f"{h} {d}" for h, d in zip(score_headers, score_dirs)] | |
| print() | |
| print(tabulate(score_table, headers=["Method", *score_headers])) | |
| print() | |
| print("Higher is better for BLEU and METEOR, lower is better for edit distance and time taken.") | |
| print("Note that pix2tex is unbatched (I couldn't find a batch inference method in the docs), so time taken is higher than it should be.") | |
| with open(result_path, "w") as f: | |
| json.dump(write_data, f, indent=4) | |
| if __name__ == "__main__": | |
| main() | |