Spaces:
Paused
Paused
Delete benchmark.py
Browse files- benchmark.py +0 -226
benchmark.py
DELETED
@@ -1,226 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import os.path
|
3 |
-
import random
|
4 |
-
import time
|
5 |
-
from functools import partial
|
6 |
-
|
7 |
-
import evaluate
|
8 |
-
from tabulate import tabulate
|
9 |
-
from tqdm import tqdm
|
10 |
-
|
11 |
-
from texify.inference import batch_inference
|
12 |
-
from texify.model.model import load_model
|
13 |
-
from texify.model.processor import load_processor
|
14 |
-
from PIL import Image
|
15 |
-
from texify.settings import settings
|
16 |
-
import json
|
17 |
-
import base64
|
18 |
-
import io
|
19 |
-
from rapidfuzz.distance import Levenshtein
|
20 |
-
|
21 |
-
|
22 |
-
def normalize_text(text):
|
23 |
-
# Replace fences
|
24 |
-
text = text.replace("$", "")
|
25 |
-
text = text.replace("\[", "")
|
26 |
-
text = text.replace("\]", "")
|
27 |
-
text = text.replace("\(", "")
|
28 |
-
text = text.replace("\)", "")
|
29 |
-
text = text.strip()
|
30 |
-
return text
|
31 |
-
|
32 |
-
|
33 |
-
def score_text(predictions, references):
|
34 |
-
bleu = evaluate.load("bleu")
|
35 |
-
bleu_results = bleu.compute(predictions=predictions, references=references)
|
36 |
-
|
37 |
-
meteor = evaluate.load('meteor')
|
38 |
-
meteor_results = meteor.compute(predictions=predictions, references=references)
|
39 |
-
|
40 |
-
lev_dist = []
|
41 |
-
for p, r in zip(predictions, references):
|
42 |
-
lev_dist.append(Levenshtein.normalized_distance(p, r))
|
43 |
-
|
44 |
-
return {
|
45 |
-
'bleu': bleu_results["bleu"],
|
46 |
-
'meteor': meteor_results['meteor'],
|
47 |
-
'edit': sum(lev_dist) / len(lev_dist)
|
48 |
-
}
|
49 |
-
|
50 |
-
|
51 |
-
def image_to_pil(image):
|
52 |
-
decoded = base64.b64decode(image)
|
53 |
-
return Image.open(io.BytesIO(decoded))
|
54 |
-
|
55 |
-
|
56 |
-
def load_images(source_data):
|
57 |
-
images = [sd["image"] for sd in source_data]
|
58 |
-
images = [image_to_pil(image) for image in images]
|
59 |
-
return images
|
60 |
-
|
61 |
-
|
62 |
-
def inference_texify(source_data, model, processor):
|
63 |
-
images = load_images(source_data)
|
64 |
-
|
65 |
-
write_data = []
|
66 |
-
for i in tqdm(range(0, len(images), settings.BATCH_SIZE), desc="Texify inference"):
|
67 |
-
batch = images[i:i+settings.BATCH_SIZE]
|
68 |
-
text = batch_inference(batch, model, processor)
|
69 |
-
for j, t in enumerate(text):
|
70 |
-
eq_idx = i + j
|
71 |
-
write_data.append({"text": t, "equation": source_data[eq_idx]["equation"]})
|
72 |
-
|
73 |
-
return write_data
|
74 |
-
|
75 |
-
|
76 |
-
def inference_pix2tex(source_data):
|
77 |
-
from pix2tex.cli import LatexOCR
|
78 |
-
model = LatexOCR()
|
79 |
-
|
80 |
-
images = load_images(source_data)
|
81 |
-
write_data = []
|
82 |
-
for i in tqdm(range(len(images)), desc="Pix2tex inference"):
|
83 |
-
try:
|
84 |
-
text = model(images[i])
|
85 |
-
except ValueError:
|
86 |
-
# Happens when resize fails
|
87 |
-
text = ""
|
88 |
-
write_data.append({"text": text, "equation": source_data[i]["equation"]})
|
89 |
-
|
90 |
-
return write_data
|
91 |
-
|
92 |
-
|
93 |
-
def image_to_bmp(image):
|
94 |
-
img_out = io.BytesIO()
|
95 |
-
image.save(img_out, format="BMP")
|
96 |
-
return img_out
|
97 |
-
|
98 |
-
|
99 |
-
def inference_nougat(source_data, batch_size=1):
|
100 |
-
import torch
|
101 |
-
from nougat.postprocessing import markdown_compatible
|
102 |
-
from nougat.utils.checkpoint import get_checkpoint
|
103 |
-
from nougat.utils.dataset import ImageDataset
|
104 |
-
from nougat.utils.device import move_to_device
|
105 |
-
from nougat import NougatModel
|
106 |
-
|
107 |
-
# Load images, then convert to bmp format for nougat
|
108 |
-
images = load_images(source_data)
|
109 |
-
images = [image_to_bmp(image) for image in images]
|
110 |
-
predictions = []
|
111 |
-
|
112 |
-
ckpt = get_checkpoint(None, model_tag="0.1.0-small")
|
113 |
-
model = NougatModel.from_pretrained(ckpt)
|
114 |
-
if settings.TORCH_DEVICE_MODEL != "cpu":
|
115 |
-
move_to_device(model, bf16=settings.CUDA, cuda=settings.CUDA)
|
116 |
-
model.eval()
|
117 |
-
|
118 |
-
dataset = ImageDataset(
|
119 |
-
images,
|
120 |
-
partial(model.encoder.prepare_input, random_padding=False),
|
121 |
-
)
|
122 |
-
|
123 |
-
# Batch sizes higher than 1 explode memory usage on CPU/MPS
|
124 |
-
dataloader = torch.utils.data.DataLoader(
|
125 |
-
dataset,
|
126 |
-
batch_size=batch_size,
|
127 |
-
pin_memory=True,
|
128 |
-
shuffle=False,
|
129 |
-
)
|
130 |
-
|
131 |
-
for idx, sample in tqdm(enumerate(dataloader), desc="Nougat inference", total=len(dataloader)):
|
132 |
-
model.config.max_length = settings.MAX_TOKENS
|
133 |
-
model_output = model.inference(image_tensors=sample, early_stopping=False)
|
134 |
-
output = [markdown_compatible(o) for o in model_output["predictions"]]
|
135 |
-
predictions.extend(output)
|
136 |
-
return predictions
|
137 |
-
|
138 |
-
|
139 |
-
def main():
|
140 |
-
parser = argparse.ArgumentParser(description="Benchmark the performance of texify.")
|
141 |
-
parser.add_argument("--data_path", type=str, help="Path to JSON file with source images/equations", default=os.path.join(settings.DATA_DIR, "bench_data.json"))
|
142 |
-
parser.add_argument("--result_path", type=str, help="Path to JSON file to save results to.", default=os.path.join(settings.DATA_DIR, "bench_results.json"))
|
143 |
-
parser.add_argument("--max", type=int, help="Maximum number of images to benchmark.", default=None)
|
144 |
-
parser.add_argument("--pix2tex", action="store_true", help="Run pix2tex scoring", default=False)
|
145 |
-
parser.add_argument("--nougat", action="store_true", help="Run nougat scoring", default=False)
|
146 |
-
args = parser.parse_args()
|
147 |
-
|
148 |
-
source_path = os.path.abspath(args.data_path)
|
149 |
-
result_path = os.path.abspath(args.result_path)
|
150 |
-
os.makedirs(os.path.dirname(result_path), exist_ok=True)
|
151 |
-
model = load_model()
|
152 |
-
processor = load_processor()
|
153 |
-
|
154 |
-
with open(source_path, "r") as f:
|
155 |
-
source_data = json.load(f)
|
156 |
-
|
157 |
-
if args.max:
|
158 |
-
random.seed(1)
|
159 |
-
source_data = random.sample(source_data, args.max)
|
160 |
-
|
161 |
-
start = time.time()
|
162 |
-
predictions = inference_texify(source_data, model, processor)
|
163 |
-
times = {"texify": time.time() - start}
|
164 |
-
text = [normalize_text(p["text"]) for p in predictions]
|
165 |
-
references = [normalize_text(p["equation"]) for p in predictions]
|
166 |
-
|
167 |
-
scores = score_text(text, references)
|
168 |
-
|
169 |
-
write_data = {
|
170 |
-
"texify": {
|
171 |
-
"scores": scores,
|
172 |
-
"text": [{"prediction": p, "reference": r} for p, r in zip(text, references)]
|
173 |
-
}
|
174 |
-
}
|
175 |
-
|
176 |
-
if args.pix2tex:
|
177 |
-
start = time.time()
|
178 |
-
predictions = inference_pix2tex(source_data)
|
179 |
-
times["pix2tex"] = time.time() - start
|
180 |
-
|
181 |
-
p_text = [normalize_text(p["text"]) for p in predictions]
|
182 |
-
|
183 |
-
p_scores = score_text(p_text, references)
|
184 |
-
|
185 |
-
write_data["pix2tex"] = {
|
186 |
-
"scores": p_scores,
|
187 |
-
"text": [{"prediction": p, "reference": r} for p, r in zip(p_text, references)]
|
188 |
-
}
|
189 |
-
|
190 |
-
if args.nougat:
|
191 |
-
start = time.time()
|
192 |
-
predictions = inference_nougat(source_data)
|
193 |
-
times["nougat"] = time.time() - start
|
194 |
-
n_text = [normalize_text(p) for p in predictions]
|
195 |
-
|
196 |
-
n_scores = score_text(n_text, references)
|
197 |
-
|
198 |
-
write_data["nougat"] = {
|
199 |
-
"scores": n_scores,
|
200 |
-
"text": [{"prediction": p, "reference": r} for p, r in zip(n_text, references)]
|
201 |
-
}
|
202 |
-
|
203 |
-
score_table = []
|
204 |
-
score_headers = ["bleu", "meteor", "edit"]
|
205 |
-
score_dirs = ["⬆", "⬆", "⬇", "⬇"]
|
206 |
-
|
207 |
-
for method in write_data.keys():
|
208 |
-
score_table.append([method, *[write_data[method]["scores"][h] for h in score_headers], times[method]])
|
209 |
-
|
210 |
-
score_headers.append("time taken (s)")
|
211 |
-
score_headers = [f"{h} {d}" for h, d in zip(score_headers, score_dirs)]
|
212 |
-
print()
|
213 |
-
print(tabulate(score_table, headers=["Method", *score_headers]))
|
214 |
-
print()
|
215 |
-
print("Higher is better for BLEU and METEOR, lower is better for edit distance and time taken.")
|
216 |
-
print("Note that pix2tex is unbatched (I couldn't find a batch inference method in the docs), so time taken is higher than it should be.")
|
217 |
-
|
218 |
-
with open(result_path, "w") as f:
|
219 |
-
json.dump(write_data, f, indent=4)
|
220 |
-
|
221 |
-
|
222 |
-
if __name__ == "__main__":
|
223 |
-
main()
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|