Tonic commited on
Commit
aaa7ccd
·
1 Parent(s): 20f496e

Delete benchmark.py

Browse files
Files changed (1) hide show
  1. benchmark.py +0 -226
benchmark.py DELETED
@@ -1,226 +0,0 @@
1
- import argparse
2
- import os.path
3
- import random
4
- import time
5
- from functools import partial
6
-
7
- import evaluate
8
- from tabulate import tabulate
9
- from tqdm import tqdm
10
-
11
- from texify.inference import batch_inference
12
- from texify.model.model import load_model
13
- from texify.model.processor import load_processor
14
- from PIL import Image
15
- from texify.settings import settings
16
- import json
17
- import base64
18
- import io
19
- from rapidfuzz.distance import Levenshtein
20
-
21
-
22
- def normalize_text(text):
23
- # Replace fences
24
- text = text.replace("$", "")
25
- text = text.replace("\[", "")
26
- text = text.replace("\]", "")
27
- text = text.replace("\(", "")
28
- text = text.replace("\)", "")
29
- text = text.strip()
30
- return text
31
-
32
-
33
- def score_text(predictions, references):
34
- bleu = evaluate.load("bleu")
35
- bleu_results = bleu.compute(predictions=predictions, references=references)
36
-
37
- meteor = evaluate.load('meteor')
38
- meteor_results = meteor.compute(predictions=predictions, references=references)
39
-
40
- lev_dist = []
41
- for p, r in zip(predictions, references):
42
- lev_dist.append(Levenshtein.normalized_distance(p, r))
43
-
44
- return {
45
- 'bleu': bleu_results["bleu"],
46
- 'meteor': meteor_results['meteor'],
47
- 'edit': sum(lev_dist) / len(lev_dist)
48
- }
49
-
50
-
51
- def image_to_pil(image):
52
- decoded = base64.b64decode(image)
53
- return Image.open(io.BytesIO(decoded))
54
-
55
-
56
- def load_images(source_data):
57
- images = [sd["image"] for sd in source_data]
58
- images = [image_to_pil(image) for image in images]
59
- return images
60
-
61
-
62
- def inference_texify(source_data, model, processor):
63
- images = load_images(source_data)
64
-
65
- write_data = []
66
- for i in tqdm(range(0, len(images), settings.BATCH_SIZE), desc="Texify inference"):
67
- batch = images[i:i+settings.BATCH_SIZE]
68
- text = batch_inference(batch, model, processor)
69
- for j, t in enumerate(text):
70
- eq_idx = i + j
71
- write_data.append({"text": t, "equation": source_data[eq_idx]["equation"]})
72
-
73
- return write_data
74
-
75
-
76
- def inference_pix2tex(source_data):
77
- from pix2tex.cli import LatexOCR
78
- model = LatexOCR()
79
-
80
- images = load_images(source_data)
81
- write_data = []
82
- for i in tqdm(range(len(images)), desc="Pix2tex inference"):
83
- try:
84
- text = model(images[i])
85
- except ValueError:
86
- # Happens when resize fails
87
- text = ""
88
- write_data.append({"text": text, "equation": source_data[i]["equation"]})
89
-
90
- return write_data
91
-
92
-
93
- def image_to_bmp(image):
94
- img_out = io.BytesIO()
95
- image.save(img_out, format="BMP")
96
- return img_out
97
-
98
-
99
- def inference_nougat(source_data, batch_size=1):
100
- import torch
101
- from nougat.postprocessing import markdown_compatible
102
- from nougat.utils.checkpoint import get_checkpoint
103
- from nougat.utils.dataset import ImageDataset
104
- from nougat.utils.device import move_to_device
105
- from nougat import NougatModel
106
-
107
- # Load images, then convert to bmp format for nougat
108
- images = load_images(source_data)
109
- images = [image_to_bmp(image) for image in images]
110
- predictions = []
111
-
112
- ckpt = get_checkpoint(None, model_tag="0.1.0-small")
113
- model = NougatModel.from_pretrained(ckpt)
114
- if settings.TORCH_DEVICE_MODEL != "cpu":
115
- move_to_device(model, bf16=settings.CUDA, cuda=settings.CUDA)
116
- model.eval()
117
-
118
- dataset = ImageDataset(
119
- images,
120
- partial(model.encoder.prepare_input, random_padding=False),
121
- )
122
-
123
- # Batch sizes higher than 1 explode memory usage on CPU/MPS
124
- dataloader = torch.utils.data.DataLoader(
125
- dataset,
126
- batch_size=batch_size,
127
- pin_memory=True,
128
- shuffle=False,
129
- )
130
-
131
- for idx, sample in tqdm(enumerate(dataloader), desc="Nougat inference", total=len(dataloader)):
132
- model.config.max_length = settings.MAX_TOKENS
133
- model_output = model.inference(image_tensors=sample, early_stopping=False)
134
- output = [markdown_compatible(o) for o in model_output["predictions"]]
135
- predictions.extend(output)
136
- return predictions
137
-
138
-
139
- def main():
140
- parser = argparse.ArgumentParser(description="Benchmark the performance of texify.")
141
- parser.add_argument("--data_path", type=str, help="Path to JSON file with source images/equations", default=os.path.join(settings.DATA_DIR, "bench_data.json"))
142
- parser.add_argument("--result_path", type=str, help="Path to JSON file to save results to.", default=os.path.join(settings.DATA_DIR, "bench_results.json"))
143
- parser.add_argument("--max", type=int, help="Maximum number of images to benchmark.", default=None)
144
- parser.add_argument("--pix2tex", action="store_true", help="Run pix2tex scoring", default=False)
145
- parser.add_argument("--nougat", action="store_true", help="Run nougat scoring", default=False)
146
- args = parser.parse_args()
147
-
148
- source_path = os.path.abspath(args.data_path)
149
- result_path = os.path.abspath(args.result_path)
150
- os.makedirs(os.path.dirname(result_path), exist_ok=True)
151
- model = load_model()
152
- processor = load_processor()
153
-
154
- with open(source_path, "r") as f:
155
- source_data = json.load(f)
156
-
157
- if args.max:
158
- random.seed(1)
159
- source_data = random.sample(source_data, args.max)
160
-
161
- start = time.time()
162
- predictions = inference_texify(source_data, model, processor)
163
- times = {"texify": time.time() - start}
164
- text = [normalize_text(p["text"]) for p in predictions]
165
- references = [normalize_text(p["equation"]) for p in predictions]
166
-
167
- scores = score_text(text, references)
168
-
169
- write_data = {
170
- "texify": {
171
- "scores": scores,
172
- "text": [{"prediction": p, "reference": r} for p, r in zip(text, references)]
173
- }
174
- }
175
-
176
- if args.pix2tex:
177
- start = time.time()
178
- predictions = inference_pix2tex(source_data)
179
- times["pix2tex"] = time.time() - start
180
-
181
- p_text = [normalize_text(p["text"]) for p in predictions]
182
-
183
- p_scores = score_text(p_text, references)
184
-
185
- write_data["pix2tex"] = {
186
- "scores": p_scores,
187
- "text": [{"prediction": p, "reference": r} for p, r in zip(p_text, references)]
188
- }
189
-
190
- if args.nougat:
191
- start = time.time()
192
- predictions = inference_nougat(source_data)
193
- times["nougat"] = time.time() - start
194
- n_text = [normalize_text(p) for p in predictions]
195
-
196
- n_scores = score_text(n_text, references)
197
-
198
- write_data["nougat"] = {
199
- "scores": n_scores,
200
- "text": [{"prediction": p, "reference": r} for p, r in zip(n_text, references)]
201
- }
202
-
203
- score_table = []
204
- score_headers = ["bleu", "meteor", "edit"]
205
- score_dirs = ["⬆", "⬆", "⬇", "⬇"]
206
-
207
- for method in write_data.keys():
208
- score_table.append([method, *[write_data[method]["scores"][h] for h in score_headers], times[method]])
209
-
210
- score_headers.append("time taken (s)")
211
- score_headers = [f"{h} {d}" for h, d in zip(score_headers, score_dirs)]
212
- print()
213
- print(tabulate(score_table, headers=["Method", *score_headers]))
214
- print()
215
- print("Higher is better for BLEU and METEOR, lower is better for edit distance and time taken.")
216
- print("Note that pix2tex is unbatched (I couldn't find a batch inference method in the docs), so time taken is higher than it should be.")
217
-
218
- with open(result_path, "w") as f:
219
- json.dump(write_data, f, indent=4)
220
-
221
-
222
- if __name__ == "__main__":
223
- main()
224
-
225
-
226
-