Spaces:
Runtime error
Runtime error
zphilip
commited on
Commit
•
9d1fa0d
1
Parent(s):
876fac2
adding part 1
Browse files- .gitattributes +2 -0
- app.py +297 -0
- nougat/__init__.py +15 -0
- nougat/__pycache__/__init__.cpython-310.pyc +0 -0
- nougat/__pycache__/_version.cpython-310.pyc +0 -0
- nougat/__pycache__/model.cpython-310.pyc +0 -0
- nougat/__pycache__/postprocessing.cpython-310.pyc +0 -0
- nougat/__pycache__/transforms.cpython-310.pyc +0 -0
- nougat/_version.py +8 -0
- nougat/dataset/__init__.py +0 -0
- nougat/dataset/__pycache__/__init__.cpython-310.pyc +0 -0
- nougat/dataset/__pycache__/rasterize.cpython-310.pyc +0 -0
- nougat/dataset/create_index.py +173 -0
- nougat/dataset/gen_seek.py +36 -0
- nougat/dataset/parser/__init__.py +0 -0
- nougat/dataset/parser/document.py +703 -0
- nougat/dataset/parser/html2md.py +67 -0
- nougat/dataset/parser/latexml_parser.py +441 -0
- nougat/dataset/parser/markdown.py +396 -0
- nougat/dataset/pdffigures.py +71 -0
- nougat/dataset/rasterize.py +81 -0
- nougat/dataset/split_htmls_to_pages.py +219 -0
- nougat/dataset/split_md_to_pages.py +477 -0
- nougat/dataset/splitter.py +393 -0
- nougat/dataset/staircase.py +314 -0
- nougat/dataset/tokenizer.json +0 -0
- nougat/dataset/utils/__init__.py +8 -0
- nougat/dataset/utils/latex_conversion.py +146 -0
- nougat/dataset/utils/pdf_text_extract.py +86 -0
- nougat/dataset/utils/utils.py +20 -0
- nougat/metrics.py +117 -0
- nougat/model.py +702 -0
- nougat/postprocessing.py +508 -0
- nougat/transforms.py +173 -0
- nougat/utils/__init__.py +0 -0
- nougat/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- nougat/utils/__pycache__/checkpoint.cpython-310.pyc +0 -0
- nougat/utils/__pycache__/dataset.cpython-310.pyc +0 -0
- nougat/utils/checkpoint.py +119 -0
- nougat/utils/dataset.py +280 -0
- nougat/utils/device.py +38 -0
- predict.py +172 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import subprocess
|
3 |
+
import uuid
|
4 |
+
import os
|
5 |
+
import requests
|
6 |
+
import re
|
7 |
+
|
8 |
+
os.environ['http_proxy'] = ""
|
9 |
+
os.environ['https_proxy'] = ""
|
10 |
+
|
11 |
+
"""
|
12 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
13 |
+
|
14 |
+
This source code is licensed under the MIT license found in the
|
15 |
+
LICENSE file in the root directory of this source tree.
|
16 |
+
"""
|
17 |
+
import sys
|
18 |
+
from pathlib import Path
|
19 |
+
import logging
|
20 |
+
import re
|
21 |
+
import argparse
|
22 |
+
import re
|
23 |
+
from functools import partial
|
24 |
+
import torch
|
25 |
+
from torch.utils.data import ConcatDataset
|
26 |
+
from tqdm import tqdm
|
27 |
+
from nougat import NougatModel
|
28 |
+
from nougat.utils.dataset import LazyDataset
|
29 |
+
from nougat.utils.checkpoint import get_checkpoint
|
30 |
+
from nougat.postprocessing import markdown_compatible
|
31 |
+
import fitz
|
32 |
+
|
33 |
+
logging.basicConfig(level=logging.INFO)
|
34 |
+
if torch.cuda.is_available():
|
35 |
+
BATCH_SIZE = int(
|
36 |
+
torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1000 * 0.3
|
37 |
+
)
|
38 |
+
if BATCH_SIZE == 0:
|
39 |
+
logging.warning("GPU VRAM is too small. Computing on CPU.")
|
40 |
+
else:
|
41 |
+
# don't know what a good value is here. Would not recommend to run on CPU
|
42 |
+
BATCH_SIZE = 1
|
43 |
+
logging.warning("No GPU found. Conversion on CPU is very slow.")
|
44 |
+
|
45 |
+
def nougat_predict(input_files, output_path, checkpoint, batchsize, markdown,recompute):
|
46 |
+
print(f'*** nougat predict with input :{input_files} ***')
|
47 |
+
model = NougatModel.from_pretrained(checkpoint).to(torch.bfloat16)
|
48 |
+
if batchsize > 0:
|
49 |
+
if torch.cuda.is_available():
|
50 |
+
model.to("cuda")
|
51 |
+
else:
|
52 |
+
# set batch size to 1. Need to check if there are benefits for CPU conversion for >1
|
53 |
+
batchsize = 1
|
54 |
+
model.eval()
|
55 |
+
datasets = []
|
56 |
+
for pdf in input_files:
|
57 |
+
#if not pdf.exists():
|
58 |
+
if not os.path.exists(pdf):
|
59 |
+
continue
|
60 |
+
if output_path:
|
61 |
+
out_path = output_path / pdf.with_suffix(".mmd").name
|
62 |
+
print(out_path)
|
63 |
+
if out_path.exists() and not recompute:
|
64 |
+
logging.info(
|
65 |
+
f"Skipping {pdf.name}, already computed. Run with --recompute to convert again."
|
66 |
+
)
|
67 |
+
continue
|
68 |
+
try:
|
69 |
+
dataset = LazyDataset(
|
70 |
+
pdf, partial(model.encoder.prepare_input, random_padding=False)
|
71 |
+
)
|
72 |
+
except fitz.fitz.FileDataError:
|
73 |
+
logging.info(f"Could not load file {str(pdf)}.")
|
74 |
+
continue
|
75 |
+
datasets.append(dataset)
|
76 |
+
if len(datasets) == 0:
|
77 |
+
print(f'*** nougat out files :{out_path} ***')
|
78 |
+
return out_path
|
79 |
+
dataloader = torch.utils.data.DataLoader(
|
80 |
+
ConcatDataset(datasets),
|
81 |
+
batch_size=batchsize,
|
82 |
+
shuffle=False,
|
83 |
+
collate_fn=LazyDataset.ignore_none_collate,
|
84 |
+
)
|
85 |
+
|
86 |
+
predictions = []
|
87 |
+
file_index = 0
|
88 |
+
page_num = 0
|
89 |
+
for i, (sample, is_last_page) in enumerate(tqdm(dataloader)):
|
90 |
+
model_output = model.inference(image_tensors=sample)
|
91 |
+
# check if model output is faulty
|
92 |
+
for j, output in enumerate(model_output["predictions"]):
|
93 |
+
if page_num == 0:
|
94 |
+
logging.info(
|
95 |
+
"Processing file %s with %i pages"
|
96 |
+
% (datasets[file_index].name, datasets[file_index].size)
|
97 |
+
)
|
98 |
+
page_num += 1
|
99 |
+
if output.strip() == "[MISSING_PAGE_POST]":
|
100 |
+
# uncaught repetitions -- most likely empty page
|
101 |
+
predictions.append(f"\n\n[MISSING_PAGE_EMPTY:{page_num}]\n\n")
|
102 |
+
elif model_output["repeats"][j] is not None:
|
103 |
+
if model_output["repeats"][j] > 0:
|
104 |
+
# If we end up here, it means the output is most likely not complete and was truncated.
|
105 |
+
logging.warning(f"Skipping page {page_num} due to repetitions.")
|
106 |
+
predictions.append(f"\n\n[MISSING_PAGE_FAIL:{page_num}]\n\n")
|
107 |
+
else:
|
108 |
+
# If we end up here, it means the document page is too different from the training domain.
|
109 |
+
# This can happen e.g. for cover pages.
|
110 |
+
predictions.append(
|
111 |
+
f"\n\n[MISSING_PAGE_EMPTY:{i*args.batchsize+j+1}]\n\n"
|
112 |
+
)
|
113 |
+
else:
|
114 |
+
if markdown:
|
115 |
+
output = markdown_compatible(output)
|
116 |
+
predictions.append(output)
|
117 |
+
if is_last_page[j]:
|
118 |
+
out = "".join(predictions).strip()
|
119 |
+
out = re.sub(r"\n{3,}", "\n\n", out).strip()
|
120 |
+
if output_path:
|
121 |
+
out_path = output_path / Path(is_last_page[j]).with_suffix(".mmd").name
|
122 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
123 |
+
out_path.write_text(out, encoding="utf-8")
|
124 |
+
else:
|
125 |
+
print(out, "\n\n")
|
126 |
+
predictions = []
|
127 |
+
page_num = 0
|
128 |
+
file_index += 1
|
129 |
+
print(f'the generated markdown file is : {out_path}')
|
130 |
+
return out_path
|
131 |
+
|
132 |
+
def get_pdf(pdf_link):
|
133 |
+
# Generate a unique filename
|
134 |
+
unique_filename = f"input/downloaded_paper_{uuid.uuid4().hex}.pdf"
|
135 |
+
|
136 |
+
# Send a GET request to the PDF link
|
137 |
+
response = requests.get(pdf_link)
|
138 |
+
|
139 |
+
if response.status_code == 200:
|
140 |
+
# Save the PDF content to a local file
|
141 |
+
with open(unique_filename, 'wb') as pdf_file:
|
142 |
+
pdf_file.write(response.content)
|
143 |
+
print("PDF downloaded successfully.")
|
144 |
+
else:
|
145 |
+
print("Failed to download the PDF.")
|
146 |
+
return unique_filename #.split('/')[-1][:-4]
|
147 |
+
|
148 |
+
|
149 |
+
def nougat_ocr(file_name):
|
150 |
+
|
151 |
+
#unique_filename = f"/content/output/downloaded_paper_{uuid.uuid4().hex}.pdf"
|
152 |
+
# Command to run
|
153 |
+
cli_command = [
|
154 |
+
'nougat',
|
155 |
+
#'--out', unique_filename,
|
156 |
+
'--out', 'output',
|
157 |
+
'pdf', f'{file_name}',
|
158 |
+
'--checkpoint', 'nougat',
|
159 |
+
'--markdown'
|
160 |
+
]
|
161 |
+
|
162 |
+
# Run the command and capture its output
|
163 |
+
#completed_process =
|
164 |
+
subprocess.run(cli_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
165 |
+
|
166 |
+
return #unique_filename
|
167 |
+
|
168 |
+
import pathlib
|
169 |
+
def predict(pdf_file, pdf_link):
|
170 |
+
print("*************** inference ******************")
|
171 |
+
if pdf_file is None:
|
172 |
+
if pdf_link == '':
|
173 |
+
print("No file is uploaded and No link is provided")
|
174 |
+
return "No data provided. Upload a pdf file or provide a pdf link and try again!"
|
175 |
+
else:
|
176 |
+
print(f'pdf_link is - {pdf_link}')
|
177 |
+
file_name = get_pdf(pdf_link)
|
178 |
+
print(f'file_name is - {file_name}')
|
179 |
+
else:
|
180 |
+
print(pdf_file)
|
181 |
+
file_name = pdf_file.name
|
182 |
+
print(file_name)
|
183 |
+
pdf_name = pdf_file.name.split('/')[-1].split('.')[0]
|
184 |
+
print(pdf_name)
|
185 |
+
|
186 |
+
# Call nougat
|
187 |
+
#nougat_ocr(file_name)
|
188 |
+
#nougat_predict(file_name)
|
189 |
+
input_files = file_name if isinstance(file_name, os.PathLike) else pathlib.Path(file_name),
|
190 |
+
#input_files = pathlib.Path(file_name),
|
191 |
+
output_path = pathlib.Path("./output")
|
192 |
+
checkpoint = pathlib.Path("./config1/")
|
193 |
+
config = pathlib.Path("./config1/config.json")
|
194 |
+
markdown = True
|
195 |
+
batchsize = BATCH_SIZE
|
196 |
+
output_files = nougat_predict(input_files=input_files, output_path=output_path, checkpoint = checkpoint, batchsize = batchsize, markdown = markdown, recompute=False)
|
197 |
+
print(f'the generated markdown file is : {output_files}')
|
198 |
+
#print("BACKKKK")
|
199 |
+
|
200 |
+
# Open the file for reading
|
201 |
+
file_name = file_name.split('/')[-1][:-4]
|
202 |
+
#with open(f'output/{file_name}.mmd', 'r') as file:
|
203 |
+
with open(output_files, 'r+') as file:
|
204 |
+
content = file.read()
|
205 |
+
# switch math delimiters
|
206 |
+
content = content.replace(r"\(", "\$").replace(r'\)', '\$').replace(r'\[', '\$\$').replace(r'\]', '\$\$')
|
207 |
+
print("***********************************")
|
208 |
+
print("convert successfully")
|
209 |
+
print("***********************************")
|
210 |
+
|
211 |
+
return content
|
212 |
+
|
213 |
+
|
214 |
+
def nougat_ocr1(file_name):
|
215 |
+
print('******* inside nougat_ocr *******')
|
216 |
+
# CLI Command to run
|
217 |
+
cli_command = [
|
218 |
+
'python predict',
|
219 |
+
'--out', 'output',
|
220 |
+
'pdf', f'{file_name}',
|
221 |
+
'--checkpoint', '../config1/',
|
222 |
+
'--markdown'
|
223 |
+
]
|
224 |
+
|
225 |
+
# Run the command and get .mmd file in an output folder
|
226 |
+
subprocess.run(cli_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
227 |
+
return
|
228 |
+
|
229 |
+
|
230 |
+
def predict1(pdf_file):
|
231 |
+
print('******* inside predict *******')
|
232 |
+
print(f"temporary file - {pdf_file.name}")
|
233 |
+
pdf_name = pdf_file.name.split('/')[-1].split('.')[0]
|
234 |
+
print(f"pdf file name - {pdf_name}")
|
235 |
+
|
236 |
+
#! Get prediction for a PDF using nougat
|
237 |
+
nougat_ocr(pdf_file.name)
|
238 |
+
print("BAACCKKK")
|
239 |
+
|
240 |
+
# Open the multimarkdown (.mmd) file for reading
|
241 |
+
with open(f'output/{pdf_name}.mmd', 'r') as file:
|
242 |
+
content = file.read()
|
243 |
+
|
244 |
+
return content
|
245 |
+
|
246 |
+
def process_example(pdf_file,pdf_link):
|
247 |
+
ocr_content = predict(pdf_file,pdf_link)
|
248 |
+
return gr.update(value=ocr_content)
|
249 |
+
|
250 |
+
css = """
|
251 |
+
#mkd {
|
252 |
+
height: 500px;
|
253 |
+
overflow: auto;
|
254 |
+
border: 1px solid #ccc;
|
255 |
+
}
|
256 |
+
"""
|
257 |
+
|
258 |
+
with gr.Blocks(css=css) as demo:
|
259 |
+
gr.HTML("<h1><center>Nougat: Neural Optical Understanding for Academic Documents<center><h1>")
|
260 |
+
gr.HTML("<h3><center>Lukas Blecher et al. <a href='https://arxiv.org/pdf/2308.13418.pdf' target='_blank'>Paper</a>, <a href='https://facebookresearch.github.io/nougat/'>Project</a><center></h3>")
|
261 |
+
|
262 |
+
with gr.Row():
|
263 |
+
mkd = gr.Markdown('<h4><center>Upload a PDF</center></h4>',scale=1)
|
264 |
+
mkd = gr.Markdown('<h4><center><i>OR</i></center></h4>',scale=1)
|
265 |
+
mkd = gr.Markdown('<h4><center>Provide a PDF link</center></h4>',scale=1)
|
266 |
+
|
267 |
+
with gr.Row(equal_height=True):
|
268 |
+
pdf_file = gr.File(label='PDF📃', file_count='single', scale=1)
|
269 |
+
pdf_link = gr.Textbox(placeholder='Enter an Arxiv link here', label='PDF link🔗🌐', scale=1)
|
270 |
+
|
271 |
+
with gr.Row():
|
272 |
+
btn = gr.Button('Run NOUGAT🍫')
|
273 |
+
clr = gr.Button('Clear🚿')
|
274 |
+
|
275 |
+
output_headline = gr.Markdown("<h3>PDF converted to markup language through Nougat-OCR👇:</h3>")
|
276 |
+
parsed_output = gr.Markdown(elem_id='mkd', value='📃🔤OCR Output')
|
277 |
+
|
278 |
+
btn.click(predict, [pdf_file, pdf_link], parsed_output )
|
279 |
+
print('******* 1 *******')
|
280 |
+
clr.click(lambda : (gr.update(value=None),
|
281 |
+
gr.update(value=None),
|
282 |
+
gr.update(value=None)),
|
283 |
+
[],
|
284 |
+
[pdf_file, pdf_link, parsed_output]
|
285 |
+
)
|
286 |
+
|
287 |
+
gr.Examples(
|
288 |
+
[["./input/test.pdf", ""], [None, "https://arxiv.org/pdf/2308.08316.pdf"]],
|
289 |
+
inputs = [pdf_file, pdf_link],
|
290 |
+
outputs = parsed_output,
|
291 |
+
fn=process_example,
|
292 |
+
cache_examples=True,
|
293 |
+
label='Click on any Examples below to get Nougat OCR results quickly:'
|
294 |
+
)
|
295 |
+
|
296 |
+
demo.queue()
|
297 |
+
demo.launch(debug=True,share=True, server_name="0.0.0.0")
|
nougat/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Donut
|
3 |
+
Copyright (c) 2022-present NAVER Corp.
|
4 |
+
MIT License
|
5 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
6 |
+
"""
|
7 |
+
from .model import NougatConfig, NougatModel
|
8 |
+
from .utils.dataset import NougatDataset
|
9 |
+
from ._version import __version__
|
10 |
+
|
11 |
+
__all__ = [
|
12 |
+
"NougatConfig",
|
13 |
+
"NougatModel",
|
14 |
+
"NougatDataset",
|
15 |
+
]
|
nougat/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (464 Bytes). View file
|
|
nougat/__pycache__/_version.cpython-310.pyc
ADDED
Binary file (355 Bytes). View file
|
|
nougat/__pycache__/model.cpython-310.pyc
ADDED
Binary file (19.9 kB). View file
|
|
nougat/__pycache__/postprocessing.cpython-310.pyc
ADDED
Binary file (12.6 kB). View file
|
|
nougat/__pycache__/transforms.cpython-310.pyc
ADDED
Binary file (5.66 kB). View file
|
|
nougat/_version.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
__version__ = "0.1.17"
|
nougat/dataset/__init__.py
ADDED
File without changes
|
nougat/dataset/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (151 Bytes). View file
|
|
nougat/dataset/__pycache__/rasterize.cpython-310.pyc
ADDED
Binary file (2.82 kB). View file
|
|
nougat/dataset/create_index.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
"""
|
8 |
+
This script creates an index of all available pages and parses the meta data for all pages into a separate file.
|
9 |
+
Optionally TesseractOCR is called for each image.
|
10 |
+
"""
|
11 |
+
import argparse
|
12 |
+
import json
|
13 |
+
from typing import Dict, List
|
14 |
+
import numpy as np
|
15 |
+
from pathlib import Path
|
16 |
+
import multiprocessing
|
17 |
+
from pebble import ProcessPool
|
18 |
+
from PIL import Image
|
19 |
+
import pytesseract
|
20 |
+
import re
|
21 |
+
import logging
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
|
25 |
+
logging.basicConfig()
|
26 |
+
logger = logging.getLogger()
|
27 |
+
logger.setLevel(logging.INFO)
|
28 |
+
|
29 |
+
|
30 |
+
def convert_pt2px(pt, dpi=96):
|
31 |
+
if isinstance(pt, list):
|
32 |
+
return [round(dpi / 72 * p) for p in pt]
|
33 |
+
elif isinstance(pt, dict):
|
34 |
+
for k in pt:
|
35 |
+
pt[k] = round(dpi / 72 * pt[k])
|
36 |
+
return pt
|
37 |
+
|
38 |
+
|
39 |
+
def read_metadata(data: Dict) -> List[List[Dict]]:
|
40 |
+
N = data["num_pages"]
|
41 |
+
out = [[] for _ in range(N)]
|
42 |
+
# pdffigures2 meta data
|
43 |
+
if "pdffigures" in data and data["pdffigures"]:
|
44 |
+
for item in data["pdffigures"]:
|
45 |
+
p = item.pop("page", None)
|
46 |
+
if p is None or p >= N:
|
47 |
+
continue
|
48 |
+
item["source"] = "fig"
|
49 |
+
if "regionBoundary" in item:
|
50 |
+
item["regionBoundary"] = convert_pt2px(item["regionBoundary"])
|
51 |
+
if "captionBoundary" in item:
|
52 |
+
item["captionBoundary"] = convert_pt2px(item["captionBoundary"])
|
53 |
+
out[p].append(item)
|
54 |
+
|
55 |
+
return out
|
56 |
+
|
57 |
+
|
58 |
+
def index_paper(directory: Path, args: argparse.Namespace):
|
59 |
+
"""
|
60 |
+
Pack all image-text pairs into a single h5 file and save it at `args.out`
|
61 |
+
"""
|
62 |
+
paper = directory.name
|
63 |
+
markdowns = directory.glob("*.mmd")
|
64 |
+
meta_file = directory / "meta.json"
|
65 |
+
data_samples = []
|
66 |
+
if not meta_file.exists():
|
67 |
+
return
|
68 |
+
# load meta info
|
69 |
+
try:
|
70 |
+
meta = read_metadata(json.load(meta_file.open("r", encoding="utf-8")))
|
71 |
+
except json.JSONDecodeError:
|
72 |
+
return
|
73 |
+
|
74 |
+
for md_path in markdowns:
|
75 |
+
image = md_path.parent / (md_path.stem + ".png")
|
76 |
+
i = int(image.stem) - 1
|
77 |
+
if not image.exists():
|
78 |
+
continue
|
79 |
+
if i >= len(meta):
|
80 |
+
continue
|
81 |
+
data_sample = {}
|
82 |
+
ocr_path = image.parent / (image.stem + "_OCR.txt")
|
83 |
+
if args.tesseract and not ocr_path.exists():
|
84 |
+
try:
|
85 |
+
pil = Image.open(image)
|
86 |
+
ocr = pytesseract.image_to_string(pil, lang="eng", timeout=2)
|
87 |
+
ocr = re.sub(r"\n+\s+?([^\s])", r"\n\n\1", ocr).strip()
|
88 |
+
with ocr_path.open("w", encoding="utf-8") as f_ocr:
|
89 |
+
f_ocr.write(ocr)
|
90 |
+
except RuntimeError:
|
91 |
+
logger.info("Page %s of paper %s timed out", image.stem, paper)
|
92 |
+
pass
|
93 |
+
if ocr_path.exists():
|
94 |
+
data_sample["ocr"] = str(ocr_path.relative_to(args.root))
|
95 |
+
data_sample["image"] = str(image.relative_to(args.root))
|
96 |
+
data_sample["markdown"] = md_path.read_text(encoding="utf8").strip()
|
97 |
+
data_sample["meta"] = meta[i]
|
98 |
+
data_samples.append(data_sample)
|
99 |
+
return data_samples
|
100 |
+
|
101 |
+
|
102 |
+
def create_index(args):
|
103 |
+
if not args.dir.exists() and not args.dir.is_dir():
|
104 |
+
logger.error("%s does not exist or is no dir.", args.dir)
|
105 |
+
return
|
106 |
+
papers = []
|
107 |
+
depth = 0
|
108 |
+
p = args.dir
|
109 |
+
while True:
|
110 |
+
p = next(p.iterdir())
|
111 |
+
if p.is_file():
|
112 |
+
break
|
113 |
+
else:
|
114 |
+
depth += 1
|
115 |
+
papers = args.dir.glob("*/" * depth)
|
116 |
+
index = []
|
117 |
+
with ProcessPool(max_workers=args.workers) as pool:
|
118 |
+
tasks = {}
|
119 |
+
for j, paper in enumerate(papers):
|
120 |
+
fname = paper.name
|
121 |
+
tasks[fname] = pool.schedule(
|
122 |
+
index_paper,
|
123 |
+
args=[paper, args],
|
124 |
+
timeout=args.timeout,
|
125 |
+
)
|
126 |
+
|
127 |
+
for fname in tqdm(tasks):
|
128 |
+
try:
|
129 |
+
res = tasks[fname].result()
|
130 |
+
if res is None:
|
131 |
+
logger.info("%s is faulty", fname)
|
132 |
+
continue
|
133 |
+
index.append(res)
|
134 |
+
except TimeoutError:
|
135 |
+
logger.info("%s timed out", fname)
|
136 |
+
|
137 |
+
with args.out.open("w", encoding="utf-8") as f:
|
138 |
+
for item in index:
|
139 |
+
for page in item:
|
140 |
+
if len(page) == 0:
|
141 |
+
continue
|
142 |
+
f.write(json.dumps(page) + "\n")
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
parser = argparse.ArgumentParser()
|
147 |
+
parser.add_argument("--out", type=Path, required=True, help="Index file")
|
148 |
+
parser.add_argument(
|
149 |
+
"--dir", type=Path, required=True, help="Parent directory for input dirs"
|
150 |
+
)
|
151 |
+
parser.add_argument("--root", type=Path, default=None)
|
152 |
+
parser.add_argument(
|
153 |
+
"--tesseract",
|
154 |
+
action="store_true",
|
155 |
+
help="Tesseract OCR prediction for each page",
|
156 |
+
)
|
157 |
+
parser.add_argument(
|
158 |
+
"--workers",
|
159 |
+
type=int,
|
160 |
+
default=multiprocessing.cpu_count(),
|
161 |
+
help="How many processes to use",
|
162 |
+
)
|
163 |
+
parser.add_argument(
|
164 |
+
"--dpi", type=int, default=96, help="DPI the images were saved with"
|
165 |
+
)
|
166 |
+
parser.add_argument("--timeout", type=int, default=240, help="Max time per paper")
|
167 |
+
args = parser.parse_args()
|
168 |
+
if args.root is None:
|
169 |
+
args.root = args.dir
|
170 |
+
else:
|
171 |
+
# check if dir is subdir of root
|
172 |
+
args.dir.relative_to(args.root)
|
173 |
+
create_index(args)
|
nougat/dataset/gen_seek.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
from tqdm import tqdm
|
8 |
+
import json
|
9 |
+
from pathlib import Path
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
|
13 |
+
def get_args():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument("src_file", nargs="+", type=Path, help="JSONL file in question")
|
16 |
+
args = parser.parse_args()
|
17 |
+
return args
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
args = get_args()
|
22 |
+
for file in args.src_file:
|
23 |
+
seek_map = []
|
24 |
+
seek_pos = 0
|
25 |
+
with open(file) as f:
|
26 |
+
with tqdm(smoothing=0.0) as pbar:
|
27 |
+
line = f.readline()
|
28 |
+
while line:
|
29 |
+
seek_map.append(seek_pos)
|
30 |
+
seek_pos = f.tell()
|
31 |
+
line = f.readline()
|
32 |
+
pbar.update(1)
|
33 |
+
|
34 |
+
out_file = file.parent / (file.stem + ".seek.map")
|
35 |
+
with open(out_file, "w") as f:
|
36 |
+
f.write(json.dumps(seek_map))
|
nougat/dataset/parser/__init__.py
ADDED
File without changes
|
nougat/dataset/parser/document.py
ADDED
@@ -0,0 +1,703 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
from collections import defaultdict
|
8 |
+
from copy import copy
|
9 |
+
import itertools
|
10 |
+
import re
|
11 |
+
from dataclasses import dataclass, field, asdict
|
12 |
+
from typing import (
|
13 |
+
Any,
|
14 |
+
List,
|
15 |
+
Dict,
|
16 |
+
Optional,
|
17 |
+
TypeVar,
|
18 |
+
Type,
|
19 |
+
Generic,
|
20 |
+
)
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
import logging
|
24 |
+
|
25 |
+
logger = logging.getLogger()
|
26 |
+
|
27 |
+
from dataclasses import dataclass, field, asdict
|
28 |
+
from typing import List, Dict, TypeVar, Type, Generic
|
29 |
+
|
30 |
+
T = TypeVar("T")
|
31 |
+
EL = TypeVar("EL")
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class Element(Generic[EL]):
|
36 |
+
"""
|
37 |
+
Generic class representing an element with children in a tree-like structure.
|
38 |
+
|
39 |
+
Attributes:
|
40 |
+
parent (Element): The parent element.
|
41 |
+
children (List[Element]): List of child elements.
|
42 |
+
"""
|
43 |
+
|
44 |
+
parent: "Element" = None
|
45 |
+
children: List["Element"] = field(default_factory=list)
|
46 |
+
|
47 |
+
@property
|
48 |
+
def plaintext(self):
|
49 |
+
return "".join([child.plaintext for child in self.children])
|
50 |
+
|
51 |
+
def append(self, child: EL) -> EL:
|
52 |
+
self.children.append(child)
|
53 |
+
child.parent = self
|
54 |
+
return child
|
55 |
+
|
56 |
+
def find_parent(self, class_or_tuple: Type[T]) -> T:
|
57 |
+
elem = self
|
58 |
+
while elem:
|
59 |
+
if isinstance(elem, class_or_tuple):
|
60 |
+
return elem
|
61 |
+
elem = elem.parent
|
62 |
+
return None
|
63 |
+
|
64 |
+
|
65 |
+
@dataclass
|
66 |
+
class UnknownElement(Element):
|
67 |
+
pass
|
68 |
+
|
69 |
+
|
70 |
+
@dataclass
|
71 |
+
class TextElement(Element):
|
72 |
+
content: str = ""
|
73 |
+
|
74 |
+
@property
|
75 |
+
def plaintext(self):
|
76 |
+
return self.content
|
77 |
+
|
78 |
+
def append(self, child: "Element"):
|
79 |
+
raise Exception(f"Cannot append elements to {self.__class__.__name__}")
|
80 |
+
|
81 |
+
|
82 |
+
@dataclass
|
83 |
+
class Math(Element):
|
84 |
+
pass
|
85 |
+
|
86 |
+
|
87 |
+
@dataclass
|
88 |
+
class PlaintextMath(Math):
|
89 |
+
pass
|
90 |
+
|
91 |
+
|
92 |
+
@dataclass
|
93 |
+
class LatexMath(Math):
|
94 |
+
inline: bool = True
|
95 |
+
code: str = ""
|
96 |
+
|
97 |
+
@property
|
98 |
+
def plaintext(self):
|
99 |
+
return self.code
|
100 |
+
|
101 |
+
|
102 |
+
@dataclass
|
103 |
+
class Author:
|
104 |
+
fullname: str = None
|
105 |
+
lastname: str = None
|
106 |
+
affiliation: str = None
|
107 |
+
|
108 |
+
|
109 |
+
@dataclass
|
110 |
+
class Link(Element):
|
111 |
+
target: str = None
|
112 |
+
|
113 |
+
|
114 |
+
@dataclass
|
115 |
+
class InlineRef(Element):
|
116 |
+
target: str = None
|
117 |
+
|
118 |
+
def as_dict(self):
|
119 |
+
return {
|
120 |
+
"target": self.target,
|
121 |
+
}
|
122 |
+
|
123 |
+
|
124 |
+
@dataclass
|
125 |
+
class Reference:
|
126 |
+
"""
|
127 |
+
Data class representing a reference with various attributes.
|
128 |
+
|
129 |
+
Attributes:
|
130 |
+
title (Element): The title of the reference.
|
131 |
+
authors (List[Author]): List of authors of the reference.
|
132 |
+
ids (Dict[str, str]): Dictionary of identification information.
|
133 |
+
date (str): The publication date of the reference.
|
134 |
+
url (str): The URL link to the reference.
|
135 |
+
journal (str): The journal where the reference is published.
|
136 |
+
full_text (str): The full text content of the reference.
|
137 |
+
|
138 |
+
Methods:
|
139 |
+
as_dict(): Convert the reference object to a dictionary.
|
140 |
+
"""
|
141 |
+
|
142 |
+
title: Element = None
|
143 |
+
authors: List[Author] = field(default_factory=list)
|
144 |
+
ids: Dict[str, str] = field(default_factory=dict)
|
145 |
+
date: str = None
|
146 |
+
url: str = None
|
147 |
+
journal: str = None
|
148 |
+
full_text: str = None
|
149 |
+
|
150 |
+
def as_dict(self):
|
151 |
+
return {
|
152 |
+
"title": self.title.plaintext,
|
153 |
+
"authors": [asdict(auth) for auth in self.authors],
|
154 |
+
"ids": self.ids,
|
155 |
+
"date": self.date,
|
156 |
+
"url": self.url,
|
157 |
+
"journal": self.journal,
|
158 |
+
"full_text": self.full_text,
|
159 |
+
}
|
160 |
+
|
161 |
+
|
162 |
+
@dataclass
|
163 |
+
class SpanElement(Element):
|
164 |
+
pass
|
165 |
+
|
166 |
+
|
167 |
+
@dataclass
|
168 |
+
class Italic(SpanElement):
|
169 |
+
pass
|
170 |
+
|
171 |
+
|
172 |
+
@dataclass
|
173 |
+
class Bold(SpanElement):
|
174 |
+
pass
|
175 |
+
|
176 |
+
|
177 |
+
@dataclass
|
178 |
+
class Superscript(SpanElement):
|
179 |
+
pass
|
180 |
+
|
181 |
+
|
182 |
+
@dataclass
|
183 |
+
class Subscript(SpanElement):
|
184 |
+
pass
|
185 |
+
|
186 |
+
|
187 |
+
@dataclass
|
188 |
+
class Paragraph(Element):
|
189 |
+
pass
|
190 |
+
|
191 |
+
|
192 |
+
@dataclass
|
193 |
+
class TableRow(Element):
|
194 |
+
cells: List[Element] = field(default_factory=list)
|
195 |
+
|
196 |
+
def add_cell(self, cell: Element):
|
197 |
+
self.cells.append(cell)
|
198 |
+
cell.parent = self
|
199 |
+
return cell
|
200 |
+
|
201 |
+
@property
|
202 |
+
def plaintext(self):
|
203 |
+
return "\t".join([cell.plaintext for cell in self.cells])
|
204 |
+
|
205 |
+
|
206 |
+
@dataclass
|
207 |
+
class TableHead(TableRow):
|
208 |
+
pass
|
209 |
+
|
210 |
+
|
211 |
+
@dataclass
|
212 |
+
class Table(Element):
|
213 |
+
id: str = None
|
214 |
+
header: Element = None
|
215 |
+
caption: Element = None
|
216 |
+
rows: List[TableRow] = field(default_factory=list)
|
217 |
+
keep_table: bool = False
|
218 |
+
|
219 |
+
def add_row(self, row: TableRow) -> TableRow:
|
220 |
+
self.rows.append(row)
|
221 |
+
row.parent = self
|
222 |
+
return row
|
223 |
+
|
224 |
+
@property
|
225 |
+
def plaintext(self):
|
226 |
+
return "\n".join([row.plaintext for row in self.rows])
|
227 |
+
|
228 |
+
|
229 |
+
@dataclass
|
230 |
+
class Equation(Element):
|
231 |
+
pass
|
232 |
+
|
233 |
+
|
234 |
+
@dataclass
|
235 |
+
class EquationList(Element):
|
236 |
+
equations: List[Equation] = field(default_factory=list)
|
237 |
+
|
238 |
+
def add_equation(self, eqn: Equation) -> Equation:
|
239 |
+
self.equations.append(eqn)
|
240 |
+
eqn.parent = self
|
241 |
+
return eqn
|
242 |
+
|
243 |
+
@property
|
244 |
+
def plaintext(self):
|
245 |
+
return "\n".join([eqn.plaintext for eqn in self.equations])
|
246 |
+
|
247 |
+
|
248 |
+
@dataclass
|
249 |
+
class Algorithm(Element):
|
250 |
+
caption: Element = None
|
251 |
+
lines: List[Element] = field(default_factory=list)
|
252 |
+
inline: bool = False
|
253 |
+
|
254 |
+
def add_line(self, line: Element) -> Element:
|
255 |
+
self.lines.append(line)
|
256 |
+
line.parent = self
|
257 |
+
return line
|
258 |
+
|
259 |
+
@property
|
260 |
+
def plaintext(self):
|
261 |
+
return "\n".join([line.plaintext for line in self.lines])
|
262 |
+
|
263 |
+
|
264 |
+
@dataclass
|
265 |
+
class Definition(Element):
|
266 |
+
term: Element = None
|
267 |
+
definition: Element = None
|
268 |
+
|
269 |
+
@property
|
270 |
+
def plaintext(self):
|
271 |
+
parts = []
|
272 |
+
if self.term:
|
273 |
+
parts.append(f"{self.term.plaintext}:")
|
274 |
+
if self.definition:
|
275 |
+
parts.append(self.definition.plaintext)
|
276 |
+
return " ".join(parts)
|
277 |
+
|
278 |
+
|
279 |
+
@dataclass
|
280 |
+
class DefinitionList(Element):
|
281 |
+
"""
|
282 |
+
Data class representing a list of definitions with an optional header.
|
283 |
+
|
284 |
+
Attributes:
|
285 |
+
header (Element): The header element for the definition list.
|
286 |
+
items (List[Definition]): List of Definition elements.
|
287 |
+
|
288 |
+
Methods:
|
289 |
+
add_item(item: Definition) -> Definition: Add a definition item to the list.
|
290 |
+
"""
|
291 |
+
|
292 |
+
header: Element = None
|
293 |
+
items: List[Element] = field(default_factory=list)
|
294 |
+
|
295 |
+
def add_item(self, item: Definition) -> Definition:
|
296 |
+
self.items.append(item)
|
297 |
+
item.parent = self
|
298 |
+
return item
|
299 |
+
|
300 |
+
@property
|
301 |
+
def plaintext(self):
|
302 |
+
parts = []
|
303 |
+
if self.header:
|
304 |
+
parts.append(self.header.plaintext)
|
305 |
+
parts.extend([df.plaintext for df in self.items])
|
306 |
+
return "\n".join(parts)
|
307 |
+
|
308 |
+
|
309 |
+
@dataclass
|
310 |
+
class Figure(Element):
|
311 |
+
id: str = None
|
312 |
+
header: Element = None
|
313 |
+
caption: Element = None
|
314 |
+
|
315 |
+
|
316 |
+
@dataclass
|
317 |
+
class Section(Element):
|
318 |
+
id: str = None
|
319 |
+
header: Element = None
|
320 |
+
level: int = 0
|
321 |
+
hnum: int = 1
|
322 |
+
|
323 |
+
|
324 |
+
@dataclass
|
325 |
+
class SectionHeader(Element):
|
326 |
+
id: str = None
|
327 |
+
header: Element = None
|
328 |
+
level: int = 0
|
329 |
+
|
330 |
+
|
331 |
+
@dataclass
|
332 |
+
class ListItem(Element):
|
333 |
+
label: str = ""
|
334 |
+
|
335 |
+
|
336 |
+
@dataclass
|
337 |
+
class ListContainer(Element):
|
338 |
+
level: int = 0
|
339 |
+
ordered: bool = False
|
340 |
+
items: List[Element] = field(default_factory=list)
|
341 |
+
|
342 |
+
def add_item(self, item: ListItem) -> ListItem:
|
343 |
+
self.items.append(item)
|
344 |
+
item.parent = self
|
345 |
+
return item
|
346 |
+
|
347 |
+
@property
|
348 |
+
def plaintext(self):
|
349 |
+
return "\n".join([item.plaintext for item in self.items])
|
350 |
+
|
351 |
+
|
352 |
+
@dataclass
|
353 |
+
class Footnote(Element):
|
354 |
+
id: str = None
|
355 |
+
|
356 |
+
|
357 |
+
@dataclass
|
358 |
+
class Document(Element, Reference):
|
359 |
+
abstract: Element = None
|
360 |
+
language: str = None
|
361 |
+
keywords: List[Element] = field(default_factory=list)
|
362 |
+
references: List[Reference] = field(default_factory=list)
|
363 |
+
inline_refs: List[InlineRef] = field(default_factory=list)
|
364 |
+
bib: Reference = None
|
365 |
+
|
366 |
+
def add_reference(self, reference):
|
367 |
+
self.references.append(reference)
|
368 |
+
|
369 |
+
def add_inline_ref(self, in_ref):
|
370 |
+
self.inline_refs.append(in_ref)
|
371 |
+
|
372 |
+
def set_bib(self, reference):
|
373 |
+
self.bib = reference
|
374 |
+
|
375 |
+
|
376 |
+
@dataclass
|
377 |
+
class Spec:
|
378 |
+
"""
|
379 |
+
Data class representing specifications for table cells.
|
380 |
+
|
381 |
+
Attributes:
|
382 |
+
t (int): The top border size.
|
383 |
+
b (int): The bottom border size.
|
384 |
+
l (int): The left border size.
|
385 |
+
r (int): The right border size.
|
386 |
+
align (str): The alignment of the cell content ('c' for center, 'l' for left, 'r' for right,
|
387 |
+
or 'p{width}' for justified with a specified width).
|
388 |
+
|
389 |
+
Methods:
|
390 |
+
__hash__() -> int: Compute the hash of the specification.
|
391 |
+
__eq__(__o: object) -> bool: Check if two specifications are equal.
|
392 |
+
set_align(classes: List[str], style: Optional[str] = None) -> None:
|
393 |
+
Extract alignment information from HTML classes.
|
394 |
+
set_border(classes: List[str]) -> None: Automatically set border specifications.
|
395 |
+
set_attrs(attrs: Dict[str, Any]) -> None: Automatically set all attributes from HTML class attributes.
|
396 |
+
__str__() -> str: Get the string representation of the specification.
|
397 |
+
"""
|
398 |
+
|
399 |
+
t: int = field(default=0, repr=False)
|
400 |
+
b: int = field(default=0, repr=False)
|
401 |
+
l: int = field(default=0)
|
402 |
+
r: int = field(default=0)
|
403 |
+
align: str = field(default="")
|
404 |
+
|
405 |
+
def __hash__(self) -> int:
|
406 |
+
return hash(repr(self))
|
407 |
+
|
408 |
+
def __eq__(self, __o: object) -> bool:
|
409 |
+
return repr(self) == repr(__o)
|
410 |
+
|
411 |
+
def set_align(self, classes: List[str], style: Optional[str] = None) -> None:
|
412 |
+
"""extract alignment information from available classes (html)"""
|
413 |
+
aligns = [s for s in classes if "align" in s]
|
414 |
+
if len(aligns) == 0:
|
415 |
+
return
|
416 |
+
elif len(aligns) > 1:
|
417 |
+
logger.warn("Found multiple aligns in classes: %s", ", ".join(classes))
|
418 |
+
align = aligns[0]
|
419 |
+
if "center" in align or align == "c":
|
420 |
+
self.align = "c"
|
421 |
+
elif "left" in align or align == "l":
|
422 |
+
self.align = "l"
|
423 |
+
elif "right" in align or align == "r":
|
424 |
+
self.align = "r"
|
425 |
+
elif "justify" in align or align == "p":
|
426 |
+
# assert style is not None, "justify without style information"
|
427 |
+
if style is None:
|
428 |
+
self.align = "c"
|
429 |
+
else:
|
430 |
+
width = style.partition("width:")[2].partition(";")[0]
|
431 |
+
self.align = "p{%s}" % width
|
432 |
+
else:
|
433 |
+
logger.warn(
|
434 |
+
"only center, left, right, justify supported at the moment. Found %s",
|
435 |
+
align,
|
436 |
+
)
|
437 |
+
self.align = "c"
|
438 |
+
|
439 |
+
def set_border(self, classes: List[str]) -> None:
|
440 |
+
"""automatically set spec with border classes e.g 'ltx_border_t'"""
|
441 |
+
for border in classes:
|
442 |
+
orientation = border.partition("border_")[2]
|
443 |
+
if len(orientation) > 0 and orientation[0] in "tbrl":
|
444 |
+
setattr(self, orientation[0], len(orientation))
|
445 |
+
|
446 |
+
def set_attrs(self, attrs: Dict[str, Any]) -> None:
|
447 |
+
"""automatically set all attr from html class attributes"""
|
448 |
+
classes = attrs["class"]
|
449 |
+
style = attrs["style"] if "style" in attrs else None
|
450 |
+
|
451 |
+
self.set_align(classes, style=style)
|
452 |
+
self.set_border(classes)
|
453 |
+
|
454 |
+
def __str__(self) -> str:
|
455 |
+
if self.align:
|
456 |
+
return "|" * self.l + self.align + "|" * self.r
|
457 |
+
else:
|
458 |
+
# default center
|
459 |
+
return "|" * self.l + "c" + "|" * self.r
|
460 |
+
|
461 |
+
|
462 |
+
@dataclass
|
463 |
+
class TableCell(Element):
|
464 |
+
"""
|
465 |
+
Represents a cell in an HTML table.
|
466 |
+
|
467 |
+
Attributes:
|
468 |
+
multicolumn (Optional[int]): The number of columns spanned by the cell.
|
469 |
+
multirow (Optional[int]): The number of rows spanned by the cell.
|
470 |
+
spec (Spec): The specification for the cell's formatting.
|
471 |
+
content (Element): The content of the cell.
|
472 |
+
|
473 |
+
Methods:
|
474 |
+
__post_init__(*args, **kwargs) -> None: Initialize the cell, ensuring that the spec property is not None.
|
475 |
+
__hash__() -> int: Compute the hash of the cell.
|
476 |
+
__eq__(__o: object) -> bool: Check if two cells are equal.
|
477 |
+
set_attrs(attrs: Dict[str, Any]) -> None: Set attributes for the cell from HTML attributes.
|
478 |
+
plaintext() -> str: Get the plaintext content of the cell.
|
479 |
+
"""
|
480 |
+
|
481 |
+
multicolumn: Optional[int] = None
|
482 |
+
multirow: Optional[int] = None
|
483 |
+
spec: Spec = None
|
484 |
+
content: Element = None
|
485 |
+
|
486 |
+
def __post_init__(self, *args, **kwargs) -> None:
|
487 |
+
# spec property cannot be None
|
488 |
+
if self.spec is None:
|
489 |
+
self.spec = Spec()
|
490 |
+
|
491 |
+
def __hash__(self) -> int:
|
492 |
+
return hash(repr(self))
|
493 |
+
|
494 |
+
def __eq__(self, __o: object) -> bool:
|
495 |
+
return repr(self) == repr(__o)
|
496 |
+
|
497 |
+
def set_attrs(self, attrs: Dict[str, Any]) -> None:
|
498 |
+
if "colspan" in attrs:
|
499 |
+
self.multicolumn = int(attrs["colspan"])
|
500 |
+
if "rowspan" in attrs:
|
501 |
+
self.multirow = int(attrs["rowspan"])
|
502 |
+
self.spec.set_attrs(attrs)
|
503 |
+
|
504 |
+
@property
|
505 |
+
def plaintext(self):
|
506 |
+
if self.content is None:
|
507 |
+
return ""
|
508 |
+
return self.content.plaintext
|
509 |
+
|
510 |
+
|
511 |
+
@dataclass
|
512 |
+
class TableRow(Element):
|
513 |
+
"""
|
514 |
+
Represents a row in an HTML table.
|
515 |
+
|
516 |
+
Attributes:
|
517 |
+
cells (List[TableCell]): The list of cells in the row.
|
518 |
+
|
519 |
+
Methods:
|
520 |
+
add_cell(cell: TableCell) -> TableCell: Add a cell to the row.
|
521 |
+
__iter__() -> Iterator: Iterate through the cells in the row.
|
522 |
+
__len__() -> int: Get the number of cells in the row.
|
523 |
+
__bool__() -> bool: Check if the row is not empty.
|
524 |
+
cum_cell_widths() -> List[int]: Get the cumulative cell widths.
|
525 |
+
cell_widths() -> List[int]: Get the widths of individual cells.
|
526 |
+
width() -> int: Get the total width of the row.
|
527 |
+
_hline(orientation: str) -> str: Determine horizontal lines to be inserted.
|
528 |
+
hline_above() -> str: Get the horizontal line description for the top of the row.
|
529 |
+
hline_below() -> str: Get the horizontal line description for the bottom of the row.
|
530 |
+
plaintext() -> str: Get the plaintext content of the row.
|
531 |
+
"""
|
532 |
+
|
533 |
+
cells: List[TableCell] = field(default_factory=list)
|
534 |
+
|
535 |
+
def add_cell(self, cell: TableCell):
|
536 |
+
self.cells.append(cell)
|
537 |
+
cell.parent = self
|
538 |
+
return cell
|
539 |
+
|
540 |
+
def __iter__(self):
|
541 |
+
return iter(self.cells)
|
542 |
+
|
543 |
+
def __len__(self) -> int:
|
544 |
+
return len(self.cells)
|
545 |
+
|
546 |
+
def __bool__(self) -> bool:
|
547 |
+
return True
|
548 |
+
|
549 |
+
@property
|
550 |
+
def cum_cell_widths(self) -> List[int]:
|
551 |
+
return np.cumsum(self.cell_widths)
|
552 |
+
|
553 |
+
@property
|
554 |
+
def cell_widths(self) -> List[int]:
|
555 |
+
return [(cell.multicolumn or 1) for cell in self.cells]
|
556 |
+
|
557 |
+
@property
|
558 |
+
def width(self) -> int:
|
559 |
+
return sum(self.cell_widths)
|
560 |
+
|
561 |
+
def _hline(self, orientation: str) -> str:
|
562 |
+
"""Figure out if and where horizontal lines need to be inserted.
|
563 |
+
|
564 |
+
Args:
|
565 |
+
orientation (str): Either 't' (top) or 'b' (bottom)
|
566 |
+
|
567 |
+
Returns:
|
568 |
+
str: Correct vertical line description for latex tables.
|
569 |
+
"""
|
570 |
+
assert orientation == "t" or orientation == "b"
|
571 |
+
lines = []
|
572 |
+
for cell in self.cells:
|
573 |
+
lines.extend([getattr(cell.spec, orientation)] * (cell.multicolumn or 1))
|
574 |
+
lines.append(0)
|
575 |
+
indices = []
|
576 |
+
start = None
|
577 |
+
for i, v in enumerate(lines):
|
578 |
+
if v and start is None:
|
579 |
+
start = i
|
580 |
+
elif start is not None and not v:
|
581 |
+
indices.append((start, i - 1))
|
582 |
+
start = None
|
583 |
+
s = ""
|
584 |
+
for a, b in indices:
|
585 |
+
if b - a + 1 == self.width:
|
586 |
+
s += "\\hline " * lines[0]
|
587 |
+
else:
|
588 |
+
s += "\\cline{%i-%i} " % (a + 1, b + 1)
|
589 |
+
return s.strip()
|
590 |
+
|
591 |
+
@property
|
592 |
+
def hline_above(self) -> str:
|
593 |
+
return self._hline("t")
|
594 |
+
|
595 |
+
@property
|
596 |
+
def hline_below(self) -> str:
|
597 |
+
return self._hline("b")
|
598 |
+
|
599 |
+
@property
|
600 |
+
def plaintext(self) -> str:
|
601 |
+
return "\t".join([cell.plaintext for cell in self.cells])
|
602 |
+
|
603 |
+
|
604 |
+
@dataclass
|
605 |
+
class Tabular(Element):
|
606 |
+
rows: List[TableRow] = field(default_factory=list)
|
607 |
+
"""
|
608 |
+
Represents a tabular structure, such as an HTML table.
|
609 |
+
|
610 |
+
Attributes:
|
611 |
+
rows (List[TableRow]): The list of rows in the tabular structure.
|
612 |
+
|
613 |
+
Methods:
|
614 |
+
add_row(row: TableRow) -> TableRow: Add a row to the tabular structure.
|
615 |
+
width() -> int: Get the maximum width of the tabular structure.
|
616 |
+
cols() -> List[List[TableCell]]: Get a list of columns in the tabular structure.
|
617 |
+
_square_table() -> None: Ensure the table has an equal number of columns in each row.
|
618 |
+
get_table_spec() -> str: Generate a LaTeX table specification based on cell alignments.
|
619 |
+
plaintext() -> str: Get the plaintext content of the tabular structure.
|
620 |
+
"""
|
621 |
+
|
622 |
+
def add_row(self, row: TableRow) -> TableRow:
|
623 |
+
self.rows.append(row)
|
624 |
+
row.parent = self
|
625 |
+
return row
|
626 |
+
|
627 |
+
@property
|
628 |
+
def width(self) -> int:
|
629 |
+
if len(self.rows) > 0:
|
630 |
+
return max([r.width for r in self.rows])
|
631 |
+
else:
|
632 |
+
return 0
|
633 |
+
|
634 |
+
@property
|
635 |
+
def cols(self) -> List[List[TableCell]]:
|
636 |
+
return list(
|
637 |
+
map(
|
638 |
+
list,
|
639 |
+
itertools.zip_longest(*[r.cells for r in self.rows], fillvalue=None),
|
640 |
+
)
|
641 |
+
)
|
642 |
+
|
643 |
+
def _square_table(self) -> None:
|
644 |
+
"""check if number of columns is equal for every row. Add placeholders for `\multirow` instances"""
|
645 |
+
for i, row in enumerate(self.rows):
|
646 |
+
for j, cell in enumerate(row.cells):
|
647 |
+
if cell.multirow is not None and cell.multirow > 1:
|
648 |
+
spec = copy(cell.spec)
|
649 |
+
# assume no hlines in multi cells: disable bottom lines for top and top lines for lower cells.
|
650 |
+
spec.t = 0
|
651 |
+
cell.spec.b = 0
|
652 |
+
for k in range(i + 1, i + cell.multirow):
|
653 |
+
if k < len(self.rows):
|
654 |
+
for _ in range(row.cell_widths[j]):
|
655 |
+
# add empty cell
|
656 |
+
self.rows[k].cells.insert(
|
657 |
+
j, TableCell(parent=self.rows[k], spec=spec)
|
658 |
+
)
|
659 |
+
|
660 |
+
def get_table_spec(self) -> str:
|
661 |
+
"""Generates a LaTeX table spec."""
|
662 |
+
# First make table square
|
663 |
+
self._square_table()
|
664 |
+
# Find the most used spec in regular cells (no multi-col/row)
|
665 |
+
specs = [Spec() for _ in range(self.width)]
|
666 |
+
for i, col in enumerate(self.cols):
|
667 |
+
counts = defaultdict(int)
|
668 |
+
for cell in col:
|
669 |
+
if cell is None or cell.spec.align == "":
|
670 |
+
continue
|
671 |
+
if cell.multicolumn is None and cell.multirow is None:
|
672 |
+
counts[cell.spec] += 1
|
673 |
+
if len(counts) > 0:
|
674 |
+
specs[i] = max(counts, key=counts.get)
|
675 |
+
# convert all cells that don't match the column style into a multicol{1}{custom_spec}
|
676 |
+
for i, col in enumerate(self.cols):
|
677 |
+
for cell in col:
|
678 |
+
if cell is not None and cell.spec != specs[i]:
|
679 |
+
# check if there is text in the cell. If not alignment doesn't matter
|
680 |
+
if (
|
681 |
+
len(cell.children) == 0
|
682 |
+
and cell.spec.l == specs[i].l
|
683 |
+
and cell.spec.r == specs[i].r
|
684 |
+
):
|
685 |
+
continue
|
686 |
+
# convert any standard cell into a multicol cell of width 1
|
687 |
+
if cell.multicolumn is None:
|
688 |
+
cell.multicolumn = 1
|
689 |
+
# generate final latex table spec
|
690 |
+
out = " ".join([str(spec) for spec in specs])
|
691 |
+
out = re.sub(r"(\|) +(\w)", r"\1\2", out)
|
692 |
+
out = re.sub(r"(\w) +(\|)", r"\1\2", out)
|
693 |
+
return out
|
694 |
+
|
695 |
+
@property
|
696 |
+
def plaintext(self):
|
697 |
+
return "\n".join([row.plaintext for row in self.rows])
|
698 |
+
|
699 |
+
|
700 |
+
@dataclass
|
701 |
+
class Table(Element):
|
702 |
+
id: str = None
|
703 |
+
caption: Element = None
|
nougat/dataset/parser/html2md.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
import argparse
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import List, Optional
|
10 |
+
from bs4 import BeautifulSoup
|
11 |
+
from tqdm import tqdm
|
12 |
+
import htmlmin
|
13 |
+
from nougat.dataset.parser.latexml_parser import parse_latexml, _clean_html_whitespace
|
14 |
+
from nougat.dataset.parser.markdown import format_document
|
15 |
+
|
16 |
+
|
17 |
+
def check_file_path(paths: List[Path], wdir: Optional[Path] = None) -> List[str]:
|
18 |
+
"""
|
19 |
+
Checks if the given file paths exist.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
paths: A list of file paths.
|
23 |
+
wdir: The working directory. If None, the current working directory is used.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
A list of file paths that exist.
|
27 |
+
"""
|
28 |
+
files = []
|
29 |
+
for path in paths:
|
30 |
+
if type(path) == str:
|
31 |
+
if path == "":
|
32 |
+
continue
|
33 |
+
path = Path(path)
|
34 |
+
pathsi = [path] if wdir is None else [path, wdir / path]
|
35 |
+
for p in pathsi:
|
36 |
+
if p.exists():
|
37 |
+
files.append((p.resolve()))
|
38 |
+
elif "*" in path.name:
|
39 |
+
files.extend([(pi.resolve()) for pi in p.parent.glob(p.name)])
|
40 |
+
return list(set(files))
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
parser = argparse.ArgumentParser()
|
45 |
+
parser.add_argument("--html", type=Path, nargs="+", help="HTML file", required=True)
|
46 |
+
parser.add_argument("--out", type=Path, help="Output file", required=True)
|
47 |
+
args = parser.parse_args()
|
48 |
+
args.html = check_file_path(args.html)
|
49 |
+
for f in tqdm(args.html):
|
50 |
+
html = BeautifulSoup(
|
51 |
+
htmlmin.minify(
|
52 |
+
open(f, "r", encoding="utf-8").read().replace("\xa0", " "),
|
53 |
+
remove_all_empty_space=1,
|
54 |
+
),
|
55 |
+
features="html.parser",
|
56 |
+
)
|
57 |
+
try:
|
58 |
+
doc = parse_latexml(html)
|
59 |
+
except ValueError as e:
|
60 |
+
print(e)
|
61 |
+
continue
|
62 |
+
if doc is None:
|
63 |
+
continue
|
64 |
+
out, fig = format_document(doc, keep_refs=True)
|
65 |
+
outp = (args.out if args.out.is_dir() else args.out.parent) / (f.stem + ".mmd")
|
66 |
+
with open(outp, "w", encoding="utf-8") as f:
|
67 |
+
f.write(out)
|
nougat/dataset/parser/latexml_parser.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
import re
|
8 |
+
import sys
|
9 |
+
import requests
|
10 |
+
from typing import Optional, Set
|
11 |
+
from bs4 import BeautifulSoup, NavigableString
|
12 |
+
import soupsieve as sv
|
13 |
+
|
14 |
+
from nougat.dataset.parser.document import *
|
15 |
+
|
16 |
+
|
17 |
+
def printerr(*args, **kwargs):
|
18 |
+
# uncomment for debugging
|
19 |
+
# print(*args, **kwargs)
|
20 |
+
pass
|
21 |
+
|
22 |
+
|
23 |
+
latexml_wrapper_selector = sv.compile(
|
24 |
+
", ".join(
|
25 |
+
[
|
26 |
+
".ltx_engrafo_equation_container",
|
27 |
+
"tbody",
|
28 |
+
".ltx_note_content",
|
29 |
+
".ltx_role_footnote",
|
30 |
+
".ltx_note_type",
|
31 |
+
".ltx_theorem",
|
32 |
+
".ltx_proof",
|
33 |
+
".ltx_quote",
|
34 |
+
"blockquote",
|
35 |
+
".ltx_inline-para",
|
36 |
+
".ltx_inline-block",
|
37 |
+
]
|
38 |
+
)
|
39 |
+
)
|
40 |
+
latexml_ignore_selector = sv.compile(".ltx_rule, .ltx_pagination.ltx_role_newpage")
|
41 |
+
|
42 |
+
|
43 |
+
def is_wrapper_element(element: BeautifulSoup) -> bool:
|
44 |
+
return latexml_wrapper_selector.match(element)
|
45 |
+
|
46 |
+
|
47 |
+
def ignore_element(element: BeautifulSoup) -> bool:
|
48 |
+
return latexml_ignore_selector.match(element)
|
49 |
+
|
50 |
+
|
51 |
+
def _get_classes(el: BeautifulSoup) -> Set[str]:
|
52 |
+
if not hasattr(el, "attrs"):
|
53 |
+
return set()
|
54 |
+
classes = el.attrs.get("class")
|
55 |
+
if classes is None:
|
56 |
+
return set()
|
57 |
+
return set(classes)
|
58 |
+
|
59 |
+
|
60 |
+
def _detach_selected(element: BeautifulSoup, selector: str) -> None:
|
61 |
+
for elem in element.select(selector):
|
62 |
+
elem.extract()
|
63 |
+
|
64 |
+
|
65 |
+
def parse_latexml_authors(ltx_authors: BeautifulSoup) -> List[Author]:
|
66 |
+
authors = Paragraph()
|
67 |
+
parse_latexml_children(ltx_authors, authors)
|
68 |
+
return authors
|
69 |
+
|
70 |
+
|
71 |
+
def parse_latexml_citations(cite: BeautifulSoup, parent: Element) -> None:
|
72 |
+
"""
|
73 |
+
Parses LaTeXML citations and appends them as children to the given parent element.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
cite (BeautifulSoup): The BeautifulSoup object containing the citation data.
|
77 |
+
parent (Element): The parent element to which the citations will be added as children.
|
78 |
+
"""
|
79 |
+
parse_latexml_children(cite, parent)
|
80 |
+
if ("[" in parent.plaintext and "]" in parent.plaintext) or re.search(
|
81 |
+
r"[A-Za-z]", parent.plaintext
|
82 |
+
):
|
83 |
+
return
|
84 |
+
|
85 |
+
parent.children.insert(0, TextElement(content="["))
|
86 |
+
parent.children.append(TextElement(content="]"))
|
87 |
+
|
88 |
+
|
89 |
+
def _clean_html_whitespace(text: str) -> str:
|
90 |
+
if text.strip():
|
91 |
+
text = re.sub(r"(^\n+|\n+$)", "\n", text)
|
92 |
+
else:
|
93 |
+
text = text.strip("\n")
|
94 |
+
text = re.sub(r"[ \t]+", " ", text)
|
95 |
+
return text
|
96 |
+
|
97 |
+
|
98 |
+
def parse_latexml_children(html: BeautifulSoup, parent: Element) -> None:
|
99 |
+
"""
|
100 |
+
Parses LaTeXML children and appends them as appropriate elements to the given parent element.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
html (BeautifulSoup): The BeautifulSoup object containing the HTML data.
|
104 |
+
parent (Element): The parent element to which the parsed children will be added.
|
105 |
+
"""
|
106 |
+
if html is None:
|
107 |
+
return
|
108 |
+
for child in html.children:
|
109 |
+
classes = _get_classes(child)
|
110 |
+
if isinstance(child, NavigableString):
|
111 |
+
parent.append(TextElement(content=_clean_html_whitespace(str(child))))
|
112 |
+
elif sv.match(
|
113 |
+
"p, .ltx_p, div.ltx_para, span.ltx_para, section.ltx_paragraph", child
|
114 |
+
):
|
115 |
+
paragraph = parent.append(Paragraph())
|
116 |
+
parse_latexml_children(child, paragraph)
|
117 |
+
elif sv.match(".ltx_tag", child):
|
118 |
+
if "ltx_tag_note" not in classes:
|
119 |
+
if sv.match(".ltx_tag_section", child):
|
120 |
+
child.string = child.string.upper()
|
121 |
+
elif sv.match(".ltx_tag_subsection", child):
|
122 |
+
child.string = ""
|
123 |
+
parse_latexml_children(child, parent)
|
124 |
+
elif "ltx_tag_bibitem" in classes:
|
125 |
+
parse_latexml_children(child, parent.append(SpanElement()))
|
126 |
+
elif sv.match(".ltx_note_outer", child):
|
127 |
+
# try to place the footnote outside the current paragraph
|
128 |
+
paragraph = parent.find_parent(Paragraph)
|
129 |
+
if paragraph is not None and paragraph.parent is not None:
|
130 |
+
footnote = paragraph.parent.append(Footnote())
|
131 |
+
else:
|
132 |
+
footnote = parent.append(Footnote())
|
133 |
+
parse_latexml_children(child, footnote)
|
134 |
+
elif sv.match(".ltx_note_content > .ltx_note_mark", child):
|
135 |
+
footnote = parent.find_parent(Footnote)
|
136 |
+
if footnote is not None:
|
137 |
+
footnote.id = child.get_text(strip=True)
|
138 |
+
else:
|
139 |
+
printerr("Unable to find footnote to set its id", file=sys.stderr)
|
140 |
+
parse_latexml_children(child, parent)
|
141 |
+
elif sv.match("sup", child):
|
142 |
+
sup = parent.append(Superscript())
|
143 |
+
parse_latexml_children(child, sup)
|
144 |
+
elif sv.match("sub", child):
|
145 |
+
sub = parent.append(Subscript())
|
146 |
+
parse_latexml_children(child, sub)
|
147 |
+
elif sv.match("span.ltx_Math, span.ltx_DisplayMath", child):
|
148 |
+
inline = "ltx_DisplayMath" not in classes
|
149 |
+
math_elem = child.select_one(".mjx-math")
|
150 |
+
if math_elem:
|
151 |
+
tex = math_elem.attrs["aria-label"]
|
152 |
+
if inline:
|
153 |
+
tex = rf"\({tex}\)"
|
154 |
+
else:
|
155 |
+
tex = rf"\[{tex}\]"
|
156 |
+
parent.append(LatexMath(code=tex, inline=inline))
|
157 |
+
elif sv.match("math.ltx_Math", child):
|
158 |
+
# not sure if the math tag LaTeXML version specific, but that seems to work
|
159 |
+
inline = True
|
160 |
+
if "display" in child.attrs:
|
161 |
+
inline = child.attrs["display"] == "inline"
|
162 |
+
tex = child.attrs["alttext"]
|
163 |
+
if inline:
|
164 |
+
tex = rf"\({tex}\)"
|
165 |
+
else:
|
166 |
+
tex = rf"\[{tex}\]"
|
167 |
+
parent.append(LatexMath(code=tex, inline=inline))
|
168 |
+
elif sv.match("a.ref", child):
|
169 |
+
link = parent.append(Link())
|
170 |
+
link.target = child.attrs.get("href")
|
171 |
+
parse_latexml_children(child, link)
|
172 |
+
elif sv.match(
|
173 |
+
".ltx_ref.ltx_missing_citation, .ltx_ref.ltx_missing_label", child
|
174 |
+
):
|
175 |
+
placeholder = child.get_text().strip()
|
176 |
+
resolved = False
|
177 |
+
if placeholder.isnumeric():
|
178 |
+
parent.append(TextElement(content=placeholder))
|
179 |
+
resolved = True
|
180 |
+
else:
|
181 |
+
target = child.attrs.get("href")
|
182 |
+
if target is not None:
|
183 |
+
potential_num = target.partition(".bib")[2]
|
184 |
+
if potential_num.isnumeric():
|
185 |
+
parent.append(TextElement(content=potential_num))
|
186 |
+
resolved = True
|
187 |
+
if not resolved:
|
188 |
+
raise ValueError("missing reference detected")
|
189 |
+
elif sv.match(
|
190 |
+
".ltx_bibblock, .ltx_role_author, .ltx_contact, .ltx_role_email, .ltx_role_affiliation",
|
191 |
+
child,
|
192 |
+
):
|
193 |
+
parse_latexml_children(child, parent.append(SpanElement()))
|
194 |
+
parent.append(TextElement(content="\n"))
|
195 |
+
elif sv.match(
|
196 |
+
".ltx_authors, .ltx_personname, .ltx_role_creation.ltx_date, .ltx_engrafo_author_notes, .ltx_author_notes, .ltx_date.ltx_role_creation",
|
197 |
+
child,
|
198 |
+
):
|
199 |
+
parse_latexml_children(child, parent.append(Paragraph()))
|
200 |
+
parent.append(TextElement(content="\n"))
|
201 |
+
elif sv.match(
|
202 |
+
".ltx_author_before, .ltx_role_pubyear, .ltx_role_pagerange", child
|
203 |
+
):
|
204 |
+
pass
|
205 |
+
elif sv.match("h1.ltx_title_document", child):
|
206 |
+
doc = parent.find_parent(Document)
|
207 |
+
if doc is not None:
|
208 |
+
if doc.title is None:
|
209 |
+
doc.title = SectionHeader(parent=doc)
|
210 |
+
doc.title.hnum = int(child.name[1])
|
211 |
+
parse_latexml_children(child, doc.title)
|
212 |
+
else:
|
213 |
+
printerr("Document title is already set", file=sys.stderr)
|
214 |
+
else:
|
215 |
+
printerr("Unable to find document to set title", file=sys.stderr)
|
216 |
+
elif sv.match("section", child):
|
217 |
+
if ".ltx_bibliography" not in classes:
|
218 |
+
section = parent.append(Section())
|
219 |
+
parse_latexml_children(child, section)
|
220 |
+
elif sv.match("h1, h2, h3, h4, h5, h6", child) and "ltx_title" in classes:
|
221 |
+
if {"ltx_title_theorem", "ltx_title_proof"} & classes:
|
222 |
+
parse_latexml_children(child, parent)
|
223 |
+
parent.append(TextElement(content=": "))
|
224 |
+
elif isinstance(parent, Section):
|
225 |
+
parent.hnum = int(child.name[1])
|
226 |
+
if parent.header is None:
|
227 |
+
parent.header = SpanElement()
|
228 |
+
parse_latexml_children(child, parent.header)
|
229 |
+
else:
|
230 |
+
printerr("Dangling title element", file=sys.stderr)
|
231 |
+
parse_latexml_children(child, parent)
|
232 |
+
elif sv.match(".ltx_TOC.ltx_toc_toc", child):
|
233 |
+
s = parent.append(Section(hnum=6, header=TextElement(content="Contents")))
|
234 |
+
parse_latexml_children(child, s.append(Paragraph()))
|
235 |
+
elif sv.match(
|
236 |
+
"ul.ltx_itemize, ul.ltx_toclist, ul.ltx_biblist, ol.ltx_enumerate", child
|
237 |
+
):
|
238 |
+
lst = parent.append(ListContainer())
|
239 |
+
lst.ordered = child.name == "ol"
|
240 |
+
parent_list = parent.find_parent(ListContainer)
|
241 |
+
lst.level = parent_list.level + 1 if parent_list is not None else 1
|
242 |
+
parse_latexml_children(child, lst)
|
243 |
+
elif sv.match("li.ltx_item, li.ltx_tocentry, li.ltx_bibitem", child):
|
244 |
+
lst = parent.find_parent(ListContainer)
|
245 |
+
if lst is not None:
|
246 |
+
item = lst.add_item(ListItem())
|
247 |
+
parse_latexml_children(child, item)
|
248 |
+
else:
|
249 |
+
printerr("List item outside list", file=sys.stderr)
|
250 |
+
elif sv.match("cite", child):
|
251 |
+
span = parent.append(SpanElement())
|
252 |
+
parse_latexml_citations(child, span)
|
253 |
+
elif sv.match("a.ltx_ref", child):
|
254 |
+
target = child.attrs.get("href")
|
255 |
+
if target.startswith("#bib"): # citation link
|
256 |
+
in_ref = parent.append(InlineRef())
|
257 |
+
in_ref.target = target
|
258 |
+
text = child.get_text()
|
259 |
+
in_ref.target = target
|
260 |
+
if text.strip().isnumeric():
|
261 |
+
in_ref.append(TextElement(content=text))
|
262 |
+
elif re.search(r"[A-Za-z][:;.,_]?\d", text):
|
263 |
+
# probably a broken citation, go with link number instead
|
264 |
+
in_ref.append(
|
265 |
+
TextElement(
|
266 |
+
content=re.sub(r"\D", "", target.partition(".bib")[2])
|
267 |
+
)
|
268 |
+
)
|
269 |
+
else:
|
270 |
+
raise ValueError('unusable reference "%s"' % text)
|
271 |
+
doc = parent.find_parent(Document)
|
272 |
+
if doc:
|
273 |
+
doc.add_inline_ref(in_ref)
|
274 |
+
else:
|
275 |
+
link = parent.append(Link())
|
276 |
+
link.target = target
|
277 |
+
parse_latexml_children(child, link)
|
278 |
+
elif sv.match("a", child) and len(classes) == 0:
|
279 |
+
target = child.attrs.get("href")
|
280 |
+
parse_latexml_children(child, parent.append(Link(target=target)))
|
281 |
+
elif sv.match(".ltx_eqn_table", child):
|
282 |
+
eqn_list = parent.append(EquationList())
|
283 |
+
parse_latexml_children(child, eqn_list)
|
284 |
+
elif sv.match(".ltx_eqn_row", child):
|
285 |
+
eqn_list = parent.find_parent(EquationList)
|
286 |
+
if eqn_list is not None:
|
287 |
+
eqn = eqn_list.add_equation(Equation())
|
288 |
+
parse_latexml_children(child, eqn)
|
289 |
+
else:
|
290 |
+
printerr("Dangling equation row", file=sys.stderr)
|
291 |
+
parse_latexml_children(child, parent)
|
292 |
+
elif sv.match(".ltx_eqn_cell", child):
|
293 |
+
parse_latexml_children(child, parent)
|
294 |
+
elif sv.match("table, span.ltx_tabular, div.ltx_tabular", child):
|
295 |
+
tabular = parent.append(Tabular())
|
296 |
+
parse_latexml_children(child, tabular)
|
297 |
+
elif sv.match("thead.ltx_thead", child):
|
298 |
+
table = parent.find_parent(Tabular)
|
299 |
+
if table is not None:
|
300 |
+
parse_latexml_children(child, table)
|
301 |
+
else:
|
302 |
+
printerr("Table header element outside table", file=sys.stderr)
|
303 |
+
elif sv.match("tbody.ltx_tbody", child):
|
304 |
+
parse_latexml_children(child, parent)
|
305 |
+
elif sv.match("tr.ltx_tr", child):
|
306 |
+
table = parent.find_parent(Tabular)
|
307 |
+
if table is not None:
|
308 |
+
row = table.add_row(TableRow())
|
309 |
+
parse_latexml_children(child, row)
|
310 |
+
else:
|
311 |
+
printerr("TableRow element outside table", file=sys.stderr)
|
312 |
+
elif sv.match("td.ltx_td, th.ltx_th", child):
|
313 |
+
row = parent.find_parent(TableRow)
|
314 |
+
if row is not None:
|
315 |
+
cell = TableCell()
|
316 |
+
cell.set_attrs(child.attrs)
|
317 |
+
row.add_cell(cell)
|
318 |
+
parse_latexml_children(child, cell)
|
319 |
+
else:
|
320 |
+
printerr("TableData element outside table row", file=sys.stderr)
|
321 |
+
elif sv.match("span.ltx_text, em.ltx_emph", child):
|
322 |
+
if (
|
323 |
+
child.find_parent(ListItem) is None
|
324 |
+
or child.get_text() != "[label=0)]"
|
325 |
+
or child.get_text() != "[leftmargin=*] "
|
326 |
+
):
|
327 |
+
if "ltx_font_italic" in classes:
|
328 |
+
elem = Italic()
|
329 |
+
elif "ltx_font_bold" in classes:
|
330 |
+
elem = Bold()
|
331 |
+
else:
|
332 |
+
elem = SpanElement()
|
333 |
+
parent.append(elem)
|
334 |
+
parse_latexml_children(child, elem)
|
335 |
+
else:
|
336 |
+
parent.find_parent(ListContainer).items.pop()
|
337 |
+
elif sv.match("figure.ltx_table", child):
|
338 |
+
figure = parent.append(Table())
|
339 |
+
if "id" in child.attrs:
|
340 |
+
figure.id = child.attrs["id"]
|
341 |
+
parse_latexml_children(child, figure)
|
342 |
+
elif sv.match("figure.ltx_figure", child):
|
343 |
+
figure = parent.append(Figure())
|
344 |
+
if "id" in child.attrs:
|
345 |
+
figure.id = child.attrs["id"]
|
346 |
+
parse_latexml_children(child, figure)
|
347 |
+
elif sv.match("figure.ltx_float", child):
|
348 |
+
parse_latexml_children(child, parent)
|
349 |
+
elif sv.match(".ltx_listing", child):
|
350 |
+
alg = parent.append(Algorithm())
|
351 |
+
parse_latexml_children(child, alg)
|
352 |
+
elif sv.match(".ltx_listingline", child):
|
353 |
+
alg = parent.find_parent(Algorithm)
|
354 |
+
if alg is not None:
|
355 |
+
line = alg.add_line(Element())
|
356 |
+
parse_latexml_children(child, line)
|
357 |
+
else:
|
358 |
+
printerr("Listing line outside algorithm environment", file=sys.stderr)
|
359 |
+
elif sv.match("dl.ltx_description", child):
|
360 |
+
def_list = parent.append(DefinitionList())
|
361 |
+
parse_latexml_children(child, def_list)
|
362 |
+
elif sv.match("dt.ltx_item", child):
|
363 |
+
def_list = parent.find_parent(DefinitionList)
|
364 |
+
if def_list is not None:
|
365 |
+
item = def_list.add_item(Definition())
|
366 |
+
item.term = SpanElement(parent=item)
|
367 |
+
parse_latexml_children(child, item.term)
|
368 |
+
else:
|
369 |
+
printerr("Found dangling definition term", file=sys.stderr)
|
370 |
+
elif sv.match("dd.ltx_item", child):
|
371 |
+
def_list = parent.find_parent(DefinitionList)
|
372 |
+
if def_list is not None:
|
373 |
+
if def_list.items and def_list.items[-1].definition is None:
|
374 |
+
item = def_list.items[-1]
|
375 |
+
else:
|
376 |
+
printerr("Found definition without term", file=sys.stderr)
|
377 |
+
item = def_list.add_item(Definition())
|
378 |
+
item.definition = SpanElement(parent=item)
|
379 |
+
parse_latexml_children(child, item.definition)
|
380 |
+
else:
|
381 |
+
printerr("Found dangling definition", file=sys.stderr)
|
382 |
+
parse_latexml_children(child, parent)
|
383 |
+
elif sv.match("figcaption", child):
|
384 |
+
fig = parent.find_parent((Figure, Table))
|
385 |
+
if fig is not None:
|
386 |
+
if fig.caption is None:
|
387 |
+
fig.caption = Paragraph(parent=fig)
|
388 |
+
parse_latexml_children(child, fig.caption)
|
389 |
+
fig.caption.append(TextElement(content="\n"))
|
390 |
+
else:
|
391 |
+
printerr("Figure caption outside figure element", file=sys.stderr)
|
392 |
+
para = parent.append(Paragraph())
|
393 |
+
parse_latexml_children(child, para)
|
394 |
+
elif sv.match(".ltx_break", child):
|
395 |
+
parent.append(TextElement(content="\n\n"))
|
396 |
+
elif sv.match(".ltx_abstract, .ltx_acknowledgements", child):
|
397 |
+
abstract = parent.append(Section())
|
398 |
+
parse_latexml_children(child, abstract)
|
399 |
+
elif sv.match(".ltx_ERROR", child):
|
400 |
+
printerr(
|
401 |
+
f"LaTeX error element: {child.get_text(strip=True)}", file=sys.stderr
|
402 |
+
)
|
403 |
+
elif is_wrapper_element(child):
|
404 |
+
parse_latexml_children(child, parent)
|
405 |
+
elif ignore_element(child):
|
406 |
+
continue
|
407 |
+
else:
|
408 |
+
printerr(
|
409 |
+
f"Unknown LaTeXML element <{child.name}> with classes {', '.join(classes)}",
|
410 |
+
file=sys.stderr,
|
411 |
+
)
|
412 |
+
elem = parent.append(UnknownElement())
|
413 |
+
parse_latexml_children(child, elem)
|
414 |
+
|
415 |
+
|
416 |
+
# TODO: move this somewhere else, so I can use it with plaintext too
|
417 |
+
sess = requests.Session()
|
418 |
+
|
419 |
+
|
420 |
+
def parse_latexml_references(html: BeautifulSoup, doc: Document) -> None:
|
421 |
+
for child in html.select("li.ltx_bibitem"):
|
422 |
+
child.attrs.get("id")
|
423 |
+
ref_text = child.get_text(strip=False).replace("\n", " ")
|
424 |
+
reference = Reference()
|
425 |
+
reference.title = TextElement(content=child.get_text(strip=True))
|
426 |
+
doc.add_reference(reference)
|
427 |
+
|
428 |
+
|
429 |
+
def parse_latexml(
|
430 |
+
html: BeautifulSoup,
|
431 |
+
) -> Optional[Document]:
|
432 |
+
if html.article is None:
|
433 |
+
printerr("Missing article element", file=sys.stderr)
|
434 |
+
return None
|
435 |
+
doc = Document()
|
436 |
+
parse_latexml_children(html.article, doc)
|
437 |
+
parse_latexml_references(
|
438 |
+
html.article,
|
439 |
+
doc,
|
440 |
+
)
|
441 |
+
return doc
|
nougat/dataset/parser/markdown.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
from typing import Iterable, List, Optional, Tuple
|
8 |
+
import re
|
9 |
+
from uuid import uuid4
|
10 |
+
from nougat.dataset.utils import normalize_tex
|
11 |
+
from nougat.dataset.parser.document import *
|
12 |
+
from nougat.dataset.parser.latexml_parser import _clean_html_whitespace
|
13 |
+
from unidecode import unidecode
|
14 |
+
|
15 |
+
SUPERSCRIPT_MAP = str.maketrans("0123456789", "⁰¹²³⁴⁵⁶⁷⁸⁹")
|
16 |
+
SUBSCRIPT_MAP = str.maketrans("0123456789", "₀₁₂₃₄₅₆₇₈₉")
|
17 |
+
figure_regex = re.compile(r"\[(FOOTNOTE|FIGURE|TABLE)(.*?)\](.*?)\[END\1\]", re.S)
|
18 |
+
conv = {
|
19 |
+
"&": r"\&",
|
20 |
+
"%": r"\%",
|
21 |
+
"$": r"\$",
|
22 |
+
"#": r"\#",
|
23 |
+
"_": r"\_",
|
24 |
+
"{": r"\{",
|
25 |
+
"}": r"\}",
|
26 |
+
"~": r"\textasciitilde{}",
|
27 |
+
"^": r"\^{}",
|
28 |
+
"\\": r"\textbackslash{}",
|
29 |
+
"<": r"\textless{}",
|
30 |
+
">": r"\textgreater{}",
|
31 |
+
}
|
32 |
+
regex = re.compile(
|
33 |
+
"|".join(
|
34 |
+
re.escape(str(key)) for key in sorted(conv.keys(), key=lambda item: -len(item))
|
35 |
+
)
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
def remove_trailing_whitespace(parts: List[str]) -> None:
|
40 |
+
"""Removes whitespace elements in list inplace"""
|
41 |
+
for s in reversed(parts):
|
42 |
+
if s.rstrip() == "":
|
43 |
+
del parts[-1]
|
44 |
+
else:
|
45 |
+
break
|
46 |
+
|
47 |
+
|
48 |
+
def remove_line_breaks(parts: List[str]):
|
49 |
+
out = []
|
50 |
+
for s in parts:
|
51 |
+
out.append(s.replace("\n", " "))
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
def leading_trailing_whitespace(
|
56 |
+
parts: List[str],
|
57 |
+
) -> Tuple[List[str], List[str], List[str]]:
|
58 |
+
"""splits the list into three parts. The first and last return elements are made up only of whitespace
|
59 |
+
|
60 |
+
Args:
|
61 |
+
parts (List[str]): List to split.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Tuple[List[str],List[str],List[str]]: Splitted list
|
65 |
+
"""
|
66 |
+
lead = []
|
67 |
+
trail = []
|
68 |
+
out_slice = [None, None]
|
69 |
+
for i, s in enumerate(parts):
|
70 |
+
if s.strip() == "":
|
71 |
+
lead.append(s)
|
72 |
+
out_slice[0] = i + 1
|
73 |
+
else:
|
74 |
+
break
|
75 |
+
for i, s in enumerate(reversed(parts)):
|
76 |
+
if s.strip() == "":
|
77 |
+
trail.append(s)
|
78 |
+
out_slice[1] = -1 - i
|
79 |
+
else:
|
80 |
+
break
|
81 |
+
return lead, parts[slice(*out_slice)], trail[::-1]
|
82 |
+
|
83 |
+
|
84 |
+
def latex_escape(string: str) -> str:
|
85 |
+
return regex.sub(lambda match: conv[match.group()], string)
|
86 |
+
|
87 |
+
|
88 |
+
def is_empty(content: List) -> bool:
|
89 |
+
"""Used to determine if a Section is empty"""
|
90 |
+
empty = True
|
91 |
+
for part in content:
|
92 |
+
if len(part.strip()):
|
93 |
+
empty = False
|
94 |
+
break
|
95 |
+
return empty
|
96 |
+
|
97 |
+
|
98 |
+
def format_element(
|
99 |
+
element: Element, keep_refs: bool = False, latex_env: bool = False
|
100 |
+
) -> List[str]:
|
101 |
+
"""
|
102 |
+
Formats a given Element into a list of formatted strings.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
element (Element): The element to be formatted.
|
106 |
+
keep_refs (bool, optional): Whether to keep references in the formatting. Default is False.
|
107 |
+
latex_env (bool, optional): Whether to use LaTeX environment formatting. Default is False.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
List[str]: A list of formatted strings representing the formatted element.
|
111 |
+
"""
|
112 |
+
if isinstance(element, TextElement):
|
113 |
+
if latex_env:
|
114 |
+
return [latex_escape(element.content)]
|
115 |
+
else:
|
116 |
+
return [element.content]
|
117 |
+
if isinstance(element, Bold):
|
118 |
+
parts = format_children(element, keep_refs, latex_env)
|
119 |
+
if element.find_parent(Algorithm) is not None:
|
120 |
+
return parts
|
121 |
+
lead, text, tail = leading_trailing_whitespace("".join(parts))
|
122 |
+
return [*lead, "**", *remove_line_breaks(text), "**", *tail]
|
123 |
+
if isinstance(element, Italic):
|
124 |
+
parts = format_children(element, keep_refs, latex_env)
|
125 |
+
if element.find_parent(Algorithm) is not None:
|
126 |
+
return parts
|
127 |
+
lead, text, tail = leading_trailing_whitespace("".join(parts))
|
128 |
+
return [*lead, "_", *remove_line_breaks(text), "_", *tail]
|
129 |
+
if isinstance(element, PlaintextMath):
|
130 |
+
return format_children(element, keep_refs) + ["\n"]
|
131 |
+
if isinstance(element, Paragraph):
|
132 |
+
return format_children(element, keep_refs, latex_env) + ["\n\n"]
|
133 |
+
if isinstance(element, TableCell):
|
134 |
+
parts = format_children(element, keep_refs, latex_env)
|
135 |
+
remove_trailing_whitespace(parts)
|
136 |
+
if element.multirow is not None:
|
137 |
+
parts.insert(0, "\\multirow{%i}{*}{" % (element.multirow))
|
138 |
+
parts.append("}")
|
139 |
+
if element.multicolumn is not None:
|
140 |
+
parts.insert(
|
141 |
+
0, "\\multicolumn{%i}{%s}{" % (element.multicolumn, element.spec)
|
142 |
+
)
|
143 |
+
parts.append("}")
|
144 |
+
return parts
|
145 |
+
if isinstance(element, TableRow):
|
146 |
+
parts = []
|
147 |
+
if element.hline_above:
|
148 |
+
parts.append(element.hline_above + "\n")
|
149 |
+
parts.extend(
|
150 |
+
remove_line_breaks(
|
151 |
+
format_iterator(element.cells, keep_refs, latex_env, join=" & ")
|
152 |
+
)
|
153 |
+
)
|
154 |
+
parts.append(r" \\")
|
155 |
+
parts.append((" " + element.hline_below).rstrip())
|
156 |
+
return parts
|
157 |
+
if isinstance(element, Tabular):
|
158 |
+
parts = [
|
159 |
+
"\\begin{tabular}",
|
160 |
+
"{%s}\n" % element.get_table_spec(),
|
161 |
+
]
|
162 |
+
parts.extend(format_iterator(element.rows, keep_refs, True, join="\n"))
|
163 |
+
parts.append("\n\\end{tabular}\n")
|
164 |
+
return parts
|
165 |
+
if isinstance(element, Table):
|
166 |
+
parts = [
|
167 |
+
"[TABLE%s]\n\\begin{table}\n"
|
168 |
+
% (str(uuid4())[:5] if element.id is None else ":" + str(element.id))
|
169 |
+
]
|
170 |
+
parts.extend(format_children(element, keep_refs, latex_env))
|
171 |
+
caption_parts = format_element(element.caption, keep_refs, latex_env)
|
172 |
+
remove_trailing_whitespace(caption_parts)
|
173 |
+
parts.append("\\end{table}\n")
|
174 |
+
if len(caption_parts) > 0:
|
175 |
+
parts.extend(caption_parts + ["\n"])
|
176 |
+
parts.append("[ENDTABLE]\n\n")
|
177 |
+
return parts
|
178 |
+
if isinstance(element, Figure):
|
179 |
+
parts = format_element(element.caption, keep_refs)
|
180 |
+
remove_trailing_whitespace(parts)
|
181 |
+
return (
|
182 |
+
[
|
183 |
+
"[FIGURE%s]\n"
|
184 |
+
% (str(uuid4())[:5] if element.id is None else ":" + str(element.id))
|
185 |
+
]
|
186 |
+
+ parts
|
187 |
+
+ ["\n[ENDFIGURE]\n\n"]
|
188 |
+
)
|
189 |
+
if isinstance(element, SectionHeader):
|
190 |
+
parts = ["# "]
|
191 |
+
if element.id:
|
192 |
+
parts.append(f"{element.id.upper()} ")
|
193 |
+
if element.header:
|
194 |
+
header = format_element(element.header, keep_refs)
|
195 |
+
else:
|
196 |
+
header = format_iterator(element.children, keep_refs)
|
197 |
+
_, title, _ = leading_trailing_whitespace("".join(header))
|
198 |
+
parts.append(title)
|
199 |
+
parts.append("\n\n")
|
200 |
+
return parts
|
201 |
+
if isinstance(element, Section):
|
202 |
+
children_parts = format_children(element, keep_refs)
|
203 |
+
if is_empty(children_parts):
|
204 |
+
return []
|
205 |
+
if element.header:
|
206 |
+
parts = [f"\n\n{'#'*element.hnum} "]
|
207 |
+
_, title, _ = leading_trailing_whitespace(
|
208 |
+
"".join(format_element(element.header, keep_refs))
|
209 |
+
)
|
210 |
+
parts.append(title)
|
211 |
+
parts.append("\n\n")
|
212 |
+
else:
|
213 |
+
parts = []
|
214 |
+
return parts + children_parts
|
215 |
+
if isinstance(element, Footnote):
|
216 |
+
if element.id is not None:
|
217 |
+
foot = f"\n[FOOTNOTE:{element.id}]Footnote {element.id}: "
|
218 |
+
else:
|
219 |
+
foot = "\n[FOOTNOTE:%s]Footnote: " % (str(uuid4())[:5])
|
220 |
+
return [foot] + format_children(element, keep_refs) + ["[ENDFOOTNOTE]\n\n"]
|
221 |
+
if isinstance(element, ListContainer):
|
222 |
+
items = [
|
223 |
+
(
|
224 |
+
item.label,
|
225 |
+
"".join(format_element(item, keep_refs)).strip().replace("\n", " "),
|
226 |
+
)
|
227 |
+
for item in element.items
|
228 |
+
]
|
229 |
+
parts = ["\n"]
|
230 |
+
indent = " " * max(element.level - 1, 0)
|
231 |
+
for i, (label, item) in enumerate(items, 1):
|
232 |
+
if label:
|
233 |
+
bullet = label
|
234 |
+
else:
|
235 |
+
bullet = f"{i}." if element.ordered else "*"
|
236 |
+
parts.append(f"{indent}{bullet} {item}\n")
|
237 |
+
parts.append("\n")
|
238 |
+
return parts
|
239 |
+
if isinstance(element, Equation):
|
240 |
+
# equation comprises of multiple displaystyle TeX formulas and optional equation label
|
241 |
+
parts = []
|
242 |
+
for child in element.children:
|
243 |
+
if isinstance(child, LatexMath):
|
244 |
+
tex = normalize_tex(
|
245 |
+
"".join(format_element(child, keep_refs)).strip(" \n"), inline=False
|
246 |
+
)
|
247 |
+
parts.append(tex)
|
248 |
+
else:
|
249 |
+
text = "".join(format_element(child, keep_refs))
|
250 |
+
if text:
|
251 |
+
parts.append(text)
|
252 |
+
lead, eqs, tail = leading_trailing_whitespace(parts)
|
253 |
+
s = " ".join(eqs).replace(r"\] \[", " ")
|
254 |
+
return [*lead, s, *tail]
|
255 |
+
if isinstance(element, EquationList):
|
256 |
+
parts = ["\n"]
|
257 |
+
items = element.equations
|
258 |
+
items = ["".join(format_element(item, keep_refs)).rstrip() for item in items]
|
259 |
+
items = [item + "\n" for item in items if item]
|
260 |
+
if items:
|
261 |
+
parts.extend(items)
|
262 |
+
parts.append("\n")
|
263 |
+
return parts
|
264 |
+
if isinstance(element, Algorithm):
|
265 |
+
parts = []
|
266 |
+
items = element.lines
|
267 |
+
items = ["".join(format_element(item, keep_refs)).rstrip() for item in items]
|
268 |
+
if element.inline:
|
269 |
+
items = [item for item in items if item]
|
270 |
+
else:
|
271 |
+
items = [item + "\n" for item in items if item]
|
272 |
+
if items:
|
273 |
+
prepend = "`" if element.inline else "\n```\n"
|
274 |
+
parts.append(prepend)
|
275 |
+
parts.extend(items)
|
276 |
+
append = "`" if element.inline else "```\n\n"
|
277 |
+
parts.append(append)
|
278 |
+
return parts
|
279 |
+
if isinstance(element, DefinitionList):
|
280 |
+
parts = ["\n"]
|
281 |
+
if element.header is not None:
|
282 |
+
parts.extend(format_element(element.header, keep_refs))
|
283 |
+
parts.append("\n")
|
284 |
+
items = [
|
285 |
+
"".join(format_element(item, keep_refs)).rstrip() for item in element.items
|
286 |
+
]
|
287 |
+
items = [item + "\n" for item in items if item]
|
288 |
+
if items:
|
289 |
+
parts.extend(items)
|
290 |
+
parts.append("\n")
|
291 |
+
return parts
|
292 |
+
if isinstance(element, Definition):
|
293 |
+
parts = []
|
294 |
+
if element.term is not None:
|
295 |
+
term = (
|
296 |
+
"".join(format_element(element.term, keep_refs)).rstrip(" \n\t:") + ": "
|
297 |
+
)
|
298 |
+
# maths in wiki might be inside a definition without a term
|
299 |
+
if term.strip() != ":":
|
300 |
+
parts.append(term)
|
301 |
+
if element.definition is not None:
|
302 |
+
definition = "".join(format_element(element.definition, keep_refs)).rstrip()
|
303 |
+
parts.append(definition)
|
304 |
+
if parts:
|
305 |
+
parts.append("\n")
|
306 |
+
return parts
|
307 |
+
if isinstance(element, LatexMath):
|
308 |
+
parts = []
|
309 |
+
if not element.inline:
|
310 |
+
parts.append("\n\n")
|
311 |
+
parts.append(normalize_tex(element.code, element.inline).strip())
|
312 |
+
if not element.inline:
|
313 |
+
parts.append("\n\n")
|
314 |
+
return parts
|
315 |
+
if isinstance(element, (Superscript, Subscript)):
|
316 |
+
content = element.plaintext
|
317 |
+
if content.strip().isdigit():
|
318 |
+
script_map = (
|
319 |
+
SUBSCRIPT_MAP if isinstance(element, Subscript) else SUPERSCRIPT_MAP
|
320 |
+
)
|
321 |
+
return [content.translate(script_map)]
|
322 |
+
else:
|
323 |
+
return format_children(element, keep_refs)
|
324 |
+
if isinstance(element, InlineRef):
|
325 |
+
parts = format_children(element, keep_refs)
|
326 |
+
return parts
|
327 |
+
return format_children(element, keep_refs, latex_env)
|
328 |
+
|
329 |
+
|
330 |
+
def format_iterator(
|
331 |
+
iterator: Iterable,
|
332 |
+
keep_refs: bool = False,
|
333 |
+
latex_env: bool = False,
|
334 |
+
join: Optional[str] = None,
|
335 |
+
) -> List[str]:
|
336 |
+
"""
|
337 |
+
The `format_iterator` function takes an iterator and formats its elements, optionally joining them with a specified string.
|
338 |
+
|
339 |
+
:param iterator: The `iterator` parameter is an iterable object that contains the elements to be formatted. It could be a list, tuple, set, or any other iterable object
|
340 |
+
:type iterator: Iterable
|
341 |
+
:param keep_refs: The `keep_refs` parameter is a boolean flag that determines whether references to other elements should be preserved in the formatted output. If `keep_refs` is set to `True`, the references will be included in the output. If `keep_refs` is set to `False` (default), the, defaults to False
|
342 |
+
:type keep_refs: bool (optional)
|
343 |
+
:param latex_env: The `latex_env` parameter is a boolean flag that determines whether the output should be formatted as LaTeX code. If `latex_env` is set to `True`, the output will be formatted using LaTeX syntax. If `latex_env` is set to `False` (default), the output will be, defaults to False
|
344 |
+
:type latex_env: bool (optional)
|
345 |
+
:param join: The `join` parameter is an optional string that specifies the delimiter to be used when joining the formatted elements of the iterator into a single string. If `join` is provided, it will be inserted between each formatted element. If `join` is not provided, the formatted elements will be returned as
|
346 |
+
:type join: Optional[str]
|
347 |
+
:return: The function `format_iterator` returns a list of strings.
|
348 |
+
"""
|
349 |
+
parts = []
|
350 |
+
for child in iterator:
|
351 |
+
parts.extend(format_element(child, keep_refs, latex_env))
|
352 |
+
if join is not None:
|
353 |
+
parts.append(join)
|
354 |
+
if join is not None:
|
355 |
+
parts = parts[:-1]
|
356 |
+
return parts
|
357 |
+
|
358 |
+
|
359 |
+
def format_children(
|
360 |
+
element: Element, keep_refs: bool = False, latex_env: bool = False
|
361 |
+
) -> List[str]:
|
362 |
+
if element is None:
|
363 |
+
return []
|
364 |
+
return format_iterator(element.children, keep_refs, latex_env)
|
365 |
+
|
366 |
+
|
367 |
+
def format_document(
|
368 |
+
doc: Document, keep_refs: bool = False
|
369 |
+
) -> Tuple[str, Dict[str, str]]:
|
370 |
+
"""
|
371 |
+
The `format_document` function takes a `doc` object of type `Document` and a boolean `keep_refs` as input and returns a tuple containing the formatted text of the document and a dictionary of figures found in the document.
|
372 |
+
|
373 |
+
:param doc: The `doc` parameter is of type `Document`, which is presumably a custom class representing a document
|
374 |
+
:type doc: Document
|
375 |
+
:param keep_refs: The `keep_refs` parameter is a boolean flag that determines whether to keep references in the formatted document or not. If `keep_refs` is set to `True`, the references will be included in the formatted document. If `keep_refs` is set to `False`, the references will be excluded, defaults to False
|
376 |
+
:type keep_refs: bool (optional)
|
377 |
+
:return: The function `format_document` returns a tuple containing two elements: a formatted text document and a dictionary of figures.
|
378 |
+
"""
|
379 |
+
parts = []
|
380 |
+
|
381 |
+
if doc.title:
|
382 |
+
parts.extend([*format_element(doc.title), "\n"])
|
383 |
+
parts.append("\n")
|
384 |
+
parts.extend(format_children(doc, keep_refs))
|
385 |
+
text = "".join(parts)
|
386 |
+
text = text.replace("\xa0", " ") # replace non-breakable spaces
|
387 |
+
text = re.sub(r" $", "", text, flags=re.MULTILINE)
|
388 |
+
text = re.sub(r"\n[\t ]*$", "\n", text, flags=re.MULTILINE)
|
389 |
+
text = re.sub(r"(?<!\n) {2,}", " ", text)
|
390 |
+
text = re.sub(r"\n{3,}", "\n\n", text).lstrip()
|
391 |
+
figures = {unidecode(m[0] + m[1]): m[2].strip() for m in figure_regex.findall(text)}
|
392 |
+
text = figure_regex.sub(
|
393 |
+
r"[\1\2][END\1]",
|
394 |
+
text,
|
395 |
+
)
|
396 |
+
return text, figures
|
nougat/dataset/pdffigures.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
import os
|
8 |
+
import subprocess
|
9 |
+
import logging
|
10 |
+
|
11 |
+
PDFFIGURES2_JAR_PATH = os.environ.get("PDFFIGURES_PATH", None)
|
12 |
+
logger = logging.getLogger()
|
13 |
+
if PDFFIGURES2_JAR_PATH is None:
|
14 |
+
logger.warning(
|
15 |
+
"You need to configure the path to the pdffigures2 executable in this file (nougat/dataset/pdffigures.py) or set the environment variable 'PDFFIGURES_PATH'."
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def call_pdffigures(
|
20 |
+
pdf_path: str, figures_dir: str, timeout: int = 30, verbose: bool = False
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
Extract figures from a PDF file using pdffigures2.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
pdf_path (str): The path to the PDF file.
|
27 |
+
figures_dir (str): The directory where the figures will be extracted.
|
28 |
+
timeout (int, optional): The timeout in seconds for the pdffigures2 command. Defaults to 30.
|
29 |
+
verbose (bool, optional): Whether to print the output of the pdffigures2 command. Defaults to False.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
str: The path to the JSON file containing the extracted figures.
|
33 |
+
"""
|
34 |
+
os.makedirs(figures_dir, exist_ok=True)
|
35 |
+
kwargs = (
|
36 |
+
{} if verbose else {"stderr": subprocess.DEVNULL, "stdout": subprocess.DEVNULL}
|
37 |
+
)
|
38 |
+
if PDFFIGURES2_JAR_PATH is None:
|
39 |
+
return
|
40 |
+
process = subprocess.Popen(
|
41 |
+
"java"
|
42 |
+
" -jar {pdffigures_jar_path}"
|
43 |
+
" -d {figures_dir}/"
|
44 |
+
" -c"
|
45 |
+
" -q"
|
46 |
+
" {pdf_path}".format(
|
47 |
+
pdffigures_jar_path=PDFFIGURES2_JAR_PATH,
|
48 |
+
pdf_path=pdf_path,
|
49 |
+
figures_dir=figures_dir,
|
50 |
+
),
|
51 |
+
shell=True,
|
52 |
+
**kwargs
|
53 |
+
)
|
54 |
+
|
55 |
+
try:
|
56 |
+
exit_code = process.wait(timeout=timeout)
|
57 |
+
if exit_code != 0:
|
58 |
+
logger.error("Extracting figures from file %s failed.", pdf_path)
|
59 |
+
return False
|
60 |
+
except subprocess.TimeoutExpired as e:
|
61 |
+
logger.error(
|
62 |
+
"pdffigures2 command did not terminate in 30 seconds, "
|
63 |
+
"terminating. Error: %s",
|
64 |
+
e,
|
65 |
+
)
|
66 |
+
process.terminate() # give up
|
67 |
+
return False
|
68 |
+
pdf_name = os.path.basename(pdf_path).partition(".pdf")[0]
|
69 |
+
dest_file = os.path.join(figures_dir, (pdf_name + ".json"))
|
70 |
+
|
71 |
+
return dest_file
|
nougat/dataset/rasterize.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
import argparse
|
8 |
+
import logging
|
9 |
+
import pypdfium2
|
10 |
+
from pathlib import Path
|
11 |
+
from tqdm import tqdm
|
12 |
+
import io
|
13 |
+
from typing import Optional, List, Union
|
14 |
+
|
15 |
+
logging.getLogger("pypdfium2").setLevel(logging.WARNING)
|
16 |
+
|
17 |
+
|
18 |
+
def rasterize_paper(
|
19 |
+
pdf: Union[Path, bytes],
|
20 |
+
outpath: Optional[Path] = None,
|
21 |
+
dpi: int = 96,
|
22 |
+
return_pil=False,
|
23 |
+
pages=None,
|
24 |
+
) -> Optional[List[io.BytesIO]]:
|
25 |
+
"""
|
26 |
+
Rasterize a PDF file to PNG images.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
pdf (Path): The path to the PDF file.
|
30 |
+
outpath (Optional[Path], optional): The output directory. If None, the PIL images will be returned instead. Defaults to None.
|
31 |
+
dpi (int, optional): The output DPI. Defaults to 96.
|
32 |
+
return_pil (bool, optional): Whether to return the PIL images instead of writing them to disk. Defaults to False.
|
33 |
+
pages (Optional[List[int]], optional): The pages to rasterize. If None, all pages will be rasterized. Defaults to None.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
Optional[List[io.BytesIO]]: The PIL images if `return_pil` is True, otherwise None.
|
37 |
+
"""
|
38 |
+
pils = []
|
39 |
+
if outpath is None:
|
40 |
+
return_pil = True
|
41 |
+
try:
|
42 |
+
if isinstance(pdf, (str, Path)):
|
43 |
+
pdf = pypdfium2.PdfDocument(pdf)
|
44 |
+
if pages is None:
|
45 |
+
pages = range(len(pdf))
|
46 |
+
renderer = pdf.render(
|
47 |
+
pypdfium2.PdfBitmap.to_pil,
|
48 |
+
page_indices=pages,
|
49 |
+
scale=dpi / 72,
|
50 |
+
)
|
51 |
+
for i, image in zip(pages, renderer):
|
52 |
+
if return_pil:
|
53 |
+
page_bytes = io.BytesIO()
|
54 |
+
image.save(page_bytes, "bmp")
|
55 |
+
pils.append(page_bytes)
|
56 |
+
else:
|
57 |
+
image.save((outpath / ("%02d.png" % (i + 1))), "png")
|
58 |
+
except Exception as e:
|
59 |
+
logging.error(e)
|
60 |
+
if return_pil:
|
61 |
+
return pils
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
parser = argparse.ArgumentParser()
|
66 |
+
parser.add_argument("--pdfs", nargs="+", type=Path, help="PDF files", required=True)
|
67 |
+
parser.add_argument("--out", type=Path, help="Output dir", default=None)
|
68 |
+
parser.add_argument(
|
69 |
+
"--dpi", type=int, default=96, help="What resolution the pages will be saved"
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--pages", type=int, nargs="+", default=None, help="list of page numbers"
|
73 |
+
)
|
74 |
+
args = parser.parse_args()
|
75 |
+
if args.pages:
|
76 |
+
args.pages = [p - 1 for p in args.pages]
|
77 |
+
for pdf_file in tqdm(args.pdfs):
|
78 |
+
assert pdf_file.exists() and pdf_file.is_file()
|
79 |
+
outpath: Path = args.out or (pdf_file.parent / pdf_file.stem)
|
80 |
+
outpath.mkdir(exist_ok=True)
|
81 |
+
rasterize_paper(pdf_file, outpath, pages=args.pages, dpi=args.dpi)
|
nougat/dataset/split_htmls_to_pages.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
import argparse
|
8 |
+
from io import BytesIO
|
9 |
+
import multiprocessing
|
10 |
+
from pebble import ProcessPool
|
11 |
+
from concurrent.futures import TimeoutError
|
12 |
+
from tqdm import tqdm
|
13 |
+
from typing import Tuple
|
14 |
+
import os
|
15 |
+
from pathlib import Path
|
16 |
+
import logging
|
17 |
+
import pypdf
|
18 |
+
from PIL import Image
|
19 |
+
import pytesseract
|
20 |
+
from nougat.dataset.split_md_to_pages import *
|
21 |
+
from nougat.dataset.parser.html2md import *
|
22 |
+
from nougat.dataset.pdffigures import call_pdffigures
|
23 |
+
|
24 |
+
logging.basicConfig()
|
25 |
+
logger = logging.getLogger()
|
26 |
+
logger.setLevel(logging.INFO)
|
27 |
+
|
28 |
+
|
29 |
+
def process_paper(
|
30 |
+
fname: str,
|
31 |
+
pdf_file: Path,
|
32 |
+
html_file: Path,
|
33 |
+
json_file: Path,
|
34 |
+
args: argparse.Namespace,
|
35 |
+
) -> Tuple[int, int]:
|
36 |
+
"""
|
37 |
+
Process a single paper.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
fname (str): The paper's filename.
|
41 |
+
pdf_file (Path): The path to the PDF file.
|
42 |
+
html_file (Path): The path to the HTML file.
|
43 |
+
json_file (Path): The path to the JSON file containing the extracted figures.
|
44 |
+
args (argparse.Namespace): The command-line arguments.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
Tuple[int, int]: The number of total pages and the number of recognized pages.
|
48 |
+
"""
|
49 |
+
total_pages = 0
|
50 |
+
num_recognized_pages = 0
|
51 |
+
try:
|
52 |
+
pdf = pypdf.PdfReader(pdf_file)
|
53 |
+
total_pages = len(pdf.pages)
|
54 |
+
outpath: Path = args.out / fname
|
55 |
+
# skip this paper if already processed
|
56 |
+
dirs_with_same_stem = list(args.out.glob(fname.partition("v")[0] + "*"))
|
57 |
+
if (
|
58 |
+
len(dirs_with_same_stem) > 0
|
59 |
+
and len(list(dirs_with_same_stem[0].iterdir())) > 0
|
60 |
+
and not args.recompute
|
61 |
+
):
|
62 |
+
logger.info(
|
63 |
+
"%s (or another version thereof) already processed. Skipping paper",
|
64 |
+
fname,
|
65 |
+
)
|
66 |
+
return total_pages, len(list(outpath.glob("*.mmd")))
|
67 |
+
html = BeautifulSoup(
|
68 |
+
htmlmin.minify(
|
69 |
+
open(html_file, "r", encoding="utf-8").read().replace("\xa0", " "),
|
70 |
+
remove_all_empty_space=True,
|
71 |
+
),
|
72 |
+
features="html.parser",
|
73 |
+
)
|
74 |
+
doc = parse_latexml(html)
|
75 |
+
if doc is None:
|
76 |
+
return
|
77 |
+
out, fig = format_document(doc, keep_refs=True)
|
78 |
+
|
79 |
+
if args.markdown:
|
80 |
+
md_out = args.markdown / (fname + ".mmd")
|
81 |
+
with open(md_out, "w", encoding="utf-8") as f:
|
82 |
+
f.write(out)
|
83 |
+
|
84 |
+
if json_file is None:
|
85 |
+
json_file = call_pdffigures(pdf_file, args.figure)
|
86 |
+
if json_file:
|
87 |
+
figure_info = json.load(open(json_file, "r", encoding="utf-8"))
|
88 |
+
else:
|
89 |
+
figure_info = None
|
90 |
+
split = split_markdown(
|
91 |
+
out, pdf_file, figure_info=figure_info, doc_fig=fig, min_score=0.9
|
92 |
+
)
|
93 |
+
if split is None:
|
94 |
+
return
|
95 |
+
pages, meta = split
|
96 |
+
num_recognized_pages = sum([len(p) > 0 for p in pages])
|
97 |
+
if all([len(p) == 0 for p in pages]):
|
98 |
+
return
|
99 |
+
os.makedirs(outpath, exist_ok=True)
|
100 |
+
recognized_indices = []
|
101 |
+
for i, content in enumerate(pages):
|
102 |
+
with (outpath / "meta.json").open("w", encoding="utf-8") as f:
|
103 |
+
f.write(json.dumps(meta))
|
104 |
+
if content:
|
105 |
+
if re.search(r"\[(?:\?\?(?:. )?)+\]", content):
|
106 |
+
# there are wrongly parsed references in the page eg [??].
|
107 |
+
continue
|
108 |
+
with (outpath / ("%02d.mmd" % (i + 1))).open(
|
109 |
+
"w", encoding="utf-8"
|
110 |
+
) as f:
|
111 |
+
f.write(content)
|
112 |
+
recognized_indices.append(i)
|
113 |
+
rasterize_paper(pdf_file, outpath, dpi=args.dpi, pages=recognized_indices)
|
114 |
+
if args.tesseract:
|
115 |
+
for i in recognized_indices:
|
116 |
+
ocr = pytesseract.image_to_string(
|
117 |
+
Image.open((outpath / ("%02d.png" % (i + 1)))), lang="eng"
|
118 |
+
)
|
119 |
+
ocr = re.sub(r"\n+\s+?([^\s])", r"\n\n\1", ocr).strip()
|
120 |
+
with (outpath / ("%02d_OCR.txt" % (i + 1))).open(
|
121 |
+
"w", encoding="utf-8"
|
122 |
+
) as f_ocr:
|
123 |
+
f_ocr.write(ocr)
|
124 |
+
except Exception as e:
|
125 |
+
logger.error(e)
|
126 |
+
|
127 |
+
return total_pages, num_recognized_pages
|
128 |
+
|
129 |
+
|
130 |
+
def process_htmls(args):
|
131 |
+
for input_dir in (args.pdfs, args.html):
|
132 |
+
if not input_dir.exists() and not input_dir.is_dir():
|
133 |
+
logger.error("%s does not exist or is no dir.", input_dir)
|
134 |
+
return
|
135 |
+
htmls: List[Path] = args.html.glob("*.html")
|
136 |
+
args.out.mkdir(exist_ok=True)
|
137 |
+
if args.markdown:
|
138 |
+
args.markdown.mkdir(exist_ok=True)
|
139 |
+
|
140 |
+
with ProcessPool(max_workers=args.workers) as pool:
|
141 |
+
total_pages, total_pages_extracted = 0, 0
|
142 |
+
tasks = {}
|
143 |
+
for j, html_file in enumerate(htmls):
|
144 |
+
fname = html_file.stem
|
145 |
+
pdf_file = args.pdfs / (fname + ".pdf")
|
146 |
+
if not pdf_file.exists():
|
147 |
+
logger.info("%s pdf could not be found.", fname)
|
148 |
+
continue
|
149 |
+
json_file = args.figure / (fname + ".json")
|
150 |
+
if not json_file.exists():
|
151 |
+
logger.info("%s figure json could not be found.", fname)
|
152 |
+
json_file = None
|
153 |
+
tasks[fname] = pool.schedule(
|
154 |
+
process_paper,
|
155 |
+
args=[fname, pdf_file, html_file, json_file, args],
|
156 |
+
timeout=args.timeout,
|
157 |
+
)
|
158 |
+
|
159 |
+
for fname in tqdm(tasks):
|
160 |
+
try:
|
161 |
+
res = tasks[fname].result()
|
162 |
+
if res is None:
|
163 |
+
logger.info("%s is faulty", fname)
|
164 |
+
continue
|
165 |
+
num_pages, num_recognized_pages = res
|
166 |
+
total_pages += num_pages
|
167 |
+
total_pages_extracted += num_recognized_pages
|
168 |
+
logger.info(
|
169 |
+
"%s: %i/%i pages recognized. Percentage: %.2f%%",
|
170 |
+
fname,
|
171 |
+
num_recognized_pages,
|
172 |
+
num_pages,
|
173 |
+
(100 * num_recognized_pages / max(1, num_pages)),
|
174 |
+
)
|
175 |
+
except TimeoutError:
|
176 |
+
logger.info("%s timed out", fname)
|
177 |
+
if total_pages > 0:
|
178 |
+
logger.info(
|
179 |
+
"In total: %i/%i pages recognized. Percentage: %.2f%%",
|
180 |
+
total_pages_extracted,
|
181 |
+
total_pages,
|
182 |
+
(100 * total_pages_extracted / max(1, total_pages)),
|
183 |
+
)
|
184 |
+
|
185 |
+
|
186 |
+
if __name__ == "__main__":
|
187 |
+
parser = argparse.ArgumentParser()
|
188 |
+
parser.add_argument("--html", type=Path, help="HTML files", required=True)
|
189 |
+
parser.add_argument("--pdfs", type=Path, help="PDF files", required=True)
|
190 |
+
parser.add_argument("--out", type=Path, help="Output dir", required=True)
|
191 |
+
parser.add_argument("--recompute", action="store_true", help="recompute all splits")
|
192 |
+
parser.add_argument(
|
193 |
+
"--markdown", type=Path, help="Markdown output dir", default=None
|
194 |
+
)
|
195 |
+
parser.add_argument(
|
196 |
+
"--figure",
|
197 |
+
type=Path,
|
198 |
+
help="Figure info JSON dir",
|
199 |
+
)
|
200 |
+
parser.add_argument(
|
201 |
+
"--workers",
|
202 |
+
type=int,
|
203 |
+
default=multiprocessing.cpu_count(),
|
204 |
+
help="How many processes to use",
|
205 |
+
)
|
206 |
+
parser.add_argument(
|
207 |
+
"--dpi", type=int, default=96, help="What resolution the pages will be saved at"
|
208 |
+
)
|
209 |
+
parser.add_argument(
|
210 |
+
"--timeout", type=float, default=120, help="max time per paper in seconds"
|
211 |
+
)
|
212 |
+
parser.add_argument(
|
213 |
+
"--tesseract",
|
214 |
+
action="store_true",
|
215 |
+
help="Tesseract OCR prediction for each page",
|
216 |
+
)
|
217 |
+
args = parser.parse_args()
|
218 |
+
print(args)
|
219 |
+
process_htmls(args)
|
nougat/dataset/split_md_to_pages.py
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
import argparse
|
8 |
+
from collections import Counter
|
9 |
+
from copy import deepcopy
|
10 |
+
import json
|
11 |
+
import math
|
12 |
+
from operator import itemgetter
|
13 |
+
import re
|
14 |
+
from typing import Dict, List, Tuple, Union, Optional
|
15 |
+
import os
|
16 |
+
import pypdf
|
17 |
+
from unidecode import unidecode
|
18 |
+
import Levenshtein
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from sklearn.feature_extraction.text import CountVectorizer
|
22 |
+
from sklearn.feature_extraction.text import TfidfTransformer
|
23 |
+
from sklearn.linear_model import SGDClassifier
|
24 |
+
|
25 |
+
from nougat.dataset.staircase import Staircase
|
26 |
+
from nougat.dataset.splitter import (
|
27 |
+
Splitter,
|
28 |
+
get_first_last,
|
29 |
+
get_glob_index,
|
30 |
+
)
|
31 |
+
from nougat.dataset.utils import unicode_to_latex, remove_pretty_linebreaks
|
32 |
+
from nougat.dataset.utils.pdf_text_extract import get_pages, get_paragraphs
|
33 |
+
from nougat.dataset.rasterize import rasterize_paper
|
34 |
+
|
35 |
+
|
36 |
+
class BagOfWords:
|
37 |
+
"""
|
38 |
+
A bag-of-words model for text classification.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
sentences (List[str]): The training sentences.
|
42 |
+
target (Optional[List[int]]): The target labels for the training sentences. Defaults to None.
|
43 |
+
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
sentences: List[str],
|
49 |
+
target: Optional[List[int]] = None,
|
50 |
+
) -> None:
|
51 |
+
self.sentences = sentences
|
52 |
+
self.target = target
|
53 |
+
self.train()
|
54 |
+
|
55 |
+
def train(self):
|
56 |
+
if self.target is None:
|
57 |
+
self.target = np.arange(len(self.sentences))
|
58 |
+
self.count_vect = CountVectorizer()
|
59 |
+
X_train_counts = self.count_vect.fit_transform(self.sentences)
|
60 |
+
self.tfidf_transformer = TfidfTransformer(use_idf=True)
|
61 |
+
X_train_tfidf = self.tfidf_transformer.fit_transform(X_train_counts)
|
62 |
+
self.clf = SGDClassifier(
|
63 |
+
loss="hinge",
|
64 |
+
penalty="l2",
|
65 |
+
alpha=1e-3,
|
66 |
+
random_state=42,
|
67 |
+
max_iter=5,
|
68 |
+
tol=None,
|
69 |
+
)
|
70 |
+
self.clf.fit(X_train_tfidf, self.target)
|
71 |
+
|
72 |
+
def __call__(
|
73 |
+
self, text: Union[str, List[str]], lob_probs: bool = False
|
74 |
+
) -> np.ndarray:
|
75 |
+
if type(text) == str:
|
76 |
+
text = [text]
|
77 |
+
X_new_counts = self.count_vect.transform(text)
|
78 |
+
X_new_tfidf = self.tfidf_transformer.transform(X_new_counts)
|
79 |
+
if lob_probs:
|
80 |
+
return self.clf.predict_log_proba(X_new_tfidf)
|
81 |
+
else:
|
82 |
+
return self.clf.predict(X_new_tfidf)
|
83 |
+
|
84 |
+
|
85 |
+
def remove_short_seqs(seqs: List[str], minimum: int = 10) -> List[str]:
|
86 |
+
"""Remove sequences shorter than the specified minimum length."""
|
87 |
+
out = []
|
88 |
+
for seq in seqs:
|
89 |
+
if len(seq) > minimum:
|
90 |
+
out.append(seq)
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
def find_figures(
|
95 |
+
pdf_pages: List[List[str]], figure_info: Union[Dict, List]
|
96 |
+
) -> List[Tuple[int, int]]:
|
97 |
+
""" "
|
98 |
+
Find the locations of figures in a PDF file.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
pdf_pages (List[List[str]]): The text of the PDF pages.
|
102 |
+
figure_info (Union[Dict, List]): A dictionary or list of dictionaries, where each dictionary
|
103 |
+
specifies the information about a figure, such as its caption, page number, and bounding box.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
List[Tuple[int, int]]: A list of tuples, where each tuple contains the figure index, page number,
|
107 |
+
start position, and end position of the figure in the PDF file.
|
108 |
+
"""
|
109 |
+
figure_locations = []
|
110 |
+
iterator = figure_info.values() if type(figure_info) == dict else [figure_info]
|
111 |
+
for figure_list in iterator:
|
112 |
+
for i, f in enumerate(figure_list):
|
113 |
+
if "caption" in f:
|
114 |
+
fig_string = f["caption"]
|
115 |
+
elif "text" in f:
|
116 |
+
fig_string = f["text"]
|
117 |
+
else:
|
118 |
+
continue
|
119 |
+
fig_string = unicode_to_latex(fig_string)
|
120 |
+
if f["page"] >= len(pdf_pages):
|
121 |
+
continue
|
122 |
+
block, score = Splitter.fuzzysearch(
|
123 |
+
"\n".join(pdf_pages[f["page"]]),
|
124 |
+
fig_string,
|
125 |
+
)
|
126 |
+
if score > 0.8 and block[2] > 0:
|
127 |
+
figure_locations.append((i, f["page"], block[0], block[2]))
|
128 |
+
return figure_locations
|
129 |
+
|
130 |
+
|
131 |
+
def flatten(l: List) -> List:
|
132 |
+
return [item for sublist in l for item in sublist]
|
133 |
+
|
134 |
+
|
135 |
+
def get_doc_text(
|
136 |
+
pdf: str,
|
137 |
+
splitn: bool = True,
|
138 |
+
split_block: bool = True,
|
139 |
+
minlen: Optional[int] = 10,
|
140 |
+
) -> List[List[str]]:
|
141 |
+
"""
|
142 |
+
Get the text from a PDF document.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
doc (str): Path to the PDF document.
|
146 |
+
splitn (bool): Whether to split the text into lines. Defaults to True.
|
147 |
+
split_block (bool): Whether to split the text into blocks. Defaults to True.
|
148 |
+
minlen (Optional[int]): The minimum length of a line or block. Defaults to 10.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
List[List[str]]: The text of the PDF document, either as a list of lines or a list of blocks..
|
152 |
+
"""
|
153 |
+
document_lines = []
|
154 |
+
if split_block:
|
155 |
+
pages = get_paragraphs(pdf)
|
156 |
+
else:
|
157 |
+
pages = [get_pages(pdf)]
|
158 |
+
for blocks in pages:
|
159 |
+
page_lines = []
|
160 |
+
for block in blocks:
|
161 |
+
if splitn:
|
162 |
+
page_lines.extend(block.split("\n"))
|
163 |
+
else:
|
164 |
+
page_lines.append(block)
|
165 |
+
if splitn:
|
166 |
+
page_lines = remove_short_seqs(page_lines, minlen)
|
167 |
+
document_lines.append(page_lines)
|
168 |
+
return document_lines
|
169 |
+
|
170 |
+
|
171 |
+
def clean_pdf_text(pages: List[List[str]], num_words: int = 10) -> List[List[str]]:
|
172 |
+
"""
|
173 |
+
Clean the text of a PDF document by removing frequent words from the beginning and end of each page.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
pages (List[List[str]]): The text of the PDF document, as a list of lists of strings.
|
177 |
+
num_words (int, optional): The number of words to consider at the beginning and end of each page. Defaults to 10.
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
List[List[str]]: The cleaned text of the PDF document.
|
181 |
+
"""
|
182 |
+
words = []
|
183 |
+
for page in pages:
|
184 |
+
first = get_first_last(
|
185 |
+
" ".join(page).lower(), num_words=num_words, first_only=True
|
186 |
+
)
|
187 |
+
words.extend(first.split(" "))
|
188 |
+
word_counts = Counter(words)
|
189 |
+
common_words = [
|
190 |
+
"the",
|
191 |
+
"of",
|
192 |
+
"a",
|
193 |
+
"and",
|
194 |
+
"to",
|
195 |
+
"in",
|
196 |
+
"is",
|
197 |
+
"that",
|
198 |
+
"for",
|
199 |
+
"are",
|
200 |
+
"this",
|
201 |
+
"we",
|
202 |
+
"figure",
|
203 |
+
"fig.",
|
204 |
+
"",
|
205 |
+
]
|
206 |
+
frequent_words = []
|
207 |
+
for w, f in word_counts.items():
|
208 |
+
if w in common_words or w.startswith("\\"):
|
209 |
+
continue
|
210 |
+
if f / len(pages) >= 0.4:
|
211 |
+
frequent_words.append(w)
|
212 |
+
if len(frequent_words) == 0:
|
213 |
+
return pages
|
214 |
+
# remove frequent words from page beginning/end
|
215 |
+
for i in range(len(pages)):
|
216 |
+
page = pages[i]
|
217 |
+
stop = 0
|
218 |
+
page_num_words = 0
|
219 |
+
for p in page:
|
220 |
+
page_num_words += len(p.split(" "))
|
221 |
+
stop += 1
|
222 |
+
if page_num_words >= num_words:
|
223 |
+
break
|
224 |
+
for w in frequent_words:
|
225 |
+
for j in range(stop):
|
226 |
+
if w == "-": # probably page number - \d -
|
227 |
+
pages[i][j] = re.sub(
|
228 |
+
r"-\s*\d{1,3}\s*-", "", pages[i][j], flags=re.IGNORECASE
|
229 |
+
)
|
230 |
+
pages[i][j] = re.sub(re.escape(w), "", pages[i][j], flags=re.IGNORECASE)
|
231 |
+
return pages
|
232 |
+
|
233 |
+
|
234 |
+
def split_markdown(
|
235 |
+
doc: str,
|
236 |
+
pdf_file: str,
|
237 |
+
figure_info: Optional[List[Dict]] = None,
|
238 |
+
doc_fig: Dict[str, str] = {},
|
239 |
+
minlen: int = 3,
|
240 |
+
min_num_words: int = 22,
|
241 |
+
doc_paragraph_chars: int = 1000,
|
242 |
+
min_score: float = 0.75,
|
243 |
+
staircase: bool = True,
|
244 |
+
) -> Tuple[List[str], Dict]:
|
245 |
+
"""
|
246 |
+
Split a PDF document into Markdown paragraphs.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
doc (str): The text of the Markdown document.
|
250 |
+
pdf (str): The PDF document.
|
251 |
+
figure_info (Optional[List[Dict]]): A list of dictionaries, where each dictionary
|
252 |
+
specifies the information about a figure, such as its caption, page number, and bounding box.
|
253 |
+
doc_fig (Dict[str, str]): A dictionary mapping figure ids to LaTeX code.
|
254 |
+
minlen (int): The minimum length of a Markdown paragraph.
|
255 |
+
min_num_words: The minimum number of words in a Markdown paragraph.
|
256 |
+
doc_paragraph_chars: The maximum number of characters in a Markdown paragraph.
|
257 |
+
min_score: The minimum score for a Markdown paragraph to be split.
|
258 |
+
staircase: Whether to split the document into paragraphs with a staircase pattern.
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
Tuple[List[str], Dict]: The list of Markdown paragraphs and the metadata.
|
262 |
+
"""
|
263 |
+
pdf = pypdf.PdfReader(pdf_file)
|
264 |
+
doc_paragraphs_full: List[str] = doc.split("\n")
|
265 |
+
doc_paragraph_lengths = [len(p) for p in doc_paragraphs_full if len(p) > 1]
|
266 |
+
num_lines = 1 + int(doc_paragraph_chars / np.mean(doc_paragraph_lengths))
|
267 |
+
doc_paragraphs_full = [
|
268 |
+
unidecode("\n".join(doc_paragraphs_full[i : i + num_lines]))
|
269 |
+
for i in range(0, len(doc_paragraphs_full), num_lines)
|
270 |
+
]
|
271 |
+
doc_paragraphs: List[str] = []
|
272 |
+
doc_paragraph_indices: List[int] = []
|
273 |
+
for i, p in enumerate(doc_paragraphs_full):
|
274 |
+
if len(p) > 1:
|
275 |
+
doc_paragraphs.append(
|
276 |
+
re.sub(r"(\[(FOOTNOTE|FIGURE|TABLE).*?END\2\])", "", p)
|
277 |
+
)
|
278 |
+
doc_paragraph_indices.append(i)
|
279 |
+
meta = {"pdffigures": figure_info}
|
280 |
+
if len(pdf.pages) > 1:
|
281 |
+
pdf_text = get_doc_text(pdf_file, True, True, minlen)
|
282 |
+
pdf_content = [
|
283 |
+
[unicode_to_latex(q).replace("\n", " ") for q in p if len(q) >= minlen]
|
284 |
+
for p in pdf_text
|
285 |
+
]
|
286 |
+
|
287 |
+
pdf_content = clean_pdf_text(pdf_content)
|
288 |
+
if figure_info is not None:
|
289 |
+
figure_locations = sorted(
|
290 |
+
find_figures(pdf_content, figure_info), key=itemgetter(2), reverse=True
|
291 |
+
)
|
292 |
+
clean_pdf_content = deepcopy(pdf_content)
|
293 |
+
for i, page_content in enumerate(pdf_content):
|
294 |
+
len_sentences = np.cumsum([0] + [len(p) for p in page_content])
|
295 |
+
for match in figure_locations:
|
296 |
+
_, page, start, len_ = match
|
297 |
+
if i != page:
|
298 |
+
continue
|
299 |
+
a, b = (
|
300 |
+
get_glob_index(len_sentences, start),
|
301 |
+
get_glob_index(len_sentences, start + len_) + 1,
|
302 |
+
)
|
303 |
+
for j, k in enumerate(range(a, b + 1)):
|
304 |
+
if len(clean_pdf_content[i]) == k:
|
305 |
+
break
|
306 |
+
if j == 0:
|
307 |
+
clean_pdf_content[i][k] = clean_pdf_content[i][k][
|
308 |
+
: start - len_sentences[k]
|
309 |
+
]
|
310 |
+
elif k == b:
|
311 |
+
clean_pdf_content[i][k] = clean_pdf_content[i][k][
|
312 |
+
start + len_ - len_sentences[k] :
|
313 |
+
]
|
314 |
+
else:
|
315 |
+
clean_pdf_content[i][k] = ""
|
316 |
+
clean_pdf_content[i] = remove_short_seqs(clean_pdf_content[i], 0)
|
317 |
+
pdf_content = clean_pdf_content
|
318 |
+
paragraphs = flatten(pdf_content)
|
319 |
+
num_paragraphs = np.cumsum([0] + [len(page) for page in pdf_content])
|
320 |
+
if staircase:
|
321 |
+
# train bag of words
|
322 |
+
page_target = np.zeros(len(paragraphs))
|
323 |
+
page_target[num_paragraphs[1:-1] - 1] = 1
|
324 |
+
page_target = np.cumsum(page_target).astype(int)
|
325 |
+
model = BagOfWords(paragraphs, target=page_target)
|
326 |
+
labels = model(doc_paragraphs)
|
327 |
+
|
328 |
+
# fit stair case function
|
329 |
+
x = np.arange(len(labels))
|
330 |
+
stairs = Staircase(len(labels), labels.max() + 1)
|
331 |
+
stairs.fit(x, labels)
|
332 |
+
boundaries = (stairs.get_boundaries().astype(int)).tolist()
|
333 |
+
boundaries.insert(0, 0)
|
334 |
+
else:
|
335 |
+
boundaries = [0] * (len(pdf.pages))
|
336 |
+
splitter = Splitter(doc_paragraphs)
|
337 |
+
pages = [(0, 0, 1.0)]
|
338 |
+
meta["first_words"] = []
|
339 |
+
meta["last_words"] = []
|
340 |
+
for i in range(1, len(boundaries)):
|
341 |
+
delta = (
|
342 |
+
math.ceil(stairs.uncertainty[i - 1]) + 5
|
343 |
+
if staircase
|
344 |
+
else len(doc_paragraphs)
|
345 |
+
)
|
346 |
+
words_f = []
|
347 |
+
words_l = []
|
348 |
+
for p in pdf_content[i]:
|
349 |
+
words_f.extend(p.split(" "))
|
350 |
+
if len(words_f) >= min_num_words:
|
351 |
+
break
|
352 |
+
for p in pdf_content[i - 1][::-1]:
|
353 |
+
words_l.extend(p.split(" ")[::-1])
|
354 |
+
if len(words_l) >= min_num_words:
|
355 |
+
words_l = words_l[::-1]
|
356 |
+
break
|
357 |
+
if len(words_f) < 2:
|
358 |
+
pages.append(pages[-1])
|
359 |
+
first_words = " ".join(words_f[:min_num_words]).strip()
|
360 |
+
last_words = " ".join(words_l[-min_num_words:]).strip()
|
361 |
+
meta["first_words"].append(first_words)
|
362 |
+
meta["last_words"].append(last_words)
|
363 |
+
if len(first_words) < minlen and len(last_words) < minlen:
|
364 |
+
pages.append(pages[-1])
|
365 |
+
continue
|
366 |
+
pages.append(
|
367 |
+
splitter.split_first_last(
|
368 |
+
boundaries[i],
|
369 |
+
first_words,
|
370 |
+
last_words,
|
371 |
+
delta=delta,
|
372 |
+
)
|
373 |
+
)
|
374 |
+
elif len(pdf.pages) == 1: # single page
|
375 |
+
pages = [(0, 0, 1)]
|
376 |
+
else:
|
377 |
+
return
|
378 |
+
pages.append((len(doc_paragraphs), -1, 1.0))
|
379 |
+
out = []
|
380 |
+
page_scores = {}
|
381 |
+
for i in range(len(pages) - 1):
|
382 |
+
score = (pages[i][2] + pages[i + 1][2]) * 0.5
|
383 |
+
if score >= min_score:
|
384 |
+
end = pages[i + 1][0]
|
385 |
+
if end >= len(doc_paragraph_indices):
|
386 |
+
end = None
|
387 |
+
else:
|
388 |
+
end = doc_paragraph_indices[pages[i + 1][0]] + 1
|
389 |
+
lines = doc_paragraphs_full[doc_paragraph_indices[pages[i][0]] : end]
|
390 |
+
if len(lines) > 0:
|
391 |
+
lines[0] = lines[0][pages[i][1] :]
|
392 |
+
lines[-1] = lines[-1][: pages[i + 1][1]]
|
393 |
+
else:
|
394 |
+
lines = []
|
395 |
+
page_content = "\n".join(lines)
|
396 |
+
page_content = remove_pretty_linebreaks(page_content)
|
397 |
+
page_scores[i] = score
|
398 |
+
out.append(page_content)
|
399 |
+
|
400 |
+
meta["page_splits"] = pages
|
401 |
+
meta["page_scores"] = page_scores
|
402 |
+
meta["num_pages"] = len(pdf.pages)
|
403 |
+
|
404 |
+
# Reintroduce figures, tables and footnotes
|
405 |
+
figure_tex = list(doc_fig.keys()), list(doc_fig.values())
|
406 |
+
if len(doc_fig) > 0:
|
407 |
+
iterator = figure_info.values() if type(figure_info) == dict else [figure_info]
|
408 |
+
for figure_list in iterator:
|
409 |
+
if not figure_list:
|
410 |
+
continue
|
411 |
+
for i, f in enumerate(figure_list):
|
412 |
+
if "caption" in f:
|
413 |
+
fig_string = f["caption"]
|
414 |
+
elif "text" in f:
|
415 |
+
fig_string = f["text"]
|
416 |
+
else:
|
417 |
+
continue
|
418 |
+
ratios = []
|
419 |
+
for tex in figure_tex[1]:
|
420 |
+
if f["figType"] == "Table":
|
421 |
+
tex = tex.partition(r"\end{table}")[2]
|
422 |
+
ratios.append(Levenshtein.ratio(tex, fig_string))
|
423 |
+
k = np.argmax(ratios)
|
424 |
+
if ratios[k] < 0.8:
|
425 |
+
continue
|
426 |
+
if f["page"] < len(out) and out[f["page"]] != "":
|
427 |
+
out[f["page"]] += "\n\n" + remove_pretty_linebreaks(
|
428 |
+
figure_tex[1][k].strip()
|
429 |
+
)
|
430 |
+
|
431 |
+
for i in range(len(out)):
|
432 |
+
foot_match = re.findall(r"\[FOOTNOTE(.*?)\]\[ENDFOOTNOTE\]", out[i])
|
433 |
+
for match in foot_match:
|
434 |
+
out[i] = out[i].replace(
|
435 |
+
"[FOOTNOTE%s][ENDFOOTNOTE]" % match,
|
436 |
+
doc_fig.get("FOOTNOTE%s" % match, ""),
|
437 |
+
)
|
438 |
+
|
439 |
+
out[i] = re.sub(r"\[(FIGURE|TABLE)(.*?)\](.*?)\[END\1\]", "", out[i])
|
440 |
+
return out, meta
|
441 |
+
|
442 |
+
|
443 |
+
if __name__ == "__main__":
|
444 |
+
parser = argparse.ArgumentParser()
|
445 |
+
parser.add_argument("--md", type=str, help="Markdown file", required=True)
|
446 |
+
parser.add_argument("--pdf", type=str, help="PDF File", required=True)
|
447 |
+
parser.add_argument("--out", type=str, help="Out dir", required=True)
|
448 |
+
parser.add_argument(
|
449 |
+
"--figure",
|
450 |
+
type=str,
|
451 |
+
help="Figure info JSON",
|
452 |
+
)
|
453 |
+
parser.add_argument("--dpi", type=int, default=96)
|
454 |
+
args = parser.parse_args()
|
455 |
+
md = open(args.md, "r", encoding="utf-8").read().replace("\xa0", " ")
|
456 |
+
pdf = pypdf.PdfReader(args.pdf)
|
457 |
+
try:
|
458 |
+
fig_info = json.load(open(args.figure, "r", encoding="utf-8"))
|
459 |
+
except FileNotFoundError:
|
460 |
+
fig_info = None
|
461 |
+
pages, meta = split_markdown(md, pdf, fig_info)
|
462 |
+
if args.out:
|
463 |
+
outpath = os.path.join(args.out, os.path.basename(args.pdf).partition(".")[0])
|
464 |
+
os.makedirs(outpath, exist_ok=True)
|
465 |
+
found_pages = []
|
466 |
+
for i, content in enumerate(pages):
|
467 |
+
if content:
|
468 |
+
with open(
|
469 |
+
os.path.join(
|
470 |
+
outpath, "%02d_s=%.2f.mmd" % (i + 1, meta["page_scores"][i])
|
471 |
+
),
|
472 |
+
"w",
|
473 |
+
encoding="utf-8",
|
474 |
+
) as f:
|
475 |
+
f.write(content)
|
476 |
+
found_pages.append(i)
|
477 |
+
rasterize_paper(pdf, outpath, dpi=args.dpi, pages=found_pages)
|
nougat/dataset/splitter.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
from difflib import SequenceMatcher
|
8 |
+
from operator import itemgetter
|
9 |
+
from typing import List, Tuple, Union
|
10 |
+
import re
|
11 |
+
import numpy as np
|
12 |
+
from Levenshtein.StringMatcher import StringMatcher
|
13 |
+
import Levenshtein
|
14 |
+
from fuzzysearch import find_near_matches
|
15 |
+
|
16 |
+
math_start_regex = re.compile(r"(?<!\\)\\[\[\(]", re.M)
|
17 |
+
math_end_regex = re.compile(r"(?<!\\)\\[\]\)]", re.M)
|
18 |
+
|
19 |
+
|
20 |
+
def reverse(lst: List[str]) -> List[str]:
|
21 |
+
"""Reverses a list and the strings inside
|
22 |
+
|
23 |
+
Args:
|
24 |
+
lst (List[str]): List to process
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
List[str]: Reversed list
|
28 |
+
"""
|
29 |
+
out = lst[::-1]
|
30 |
+
for i in range(len(out)):
|
31 |
+
out[i] = out[i][::-1]
|
32 |
+
return out
|
33 |
+
|
34 |
+
|
35 |
+
def get_first_last(
|
36 |
+
s: str,
|
37 |
+
num_words: int = 8,
|
38 |
+
delim: str = " ",
|
39 |
+
first_only: bool = False,
|
40 |
+
last_only: bool = False,
|
41 |
+
) -> Union[Tuple[str, str], str]:
|
42 |
+
"""
|
43 |
+
Get the first and last `num_words` from a string `s`.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
s (str): The string.
|
47 |
+
num_words (int): The number of words.
|
48 |
+
delim (str): The delimiter between words.
|
49 |
+
first_only (bool): Whether to only get the first `num_words`.
|
50 |
+
last_only (bool): Whether to only get the last `num_words`.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Union[Tuple[str, str], str]: The first and last `num_words` from `s`, or `s` if `num_words` is 0.
|
54 |
+
"""
|
55 |
+
s = s.split(delim)
|
56 |
+
if not first_only and not last_only:
|
57 |
+
return delim.join(s[:num_words]), delim.join(s[-num_words:])
|
58 |
+
elif first_only:
|
59 |
+
return delim.join(s[:num_words])
|
60 |
+
elif last_only:
|
61 |
+
return delim.join(s[-num_words:])
|
62 |
+
|
63 |
+
|
64 |
+
def get_glob_index(
|
65 |
+
lengths: List[int], ind: int, return_breakpoints: bool = False
|
66 |
+
) -> int:
|
67 |
+
"""returns the index where ind is closest and greater than the lengths"""
|
68 |
+
breakpoints = np.cumsum(lengths)
|
69 |
+
overlap = breakpoints - ind
|
70 |
+
overlap[overlap > 0] = -int(1e5)
|
71 |
+
indices = overlap.argmax(0)
|
72 |
+
if return_breakpoints:
|
73 |
+
return indices, breakpoints
|
74 |
+
else:
|
75 |
+
return indices
|
76 |
+
|
77 |
+
|
78 |
+
# table-header-figure regex
|
79 |
+
# thf_regex = re.compile(r"(\[(FOOTNOTE|FIGURE|TABLE).*?END\2\])")
|
80 |
+
|
81 |
+
|
82 |
+
class Splitter:
|
83 |
+
_split_locs: List[Tuple[int, int]] = None
|
84 |
+
|
85 |
+
def __init__(self, paragraphs: List[str]) -> None:
|
86 |
+
self.paragraphs = paragraphs
|
87 |
+
self.paragraphs_no_space = [self.remove_special_chars(h) for h in paragraphs]
|
88 |
+
self._split_locs = [(0, 0)]
|
89 |
+
self.paragraphs_rev = reverse(self.paragraphs)
|
90 |
+
self.paragraphs_rev_no_space = reverse(self.paragraphs_no_space)
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
def remove_special_chars(string: str) -> str:
|
94 |
+
# string = thf_regex.sub(r"", string)
|
95 |
+
return (
|
96 |
+
string.replace("\\ ", "")
|
97 |
+
.replace(" ", "")
|
98 |
+
.replace("\n", "")
|
99 |
+
.replace("*", "")
|
100 |
+
.replace("_", "")
|
101 |
+
.replace("^", "")
|
102 |
+
.replace("\\[", "")
|
103 |
+
.replace("\\]", "")
|
104 |
+
.replace("\\(", "")
|
105 |
+
.replace("\\)", "")
|
106 |
+
.replace("\\right", "")
|
107 |
+
.replace("\\left", "")
|
108 |
+
.replace("\\sum", "X") # old latex unicode encoding issue
|
109 |
+
.replace("{", "")
|
110 |
+
.replace("}", "")
|
111 |
+
.replace("#", "")
|
112 |
+
.replace("[REF]", "")
|
113 |
+
.replace("[ENDREF]", "")
|
114 |
+
.replace("\\varphi", "\\phi") # https://meta.stackexchange.com/a/349360
|
115 |
+
.replace("\\quad", "")
|
116 |
+
.replace("\\qquad", "")
|
117 |
+
.replace("\\hskip", "")
|
118 |
+
.replace("\\vskip", "")
|
119 |
+
.replace("\\frac", "")
|
120 |
+
.replace("\\rm", "")
|
121 |
+
.replace("\\,", "")
|
122 |
+
.replace("-", "")
|
123 |
+
.lower()
|
124 |
+
)
|
125 |
+
|
126 |
+
@staticmethod
|
127 |
+
def count_special_chars(string: str, char_ind: int) -> int:
|
128 |
+
if len(string) == 0:
|
129 |
+
return 0
|
130 |
+
add_space_ind = 0
|
131 |
+
while True:
|
132 |
+
string_ = string[: char_ind + add_space_ind]
|
133 |
+
# last_first = string[: char_ind + add_space_ind+]
|
134 |
+
add = (
|
135 |
+
string_.count(" ")
|
136 |
+
+ string_.count("\\ ") * 2
|
137 |
+
+ string_.count("\n")
|
138 |
+
+ string_.count("*")
|
139 |
+
+ string_.count("_")
|
140 |
+
+ string_.count("^")
|
141 |
+
+ string_.count("\\[") * 2
|
142 |
+
+ string_.count("\\]") * 2
|
143 |
+
+ string_.count("\\(") * 2
|
144 |
+
+ string_.count("\\)") * 2
|
145 |
+
+ string_.count("\\right") * 6
|
146 |
+
+ string_.count("\\left") * 5
|
147 |
+
+ string_.count("\\sum") * 3 # replaced to X that's why not 4
|
148 |
+
+ string_.count("{")
|
149 |
+
+ string_.count("}")
|
150 |
+
+ string_.count("#")
|
151 |
+
+ string_.count("[REF]") * 5
|
152 |
+
+ string_.count("[ENDREF]") * 8
|
153 |
+
+ string_.count("\\varphi") * 3
|
154 |
+
+ string_.count("\\quad") * 5
|
155 |
+
+ string_.count("\\qquad") * 6
|
156 |
+
+ string_.count("\\hskip") * 6
|
157 |
+
+ string_.count("\\vskip") * 6
|
158 |
+
+ string_.count("\\frac") * 5
|
159 |
+
+ string_.count("\\rm") * 3
|
160 |
+
+ string_.count("\\,") * 2
|
161 |
+
+ string_.count("-")
|
162 |
+
)
|
163 |
+
if add == add_space_ind:
|
164 |
+
break
|
165 |
+
add_space_ind = add
|
166 |
+
if len(string) <= char_ind + add_space_ind:
|
167 |
+
add_space_ind = max(0, len(string) - 1 - char_ind)
|
168 |
+
|
169 |
+
# check first chars of rest if they match closing expressions
|
170 |
+
while True:
|
171 |
+
rest = string[char_ind + add_space_ind :]
|
172 |
+
string_ = string[: char_ind + add_space_ind]
|
173 |
+
section_title = re.match(r"#+\s?\d*\s*$", string_)
|
174 |
+
if rest.startswith("\\]") or rest.startswith("\\)"):
|
175 |
+
add_space_ind += 2
|
176 |
+
elif (rest.startswith(")") or rest.startswith("]")) and string_.endswith(
|
177 |
+
"\\"
|
178 |
+
):
|
179 |
+
add_space_ind += 1
|
180 |
+
elif (rest.startswith("(") or rest.startswith("[")) and string_.endswith(
|
181 |
+
"\\"
|
182 |
+
):
|
183 |
+
add_space_ind -= 1
|
184 |
+
elif rest.startswith(" "):
|
185 |
+
add_space_ind += 1
|
186 |
+
elif section_title:
|
187 |
+
add_space_ind -= section_title.end() - section_title.start()
|
188 |
+
elif (
|
189 |
+
re.match(r"^[^\w\s]*_\s", rest)
|
190 |
+
or re.match(r"^[^\w\s]*\*\*?\s", rest)
|
191 |
+
or re.match(r"^.\n", rest)
|
192 |
+
):
|
193 |
+
add_space_ind += 1
|
194 |
+
else:
|
195 |
+
break
|
196 |
+
# check if it starts in a math env and include everything before
|
197 |
+
end = math_end_regex.search(rest)
|
198 |
+
if end is not None:
|
199 |
+
start = math_start_regex.search(rest)
|
200 |
+
if start is None or start.start() > end.start():
|
201 |
+
inds = [
|
202 |
+
m.start()
|
203 |
+
for m in math_start_regex.finditer(string_)
|
204 |
+
if m.start() < end.start() + len(string_)
|
205 |
+
]
|
206 |
+
if len(inds) > 0:
|
207 |
+
add_space_ind = inds[-1] - char_ind
|
208 |
+
# assert string_[char_ind+add_space_ind]=='\\'
|
209 |
+
return add_space_ind
|
210 |
+
|
211 |
+
def split_first_last(
|
212 |
+
self, index: int, first: str, last: str, delta: int = 5
|
213 |
+
) -> Tuple[int, int, float]:
|
214 |
+
"""Refines a split by looking at both the first words from a new page and the last words from the previous page.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
index (int): paragraph index
|
218 |
+
first (str): first words
|
219 |
+
last (str): last words
|
220 |
+
delta (int, optional): paragraph search radius. Defaults to 5.
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
Tuple[int, int, float]: split prediction
|
224 |
+
"""
|
225 |
+
if first:
|
226 |
+
first_split = glob_f, char_f, score_f = self.split(
|
227 |
+
index, first, delta=delta
|
228 |
+
)
|
229 |
+
if last:
|
230 |
+
last_split = glob_l, char_l, score_l = self.split(
|
231 |
+
index, last, delta=delta, reverse=True
|
232 |
+
)
|
233 |
+
if first and not last:
|
234 |
+
return first_split
|
235 |
+
elif not first and last:
|
236 |
+
return last_split
|
237 |
+
elif not first and not last:
|
238 |
+
return index, 0, 0.0
|
239 |
+
if char_f == char_l and glob_f == glob_l and (score_f > 0.5 or score_l > 0.5):
|
240 |
+
return glob_l, char_l, 1.0
|
241 |
+
|
242 |
+
# score calculation
|
243 |
+
first, last = self.remove_special_chars(first), self.remove_special_chars(last)
|
244 |
+
matching = []
|
245 |
+
for split in (first_split, last_split):
|
246 |
+
first_source = []
|
247 |
+
num_chars_first = len(first)
|
248 |
+
num_chars_last = len(last)
|
249 |
+
last_source = []
|
250 |
+
for i, p in enumerate(self.paragraphs[split[0] :]):
|
251 |
+
if i == 0:
|
252 |
+
p = p[split[1] :]
|
253 |
+
first_source.append(self.remove_special_chars(p))
|
254 |
+
if sum([len(s) for s in first_source]) >= num_chars_first:
|
255 |
+
break
|
256 |
+
first_source = "".join(first_source)[:num_chars_first]
|
257 |
+
for i, p in enumerate(self.paragraphs[split[0] :: -1]):
|
258 |
+
if i == 0:
|
259 |
+
p = p[: split[1]]
|
260 |
+
last_source.insert(0, self.remove_special_chars(p))
|
261 |
+
if sum([len(s) for s in last_source]) >= num_chars_last:
|
262 |
+
last_source = last_source
|
263 |
+
break
|
264 |
+
last_source = "".join(last_source)[-num_chars_last:]
|
265 |
+
matching.append(
|
266 |
+
[
|
267 |
+
Levenshtein.ratio(first, first_source)
|
268 |
+
* Levenshtein.ratio(first[:10], first_source[:10]),
|
269 |
+
Levenshtein.ratio(last, last_source)
|
270 |
+
* Levenshtein.ratio(last[-10:], last_source[-10:]),
|
271 |
+
]
|
272 |
+
)
|
273 |
+
scores = np.asarray(matching).max(0)
|
274 |
+
return (
|
275 |
+
(glob_l, char_l, scores[1])
|
276 |
+
if scores.argmax()
|
277 |
+
else (glob_f, char_f, scores[0])
|
278 |
+
)
|
279 |
+
|
280 |
+
def split(
|
281 |
+
self, index: int, string: str, delta: int = 5, reverse: bool = False
|
282 |
+
) -> Tuple[int, int, float]:
|
283 |
+
"""
|
284 |
+
refine split prediction. `string` are the first words from new page.
|
285 |
+
delta can be used as uncertainty measure.
|
286 |
+
returns new index and split index
|
287 |
+
"""
|
288 |
+
if reverse:
|
289 |
+
index = len(self.paragraphs) - 1 - index
|
290 |
+
string = string[::-1]
|
291 |
+
paragraphs = self.paragraphs_rev
|
292 |
+
paragraphs_no_space = self.paragraphs_rev_no_space
|
293 |
+
else:
|
294 |
+
paragraphs = self.paragraphs
|
295 |
+
paragraphs_no_space = self.paragraphs_no_space
|
296 |
+
|
297 |
+
string_ = self.remove_special_chars(string)
|
298 |
+
start_ind = max(0, index - delta)
|
299 |
+
search_corpus = paragraphs_no_space[start_ind : index + delta + 1]
|
300 |
+
lengths = np.asarray([0] + [len(p) for p in search_corpus])
|
301 |
+
corp = "".join(search_corpus)
|
302 |
+
if len(corp) == 0:
|
303 |
+
self._split_locs.append((index, 0))
|
304 |
+
return index, 0, 1
|
305 |
+
ind, score = self._find_match(corp, string_)
|
306 |
+
indices, breakpoints = get_glob_index(lengths, ind, True)
|
307 |
+
global_ind, char_ind = int(start_ind + indices), int(ind - breakpoints[indices])
|
308 |
+
self._split_locs.append((global_ind, char_ind))
|
309 |
+
if reverse:
|
310 |
+
char_ind = len(paragraphs_no_space[global_ind]) - char_ind
|
311 |
+
global_ind = len(paragraphs) - global_ind - 1
|
312 |
+
add_space_ind = self.count_special_chars(self.paragraphs[global_ind], char_ind)
|
313 |
+
return global_ind, char_ind + add_space_ind, score
|
314 |
+
|
315 |
+
def _find_match(
|
316 |
+
self, corp: str, key: str, get_start: bool = True
|
317 |
+
) -> Tuple[int, float]:
|
318 |
+
block, score = self._fuzzy(corp, key)
|
319 |
+
index = max(0, block[0])
|
320 |
+
if not get_start:
|
321 |
+
index += block[2]
|
322 |
+
return index, score
|
323 |
+
|
324 |
+
@staticmethod
|
325 |
+
def _fuzzy(
|
326 |
+
corpus: str, string: str, max_error_rate: float = 0.025
|
327 |
+
) -> Tuple[Tuple[int, int, int], float]:
|
328 |
+
max_dist = min(len(string) - 1, int(len(string) * min(0.9, max_error_rate)) + 5)
|
329 |
+
matches = find_near_matches(string, corpus, max_l_dist=max_dist)
|
330 |
+
if len(matches) > 0 and max_dist > 0:
|
331 |
+
match = min(matches, key=lambda x: x.dist)
|
332 |
+
block = (match.start, 0, match.end - match.start)
|
333 |
+
score = 1 - match.dist / max_dist
|
334 |
+
return block, score
|
335 |
+
return (0, 0, 0), 0
|
336 |
+
|
337 |
+
@staticmethod
|
338 |
+
def fuzzysearch(
|
339 |
+
corpus: str, string: str, max_error_rate: float = 0.025
|
340 |
+
) -> Tuple[Tuple[int, int, int], float]:
|
341 |
+
corpus_ = Splitter.remove_special_chars(corpus)
|
342 |
+
string_ = Splitter.remove_special_chars(string)
|
343 |
+
(start, _, dist), score = Splitter._fuzzy(
|
344 |
+
corpus_, string_, max_error_rate=max_error_rate
|
345 |
+
)
|
346 |
+
end = Splitter.count_special_chars(corpus, start + dist) + start + dist
|
347 |
+
start = start + Splitter.count_special_chars(corpus, start)
|
348 |
+
return (start, _, end - start), score
|
349 |
+
|
350 |
+
@staticmethod
|
351 |
+
def oldfuzz(corpus, string):
|
352 |
+
res = []
|
353 |
+
for Matcher in [StringMatcher, SequenceMatcher]:
|
354 |
+
m = Matcher(None, corpus, string)
|
355 |
+
blocks = m.get_matching_blocks()
|
356 |
+
scores = []
|
357 |
+
for i, block in enumerate(blocks):
|
358 |
+
m2 = Matcher(
|
359 |
+
None,
|
360 |
+
corpus[block[0] : block[0] + max(block[2], len(string))],
|
361 |
+
string,
|
362 |
+
)
|
363 |
+
r = m2.ratio()
|
364 |
+
if r > 0.995:
|
365 |
+
return blocks[i], r
|
366 |
+
else:
|
367 |
+
scores.append(r)
|
368 |
+
ind = np.argmax(scores)
|
369 |
+
res.append((blocks[ind], scores[ind]))
|
370 |
+
return max(res, key=itemgetter(1))
|
371 |
+
|
372 |
+
def evaluate_split(self, page_num: int, page_content: str) -> float:
|
373 |
+
if page_num > len(self._split_locs) or page_num < 1:
|
374 |
+
return 0
|
375 |
+
page_content = self.remove_special_chars(page_content)
|
376 |
+
if page_num == len(self._split_locs):
|
377 |
+
start, end = self._split_locs[-1], (-1, -1)
|
378 |
+
else:
|
379 |
+
start, end = self._split_locs[page_num - 1], self._split_locs[page_num]
|
380 |
+
if (end[0] + 1) - start[0] < 0:
|
381 |
+
return 0
|
382 |
+
doc_content = self.paragraphs_no_space[start[0] : (end[0] + 1) or None]
|
383 |
+
if (
|
384 |
+
len(doc_content) < 1
|
385 |
+
or len(doc_content[0]) < start[1]
|
386 |
+
or len(doc_content[-1]) < end[1]
|
387 |
+
):
|
388 |
+
return 0
|
389 |
+
doc_content[0] = doc_content[0][start[1] :]
|
390 |
+
doc_content[-1] = doc_content[-1][: end[1]]
|
391 |
+
doc_content = "".join(doc_content)
|
392 |
+
match = StringMatcher(None, page_content, doc_content).ratio()
|
393 |
+
return match
|
nougat/dataset/staircase.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
from collections import deque
|
8 |
+
import operator
|
9 |
+
import itertools
|
10 |
+
from typing import Optional, List, Tuple
|
11 |
+
import numpy as np
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
warnings.filterwarnings("ignore", message="All-NaN slice encountered")
|
15 |
+
|
16 |
+
|
17 |
+
def stair_func(x: np.ndarray, thresholds: np.ndarray) -> np.ndarray:
|
18 |
+
return np.heaviside(x[:, None] - np.floor(thresholds)[None, :], 0).sum(1)
|
19 |
+
|
20 |
+
|
21 |
+
def compute_gini(labels: np.ndarray) -> float:
|
22 |
+
N = len(labels)
|
23 |
+
if N == 0:
|
24 |
+
return 0
|
25 |
+
G = N - np.square(np.bincount(labels)).sum() / N
|
26 |
+
return G
|
27 |
+
|
28 |
+
|
29 |
+
def compute_binary_gini(labels: np.ndarray) -> float:
|
30 |
+
N = len(labels)
|
31 |
+
if N == 0:
|
32 |
+
return 0
|
33 |
+
G = N - labels.sum() ** 2 / N
|
34 |
+
return G
|
35 |
+
|
36 |
+
|
37 |
+
def gini_impurity(
|
38 |
+
thresholds: np.ndarray,
|
39 |
+
data: np.ndarray,
|
40 |
+
labels: np.ndarray,
|
41 |
+
classes: Optional[List[int]] = None,
|
42 |
+
reduction: Optional[str] = "sum",
|
43 |
+
padded: bool = True,
|
44 |
+
) -> float:
|
45 |
+
"""
|
46 |
+
Calculate the Gini impurity of a dataset split on a set of thresholds.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
thresholds (np.ndarray): The thresholds to split the data on.
|
50 |
+
data (np.ndarray): The data to split.
|
51 |
+
labels (np.ndarray): The labels for the data.
|
52 |
+
classes (Optional[List[int]]): The classes to consider. If None, all classes are used.
|
53 |
+
reduction (Optional[str]): The reduction to apply to the impurity. One of "none", "sum", or "mean".
|
54 |
+
padded (bool): Whether to pad the thresholds with `[-0.5, data.max() + 0.5]`.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
float: The Gini impurity.
|
58 |
+
"""
|
59 |
+
G = []
|
60 |
+
if not padded:
|
61 |
+
thresholds = np.insert(
|
62 |
+
thresholds, [0, len(thresholds)], [-0.5, data.max() + 0.5]
|
63 |
+
)
|
64 |
+
if classes is None:
|
65 |
+
classes = np.arange(len(thresholds) - 1)
|
66 |
+
else:
|
67 |
+
classes = np.asarray(classes)
|
68 |
+
if data.ndim == 1:
|
69 |
+
data = np.expand_dims(data, 0)
|
70 |
+
masks = np.logical_and(
|
71 |
+
data > thresholds[classes, None],
|
72 |
+
data <= thresholds[classes + 1, None],
|
73 |
+
)
|
74 |
+
for i, c in enumerate(classes):
|
75 |
+
G.append(compute_binary_gini(np.where(labels[masks[i]] == c, 1, 0)))
|
76 |
+
|
77 |
+
if reduction is None or reduction == "none":
|
78 |
+
return G
|
79 |
+
elif reduction == "sum":
|
80 |
+
return sum(G)
|
81 |
+
elif reduction == "mean":
|
82 |
+
return sum(G) / len(G)
|
83 |
+
else:
|
84 |
+
raise NotImplementedError
|
85 |
+
|
86 |
+
|
87 |
+
def step_impurity(
|
88 |
+
thresholds,
|
89 |
+
data: np.ndarray,
|
90 |
+
labels: np.ndarray,
|
91 |
+
classes: Optional[List[int]] = None,
|
92 |
+
) -> float:
|
93 |
+
"""
|
94 |
+
Calculate the step-wise Gini impurity of a dataset split on a set of thresholds.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
thresholds (np.ndarray): The thresholds to split the data on.
|
98 |
+
data (np.ndarray): The data to split.
|
99 |
+
labels (np.ndarray): The labels for the data.
|
100 |
+
classes (Optional[List[int]]): The classes to consider. If None, all classes are used.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
float: The step-wise Gini impurity.
|
104 |
+
"""
|
105 |
+
G = gini_impurity(thresholds, data, labels, reduction=None, classes=classes)
|
106 |
+
out = []
|
107 |
+
for i in range(len(G) - 1):
|
108 |
+
out.append(G[i] + G[i + 1])
|
109 |
+
return out
|
110 |
+
|
111 |
+
|
112 |
+
class PaddedArray:
|
113 |
+
"""
|
114 |
+
A wrapper class for an array that allows for relative indexing.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
array (np.ndarray): The array to wrap.
|
118 |
+
range (Optional[Tuple[int, int]]): The range of the array to expose. Defaults to (1, -1).
|
119 |
+
"""
|
120 |
+
|
121 |
+
def __init__(
|
122 |
+
self, array: np.ndarray, range: Optional[Tuple[int, int]] = (1, -1)
|
123 |
+
) -> None:
|
124 |
+
self.array = array
|
125 |
+
mi, ma = range
|
126 |
+
assert ma <= 0, "relative assignment only"
|
127 |
+
self.range = mi, ma
|
128 |
+
|
129 |
+
def __len__(self):
|
130 |
+
return len(self.array) + self.range[1] - self.range[0]
|
131 |
+
|
132 |
+
def _process_index(self, index):
|
133 |
+
if isinstance(index, slice):
|
134 |
+
index = slice(
|
135 |
+
(index.start or 0) + self.range[0],
|
136 |
+
self.range[0] + (len(self) if index.stop is None else index.stop),
|
137 |
+
index.step,
|
138 |
+
)
|
139 |
+
if index.stop > len(self.array):
|
140 |
+
raise IndexError
|
141 |
+
else:
|
142 |
+
index = index + self.range[0]
|
143 |
+
if index > len(self):
|
144 |
+
raise IndexError
|
145 |
+
return index
|
146 |
+
|
147 |
+
def __getitem__(self, index):
|
148 |
+
index = self._process_index(index)
|
149 |
+
return self.array[index]
|
150 |
+
|
151 |
+
def __setitem__(self, index, value):
|
152 |
+
self.array[self._process_index(index)] = value
|
153 |
+
|
154 |
+
def copy(self):
|
155 |
+
return PaddedArray(self.array.copy(), self.range)
|
156 |
+
|
157 |
+
def toarray(self):
|
158 |
+
return self.array[self.range[0] : self.range[1]]
|
159 |
+
|
160 |
+
|
161 |
+
class Staircase:
|
162 |
+
"""
|
163 |
+
A class for learning a staircase decision tree.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
domain: The number of points in the domain.
|
167 |
+
n_classes: The number of classes.
|
168 |
+
"""
|
169 |
+
|
170 |
+
def __init__(self, domain: int, n_classes: int) -> None:
|
171 |
+
self.domain = domain
|
172 |
+
self.classes = n_classes
|
173 |
+
assert domain > 0
|
174 |
+
assert n_classes > 0
|
175 |
+
self.thresholds = self._back_thres = self._forward_thres = np.linspace(
|
176 |
+
domain / n_classes, domain, n_classes - 1, endpoint=False
|
177 |
+
)
|
178 |
+
self.uncertainty = np.zeros_like(self.thresholds)
|
179 |
+
|
180 |
+
def statistic_fit(
|
181 |
+
self,
|
182 |
+
data: np.ndarray,
|
183 |
+
labels: np.ndarray,
|
184 |
+
):
|
185 |
+
"""
|
186 |
+
Fit statistical thresholds for anomaly detection.
|
187 |
+
|
188 |
+
This method fits statistical thresholds for anomaly detection based on input data and labels.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
data (np.ndarray): The input data.
|
192 |
+
labels (np.ndarray): The labels corresponding to the data.
|
193 |
+
|
194 |
+
Note:
|
195 |
+
This method modifies the internal state of the object to set statistical thresholds.
|
196 |
+
"""
|
197 |
+
onehot = np.eye(self.classes)[labels.reshape(-1)]
|
198 |
+
onehot.reshape(list(labels.shape) + [self.classes])
|
199 |
+
k = onehot * data.T.repeat(self.classes, 1)
|
200 |
+
k[k == 0] = np.nan
|
201 |
+
med = np.nanmedian(k, 0)
|
202 |
+
for i in range(len(med)):
|
203 |
+
if med[i] != med[i]:
|
204 |
+
med[i] = 0 if i == 0 else med[i - 1]
|
205 |
+
mad = 5 * np.nan_to_num(
|
206 |
+
np.nanmedian(np.absolute(k - np.nanmedian(k, 0)), 0),
|
207 |
+
nan=self.domain / self.classes / 2,
|
208 |
+
)
|
209 |
+
arr = np.vstack(((med - mad)[:-1], (med + mad)[1:]))
|
210 |
+
self._forward_thres[:] = arr.max(0)
|
211 |
+
self._back_thres[:] = arr.min(0)
|
212 |
+
|
213 |
+
self._stat_forward = self._forward_thres.copy()
|
214 |
+
self._stat_back = self._back_thres.copy()
|
215 |
+
|
216 |
+
def fit(
|
217 |
+
self,
|
218 |
+
data: np.ndarray,
|
219 |
+
labels: np.ndarray,
|
220 |
+
early_stop_after: int = 10,
|
221 |
+
fixed: bool = True,
|
222 |
+
) -> None:
|
223 |
+
"""
|
224 |
+
Fit statistical thresholds for anomaly detection.
|
225 |
+
|
226 |
+
This method fits statistical thresholds for anomaly detection based on input data and labels.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
data (np.ndarray): The input data.
|
230 |
+
labels (np.ndarray): The labels corresponding to the data.
|
231 |
+
early_stop_after (int, optional): The number of consecutive early stops to consider. Default is 10.
|
232 |
+
fixed (bool, optional): Whether to use fixed thresholds. Default is True.
|
233 |
+
|
234 |
+
Note:
|
235 |
+
This method modifies the internal state of the object to set statistical thresholds.
|
236 |
+
"""
|
237 |
+
assert data.ndim == 1
|
238 |
+
assert labels.ndim <= 2
|
239 |
+
if self.classes == 1:
|
240 |
+
self.thresholds = np.array([0.5 + data.max()])
|
241 |
+
self.uncertainty = np.zeros_like(self.thresholds)
|
242 |
+
if data.ndim == 1:
|
243 |
+
data = np.expand_dims(data, 0)
|
244 |
+
thresholds = PaddedArray(
|
245 |
+
np.insert(
|
246 |
+
np.arange(self.domain - self.classes + 1, self.domain) - 1,
|
247 |
+
[0, self.classes - 1],
|
248 |
+
[-0.5, self.domain + 0.5],
|
249 |
+
).astype(int)
|
250 |
+
)
|
251 |
+
self._back_thres = thresholds.copy()
|
252 |
+
self._forward_thres = thresholds.copy()
|
253 |
+
self.statistic_fit(data, labels)
|
254 |
+
last = -0.5
|
255 |
+
for n in range(self.classes):
|
256 |
+
G = np.inf
|
257 |
+
Gis = deque([], early_stop_after)
|
258 |
+
# forward pass
|
259 |
+
if n < self.classes - 1:
|
260 |
+
new_forward_n: float = self._forward_thres[n]
|
261 |
+
for i in range(
|
262 |
+
max(0, self._back_thres[n - 1]) if n - 1 >= 0 else int(last),
|
263 |
+
min(self.domain, self._forward_thres[n + 1])
|
264 |
+
if n + 2 < self.classes
|
265 |
+
else self.domain - 1,
|
266 |
+
):
|
267 |
+
thresholds.array[n + 1] = i + 0.5
|
268 |
+
Gi = step_impurity(
|
269 |
+
thresholds.array, data, labels, classes=[n, n + 1]
|
270 |
+
)[0]
|
271 |
+
Gis.append(Gi)
|
272 |
+
if Gi <= G:
|
273 |
+
last = i + 0.5
|
274 |
+
new_forward_n = last
|
275 |
+
G = Gi
|
276 |
+
elif (
|
277 |
+
(not fixed or i - last > self.domain / self.classes)
|
278 |
+
and len(Gis) == early_stop_after
|
279 |
+
and all(
|
280 |
+
itertools.starmap(
|
281 |
+
operator.ge,
|
282 |
+
zip(Gis, itertools.islice(Gis, 1, early_stop_after)),
|
283 |
+
)
|
284 |
+
)
|
285 |
+
):
|
286 |
+
break
|
287 |
+
thresholds.array[n + 1] = new_forward_n
|
288 |
+
self._forward_thres.array[n + 1] = new_forward_n
|
289 |
+
self._back_thres.array[n + 1] = new_forward_n
|
290 |
+
G = np.inf
|
291 |
+
self._forward_thres = self._forward_thres.toarray().clip(
|
292 |
+
min=0, max=self.domain - 1
|
293 |
+
)
|
294 |
+
self._back_thres = self._back_thres.toarray().clip(min=0, max=self.domain - 1)
|
295 |
+
self.thresholds = (self._forward_thres + self._back_thres) / 2
|
296 |
+
self.uncertainty = np.abs(self._forward_thres - self._back_thres) / 2
|
297 |
+
|
298 |
+
@property
|
299 |
+
def score(self):
|
300 |
+
try:
|
301 |
+
return gini_impurity(self.thresholds, self._data, self._labels) / len(
|
302 |
+
self._data
|
303 |
+
)
|
304 |
+
except AttributeError:
|
305 |
+
return np.inf
|
306 |
+
|
307 |
+
def predict(self, x: np.ndarray) -> np.ndarray:
|
308 |
+
return stair_func(x, self.get_boundaries())
|
309 |
+
|
310 |
+
def __call__(self, *args):
|
311 |
+
return self.predict(*args)
|
312 |
+
|
313 |
+
def get_boundaries(self) -> np.ndarray:
|
314 |
+
return self.thresholds.astype(int).clip(min=0, max=self.domain - 1) + 0.5
|
nougat/dataset/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
nougat/dataset/utils/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
from nougat.dataset.utils.latex_conversion import *
|
8 |
+
from nougat.dataset.utils.utils import *
|
nougat/dataset/utils/latex_conversion.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
import re
|
8 |
+
from pylatexenc.latexencode import UnicodeToLatexEncoder
|
9 |
+
from pylatexenc.latex2text import LatexNodes2Text
|
10 |
+
from unidecode import unidecode
|
11 |
+
|
12 |
+
syn = [
|
13 |
+
("\\rbrack ", "] "),
|
14 |
+
("\\lbrack ", "[ "),
|
15 |
+
("\\lbrace ", "\\} "),
|
16 |
+
("\\rbrace ", "\\{ "),
|
17 |
+
("\\lnot ", "\\neg "),
|
18 |
+
("\\land ", "\\wedge "),
|
19 |
+
("\\vee ", "\\lor "),
|
20 |
+
("\\doublecup ", "\\Cup "),
|
21 |
+
("\\doublecap ", "\\Cap "),
|
22 |
+
("\\llless ", "\\lll "),
|
23 |
+
("\\gggtr ", "\\ggg "),
|
24 |
+
("\\doteqdot ", "\\Doteq "),
|
25 |
+
("\\ne ", "\\neq "),
|
26 |
+
("\\le ", "\\leq "),
|
27 |
+
("\\ge ", "\\geq "),
|
28 |
+
("\\leftarrow ", "\\gets "),
|
29 |
+
("\\rightarrow ", "\\to "),
|
30 |
+
("\\restriction ", "\\upharpoonright "),
|
31 |
+
("\\owns ", "\\ni "),
|
32 |
+
("\\textlnot ", "\\neg "),
|
33 |
+
("\\textellipsis ", "\\ldots "),
|
34 |
+
("\\textbullet ", "\\bullet "),
|
35 |
+
("\\plusmn ", "\\pm "),
|
36 |
+
("\\texttimes", "\\times"),
|
37 |
+
("\\textmu", "\\mu"),
|
38 |
+
("\\textendash", "-"),
|
39 |
+
("\\textemdash", "---"),
|
40 |
+
("\\>", "\\:"),
|
41 |
+
("\\medspace", "\\:"),
|
42 |
+
("\\thinspace", "\\,"),
|
43 |
+
("\\negthinspace", "\\!"),
|
44 |
+
("\\thickspace", "\\;"),
|
45 |
+
]
|
46 |
+
umlaut_mapping = {
|
47 |
+
"textasciicircum": "^",
|
48 |
+
"ddot": '"',
|
49 |
+
"textasciidieresis": '"',
|
50 |
+
"textasciicaron": "v ",
|
51 |
+
}
|
52 |
+
umlaut_keys = "|".join(umlaut_mapping.keys())
|
53 |
+
umlaut_regex = re.compile(rf"\s?\\({umlaut_keys})\s(\w)")
|
54 |
+
latex_comments = re.compile(r"(?<!\\)%.*\n")
|
55 |
+
toascii = UnicodeToLatexEncoder(
|
56 |
+
non_ascii_only=True, unknown_char_policy="ignore", unknown_char_warning=False
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
def remove_style(string: str) -> str:
|
61 |
+
return (
|
62 |
+
string.replace("\\displaystyle", "")
|
63 |
+
.replace("\\scriptstyle", "")
|
64 |
+
.replace("\\scriptscriptstyle", "")
|
65 |
+
.replace("\\textstyle", "")
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
def replace_duplicate_definitions(string: str) -> str:
|
70 |
+
"""In Latex there are many commands that are interchangeable. Use just one of them"""
|
71 |
+
for pair in syn:
|
72 |
+
string = string.replace(pair[0], pair[1])
|
73 |
+
return string
|
74 |
+
|
75 |
+
|
76 |
+
def unicode_to_latex(s: str) -> str:
|
77 |
+
s = re.sub(
|
78 |
+
r"\s{2,}",
|
79 |
+
" ",
|
80 |
+
re.sub(
|
81 |
+
r"\\ensuremath\s?\{\s?(.+?)\s?\}\s?",
|
82 |
+
r" \1 ",
|
83 |
+
toascii.unicode_to_latex(s.strip()),
|
84 |
+
)
|
85 |
+
.replace("}", " ")
|
86 |
+
.replace("{", " "),
|
87 |
+
)
|
88 |
+
s = (
|
89 |
+
s.strip()
|
90 |
+
.replace(
|
91 |
+
"\\textperiodcentered \\textperiodcentered \\textperiodcentered", "\\cdots"
|
92 |
+
)
|
93 |
+
.replace("\\textperiodcentered", "\\cdot")
|
94 |
+
.replace("\\textquoteright", "'")
|
95 |
+
.replace("\\textquoteleft", "'")
|
96 |
+
.replace("\\textquotedblleft", '"')
|
97 |
+
.replace("\\textquotedblright", '"')
|
98 |
+
)
|
99 |
+
s = umlaut_regex.sub(lambda x: "\\" + umlaut_mapping[x.group(1)] + x.group(2), s)
|
100 |
+
s = replace_duplicate_definitions(s)
|
101 |
+
s = unidecode(s)
|
102 |
+
return s.replace("\u2009", " ")
|
103 |
+
|
104 |
+
|
105 |
+
latex_to_unicode = LatexNodes2Text()
|
106 |
+
|
107 |
+
|
108 |
+
def remove_line_breaks(string: str) -> str:
|
109 |
+
string = latex_comments.sub("\n", string)
|
110 |
+
return string.replace("\n", " ")
|
111 |
+
|
112 |
+
|
113 |
+
def normalize_tex(math: str, inline: bool) -> str:
|
114 |
+
"""
|
115 |
+
Normalize TeX math expressions.
|
116 |
+
|
117 |
+
This function takes a TeX math expression and performs various normalization steps to ensure
|
118 |
+
consistency and proper formatting.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
math (str): The input TeX math expression.
|
122 |
+
inline (bool): Indicates whether the expression should be inline (True) or displayed (False).
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
str: The normalized TeX math expression.
|
126 |
+
"""
|
127 |
+
math = math.strip()
|
128 |
+
if not math:
|
129 |
+
return ""
|
130 |
+
if math.startswith(r"\(") or math.startswith(r"\[") or math.startswith("$$"):
|
131 |
+
math = math[2:]
|
132 |
+
elif math.startswith("$"):
|
133 |
+
math = math[1:]
|
134 |
+
if math.endswith(r"\)") or math.endswith(r"\]") or math.endswith("$$"):
|
135 |
+
math = math[:-2]
|
136 |
+
elif math.endswith("$"):
|
137 |
+
math = math[:-1]
|
138 |
+
math = math.strip()
|
139 |
+
if not math:
|
140 |
+
return ""
|
141 |
+
math = remove_line_breaks(math.strip())
|
142 |
+
math = replace_duplicate_definitions(math)
|
143 |
+
math = remove_style(math)
|
144 |
+
if inline:
|
145 |
+
return rf"\({math}\)"
|
146 |
+
return rf"\[{math}\]"
|
nougat/dataset/utils/pdf_text_extract.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
from io import StringIO
|
8 |
+
from typing import List
|
9 |
+
import re
|
10 |
+
from pdfminer.converter import TextConverter
|
11 |
+
from pdfminer.layout import LAParams
|
12 |
+
from pdfminer.pdfdocument import PDFDocument
|
13 |
+
from pdfminer.pdfinterp import PDFResourceManager, PDFPageInterpreter
|
14 |
+
from pdfminer.pdfpage import PDFPage
|
15 |
+
from pdfminer.pdfparser import PDFParser
|
16 |
+
|
17 |
+
|
18 |
+
def replace_ligatures(text: str) -> str:
|
19 |
+
ligatures = {
|
20 |
+
"ff": "ff",
|
21 |
+
"fi": "fi",
|
22 |
+
"fl": "fl",
|
23 |
+
"ffi": "ffi",
|
24 |
+
"ffl": "ffl",
|
25 |
+
"ſt": "ft",
|
26 |
+
"st": "st",
|
27 |
+
# "Ꜳ": "AA",
|
28 |
+
# "Æ": "AE",
|
29 |
+
"ꜳ": "aa",
|
30 |
+
}
|
31 |
+
for search, replace in ligatures.items():
|
32 |
+
text = text.replace(search, replace)
|
33 |
+
return text
|
34 |
+
|
35 |
+
|
36 |
+
def remove_hyphens(text: str) -> str:
|
37 |
+
"""
|
38 |
+
|
39 |
+
This fails for:
|
40 |
+
* Natural dashes: well-known, self-replication, use-cases, non-semantic,
|
41 |
+
Post-processing, Window-wise, viewpoint-dependent
|
42 |
+
* Trailing math operands: 2 - 4
|
43 |
+
* Names: Lopez-Ferreras, VGG-19, CIFAR-100
|
44 |
+
"""
|
45 |
+
lines = [line.rstrip() for line in text.split("\n")]
|
46 |
+
|
47 |
+
# Find dashes
|
48 |
+
line_numbers = []
|
49 |
+
for line_no, line in enumerate(lines[:-1]):
|
50 |
+
if line.endswith("-"):
|
51 |
+
line_numbers.append(line_no)
|
52 |
+
|
53 |
+
# Replace
|
54 |
+
for line_no in line_numbers:
|
55 |
+
lines = dehyphenate(lines, line_no)
|
56 |
+
return "\n".join(lines)
|
57 |
+
|
58 |
+
|
59 |
+
def dehyphenate(lines: List[str], line_no: int) -> List[str]:
|
60 |
+
next_line = lines[line_no + 1]
|
61 |
+
word_suffix = next_line.split(" ")[0]
|
62 |
+
|
63 |
+
lines[line_no] = lines[line_no][:-1] + word_suffix
|
64 |
+
lines[line_no + 1] = lines[line_no + 1][len(word_suffix) :]
|
65 |
+
return lines
|
66 |
+
|
67 |
+
|
68 |
+
def get_pages(pdf: str) -> List[str]:
|
69 |
+
out = []
|
70 |
+
with open(pdf, "rb") as in_file:
|
71 |
+
parser = PDFParser(in_file)
|
72 |
+
doc = PDFDocument(parser)
|
73 |
+
rsrcmgr = PDFResourceManager()
|
74 |
+
|
75 |
+
for page in PDFPage.create_pages(doc):
|
76 |
+
output_string = StringIO()
|
77 |
+
device = TextConverter(rsrcmgr, output_string, laparams=LAParams())
|
78 |
+
interpreter = PDFPageInterpreter(rsrcmgr, device)
|
79 |
+
interpreter.process_page(page)
|
80 |
+
out.append(remove_hyphens(replace_ligatures(output_string.getvalue())))
|
81 |
+
return out
|
82 |
+
|
83 |
+
|
84 |
+
def get_paragraphs(pdf: str) -> List[List[str]]:
|
85 |
+
pages = get_pages(pdf)
|
86 |
+
return [re.sub(r"\n{3,}", "\n\n", txt).split("\n\n") for txt in pages]
|
nougat/dataset/utils/utils.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
import re
|
8 |
+
|
9 |
+
|
10 |
+
def remove_pretty_linebreaks(string: str) -> str:
|
11 |
+
"""replaces linebreaks with spaces when there would be no
|
12 |
+
difference between them in the markdown format
|
13 |
+
|
14 |
+
Args:
|
15 |
+
string (str): String to process
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
str: Formatted string
|
19 |
+
"""
|
20 |
+
return re.sub(r"(?<!\n)\n([^\n\d\*#\[])", r" \1", string).strip()
|
nougat/metrics.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
import argparse
|
8 |
+
from multiprocessing import Pool
|
9 |
+
import re
|
10 |
+
from pathlib import Path
|
11 |
+
from collections import defaultdict
|
12 |
+
from typing import List
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
import nltk
|
17 |
+
from nltk import edit_distance
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
import orjson
|
21 |
+
|
22 |
+
inline_reg = re.compile(r"\\\((.*?)(?<!\\)\\\)")
|
23 |
+
display_reg = re.compile(r"\\\[(.+?)(?<!\\)\\\]")
|
24 |
+
table_reg = re.compile(r"\\begin\{tabular\}(.+?)(?:\\end\{tabular\}|$)", re.S)
|
25 |
+
|
26 |
+
|
27 |
+
def compute_metrics(pred, gt, minlen=4):
|
28 |
+
metrics = {}
|
29 |
+
if len(pred) < minlen or len(gt) < minlen:
|
30 |
+
return metrics
|
31 |
+
metrics["edit_dist"] = edit_distance(pred, gt) / max(len(pred), len(gt))
|
32 |
+
reference = gt.split()
|
33 |
+
hypothesis = pred.split()
|
34 |
+
metrics["bleu"] = nltk.translate.bleu([reference], hypothesis)
|
35 |
+
try:
|
36 |
+
metrics["meteor"] = nltk.translate.meteor([reference], hypothesis)
|
37 |
+
except LookupError:
|
38 |
+
metrics["meteor"] = np.nan
|
39 |
+
reference = set(reference)
|
40 |
+
hypothesis = set(hypothesis)
|
41 |
+
metrics["precision"] = nltk.scores.precision(reference, hypothesis)
|
42 |
+
metrics["recall"] = nltk.scores.recall(reference, hypothesis)
|
43 |
+
metrics["f_measure"] = nltk.scores.f_measure(reference, hypothesis)
|
44 |
+
return metrics
|
45 |
+
|
46 |
+
|
47 |
+
def get_parser():
|
48 |
+
parser = argparse.ArgumentParser()
|
49 |
+
parser.add_argument("json", type=Path, help="results file")
|
50 |
+
parser.add_argument(
|
51 |
+
"-N", dest="N", type=int, help="number of samples", default=None
|
52 |
+
)
|
53 |
+
args = parser.parse_args()
|
54 |
+
d = orjson.loads(args.json.read_text(encoding="utf-8"))
|
55 |
+
args.pred = d["predictions"]
|
56 |
+
args.gt = d["ground_truths"]
|
57 |
+
if args.N is not None:
|
58 |
+
args.pred = args.pred[: args.N]
|
59 |
+
args.gt = args.gt[: args.N]
|
60 |
+
return args
|
61 |
+
|
62 |
+
|
63 |
+
def split_text(pages: List[str]):
|
64 |
+
"""
|
65 |
+
Split a list of pages into text, inline math, display math, and table blocks.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
pages: The pages to split.
|
69 |
+
"""
|
70 |
+
text, math, table = [], [], []
|
71 |
+
for page in pages:
|
72 |
+
for i, reg in enumerate([inline_reg, display_reg, table_reg]):
|
73 |
+
matches = "\n".join(reg.findall(page))
|
74 |
+
if i == 2:
|
75 |
+
table.append(matches)
|
76 |
+
elif i == 1:
|
77 |
+
math[-1] += matches
|
78 |
+
else:
|
79 |
+
math.append(matches)
|
80 |
+
page = reg.sub("", page)
|
81 |
+
text.append(page.strip())
|
82 |
+
|
83 |
+
return text, math, table
|
84 |
+
|
85 |
+
|
86 |
+
def get_metrics(gt: List[str], pred: List[str], pool: bool = True):
|
87 |
+
metrics = defaultdict(list)
|
88 |
+
if pool:
|
89 |
+
with Pool() as p:
|
90 |
+
_metrics = p.starmap(compute_metrics, iterable=zip(pred, gt))
|
91 |
+
else:
|
92 |
+
_metrics = [compute_metrics(p, g) for p, g in zip(pred, gt)]
|
93 |
+
for m in _metrics:
|
94 |
+
for key, value in m.items():
|
95 |
+
metrics[key].append(value)
|
96 |
+
return dict(metrics)
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
args = get_parser()
|
101 |
+
for name, entries in zip(["gt", "pred"], [args.gt, args.pred]):
|
102 |
+
full: Path = args.json.parent / (args.json.stem + "_" + name + "_full.mmd")
|
103 |
+
full.write_text("\n\n------------------\n\n".join(entries))
|
104 |
+
for i, (gt, pr) in enumerate(zip(split_text(args.gt), split_text(args.pred))):
|
105 |
+
sub = ["Text", "Math", "Tables"][i]
|
106 |
+
prpath: Path = args.json.parent / (
|
107 |
+
args.json.stem + "_pred_" + sub.lower() + ".mmd"
|
108 |
+
)
|
109 |
+
prpath.write_text("\n\n------------------\n\n".join(pr))
|
110 |
+
gtpath: Path = args.json.parent / (
|
111 |
+
args.json.stem + "_gt_" + sub.lower() + ".mmd"
|
112 |
+
)
|
113 |
+
gtpath.write_text("\n\n------------------\n\n".join(gt))
|
114 |
+
print("Results for", sub)
|
115 |
+
|
116 |
+
metrics = get_metrics(gt, pr)
|
117 |
+
print({key: sum(values) / len(values) for key, values in metrics.items()})
|
nougat/model.py
ADDED
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Donut
|
3 |
+
Copyright (c) 2022-present NAVER Corp.
|
4 |
+
MIT License
|
5 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
6 |
+
"""
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
from typing import List, Optional, Union
|
11 |
+
from collections import defaultdict
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
from PIL import Image
|
16 |
+
import cv2
|
17 |
+
import timm
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from PIL import ImageOps
|
22 |
+
from timm.models.swin_transformer import SwinTransformer
|
23 |
+
from torchvision.transforms.functional import resize, rotate
|
24 |
+
from transformers import (
|
25 |
+
PreTrainedTokenizerFast,
|
26 |
+
StoppingCriteria,
|
27 |
+
StoppingCriteriaList,
|
28 |
+
MBartConfig,
|
29 |
+
MBartForCausalLM,
|
30 |
+
)
|
31 |
+
from transformers.file_utils import ModelOutput
|
32 |
+
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
33 |
+
from nougat.postprocessing import postprocess
|
34 |
+
from nougat.transforms import train_transform, test_transform
|
35 |
+
|
36 |
+
|
37 |
+
class SwinEncoder(nn.Module):
|
38 |
+
r"""
|
39 |
+
Encoder based on SwinTransformer
|
40 |
+
Set the initial weights and configuration with a pretrained SwinTransformer and then
|
41 |
+
modify the detailed configurations
|
42 |
+
|
43 |
+
Args:
|
44 |
+
input_size: Input image size (width, height)
|
45 |
+
align_long_axis: Whether to rotate image if height is greater than width
|
46 |
+
window_size: Window size(=patch size) of SwinTransformer
|
47 |
+
encoder_layer: Number of layers of SwinTransformer encoder
|
48 |
+
name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local.
|
49 |
+
otherwise, `swin_base_patch4_window12_384` will be set (using `timm`).
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
input_size: List[int],
|
55 |
+
align_long_axis: bool,
|
56 |
+
window_size: int,
|
57 |
+
encoder_layer: List[int],
|
58 |
+
patch_size: int,
|
59 |
+
embed_dim: int,
|
60 |
+
num_heads: List[int],
|
61 |
+
name_or_path: Union[str, bytes, os.PathLike] = None,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
self.input_size = input_size
|
65 |
+
self.align_long_axis = align_long_axis
|
66 |
+
self.window_size = window_size
|
67 |
+
self.encoder_layer = encoder_layer
|
68 |
+
self.patch_size = patch_size
|
69 |
+
self.embed_dim = embed_dim
|
70 |
+
self.num_heads = num_heads
|
71 |
+
|
72 |
+
self.model = SwinTransformer(
|
73 |
+
img_size=self.input_size,
|
74 |
+
depths=self.encoder_layer,
|
75 |
+
window_size=self.window_size,
|
76 |
+
patch_size=self.patch_size,
|
77 |
+
embed_dim=self.embed_dim,
|
78 |
+
num_heads=self.num_heads,
|
79 |
+
num_classes=0,
|
80 |
+
)
|
81 |
+
|
82 |
+
# weight init with swin
|
83 |
+
if not name_or_path:
|
84 |
+
swin_state_dict = timm.create_model(
|
85 |
+
"swin_base_patch4_window12_384", pretrained=True
|
86 |
+
).state_dict()
|
87 |
+
new_swin_state_dict = self.model.state_dict()
|
88 |
+
for x in new_swin_state_dict:
|
89 |
+
if x.endswith("relative_position_index") or x.endswith("attn_mask"):
|
90 |
+
pass
|
91 |
+
elif (
|
92 |
+
x.endswith("relative_position_bias_table")
|
93 |
+
and self.model.layers[0].blocks[0].attn.window_size[0] != 12
|
94 |
+
):
|
95 |
+
pos_bias = swin_state_dict[x].unsqueeze(0)[0]
|
96 |
+
old_len = int(math.sqrt(len(pos_bias)))
|
97 |
+
new_len = int(2 * window_size - 1)
|
98 |
+
pos_bias = pos_bias.reshape(1, old_len, old_len, -1).permute(
|
99 |
+
0, 3, 1, 2
|
100 |
+
)
|
101 |
+
pos_bias = F.interpolate(
|
102 |
+
pos_bias,
|
103 |
+
size=(new_len, new_len),
|
104 |
+
mode="bicubic",
|
105 |
+
align_corners=False,
|
106 |
+
)
|
107 |
+
new_swin_state_dict[x] = (
|
108 |
+
pos_bias.permute(0, 2, 3, 1)
|
109 |
+
.reshape(1, new_len**2, -1)
|
110 |
+
.squeeze(0)
|
111 |
+
)
|
112 |
+
else:
|
113 |
+
new_swin_state_dict[x] = swin_state_dict[x]
|
114 |
+
self.model.load_state_dict(new_swin_state_dict)
|
115 |
+
|
116 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
117 |
+
"""
|
118 |
+
Args:
|
119 |
+
x: (batch_size, num_channels, height, width)
|
120 |
+
"""
|
121 |
+
x = self.model.patch_embed(x)
|
122 |
+
x = self.model.pos_drop(x)
|
123 |
+
x = self.model.layers(x)
|
124 |
+
return x
|
125 |
+
|
126 |
+
@staticmethod
|
127 |
+
def crop_margin(img: Image.Image) -> Image.Image:
|
128 |
+
data = np.array(img.convert("L"))
|
129 |
+
data = data.astype(np.uint8)
|
130 |
+
max_val = data.max()
|
131 |
+
min_val = data.min()
|
132 |
+
if max_val == min_val:
|
133 |
+
return img
|
134 |
+
data = (data - min_val) / (max_val - min_val) * 255
|
135 |
+
gray = 255 * (data < 200).astype(np.uint8)
|
136 |
+
|
137 |
+
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
|
138 |
+
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
|
139 |
+
return img.crop((a, b, w + a, h + b))
|
140 |
+
|
141 |
+
@property
|
142 |
+
def to_tensor(self):
|
143 |
+
if self.training:
|
144 |
+
return train_transform
|
145 |
+
else:
|
146 |
+
return test_transform
|
147 |
+
|
148 |
+
def prepare_input(
|
149 |
+
self, img: Image.Image, random_padding: bool = False
|
150 |
+
) -> torch.Tensor:
|
151 |
+
"""
|
152 |
+
Convert PIL Image to tensor according to specified input_size after following steps below:
|
153 |
+
- resize
|
154 |
+
- rotate (if align_long_axis is True and image is not aligned longer axis with canvas)
|
155 |
+
- pad
|
156 |
+
"""
|
157 |
+
if img is None:
|
158 |
+
return
|
159 |
+
# crop margins
|
160 |
+
try:
|
161 |
+
img = self.crop_margin(img.convert("RGB"))
|
162 |
+
except OSError:
|
163 |
+
# might throw an error for broken files
|
164 |
+
return
|
165 |
+
if img.height == 0 or img.width == 0:
|
166 |
+
return
|
167 |
+
if self.align_long_axis and (
|
168 |
+
(self.input_size[0] > self.input_size[1] and img.width > img.height)
|
169 |
+
or (self.input_size[0] < self.input_size[1] and img.width < img.height)
|
170 |
+
):
|
171 |
+
img = rotate(img, angle=-90, expand=True)
|
172 |
+
img = resize(img, min(self.input_size))
|
173 |
+
img.thumbnail((self.input_size[1], self.input_size[0]))
|
174 |
+
delta_width = self.input_size[1] - img.width
|
175 |
+
delta_height = self.input_size[0] - img.height
|
176 |
+
if random_padding:
|
177 |
+
pad_width = np.random.randint(low=0, high=delta_width + 1)
|
178 |
+
pad_height = np.random.randint(low=0, high=delta_height + 1)
|
179 |
+
else:
|
180 |
+
pad_width = delta_width // 2
|
181 |
+
pad_height = delta_height // 2
|
182 |
+
padding = (
|
183 |
+
pad_width,
|
184 |
+
pad_height,
|
185 |
+
delta_width - pad_width,
|
186 |
+
delta_height - pad_height,
|
187 |
+
)
|
188 |
+
return self.to_tensor(ImageOps.expand(img, padding))
|
189 |
+
|
190 |
+
|
191 |
+
class BARTDecoder(nn.Module):
|
192 |
+
"""
|
193 |
+
Decoder based on Multilingual BART
|
194 |
+
Set the initial weights and configuration with a pretrained multilingual BART model,
|
195 |
+
and modify the detailed configurations as a Nougat decoder
|
196 |
+
|
197 |
+
Args:
|
198 |
+
decoder_layer:
|
199 |
+
Number of layers of BARTDecoder
|
200 |
+
max_position_embeddings:
|
201 |
+
The maximum sequence length to be trained
|
202 |
+
name_or_path:
|
203 |
+
Name of a pretrained model name either registered in huggingface.co. or saved in local,
|
204 |
+
otherwise, `facebook/mbart-large-50` will be set (using `transformers`)
|
205 |
+
"""
|
206 |
+
|
207 |
+
def __init__(
|
208 |
+
self,
|
209 |
+
decoder_layer: int,
|
210 |
+
max_position_embeddings: int,
|
211 |
+
hidden_dimension: int = 1024,
|
212 |
+
name_or_path: Union[str, bytes, os.PathLike] = None,
|
213 |
+
):
|
214 |
+
super().__init__()
|
215 |
+
self.decoder_layer = decoder_layer
|
216 |
+
self.max_position_embeddings = max_position_embeddings
|
217 |
+
if not name_or_path:
|
218 |
+
tokenizer_file = Path(__file__).parent / "dataset" / "tokenizer.json"
|
219 |
+
else:
|
220 |
+
tokenizer_file = Path(name_or_path) / "tokenizer.json"
|
221 |
+
if not tokenizer_file.exists():
|
222 |
+
raise ValueError("Could not find tokenizer file")
|
223 |
+
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_file))
|
224 |
+
self.tokenizer.pad_token = "<pad>"
|
225 |
+
self.tokenizer.bos_token = "<s>"
|
226 |
+
self.tokenizer.eos_token = "</s>"
|
227 |
+
self.tokenizer.unk_token = "<unk>"
|
228 |
+
|
229 |
+
self.model = MBartForCausalLM(
|
230 |
+
config=MBartConfig(
|
231 |
+
is_decoder=True,
|
232 |
+
is_encoder_decoder=False,
|
233 |
+
add_cross_attention=True,
|
234 |
+
decoder_layers=self.decoder_layer,
|
235 |
+
max_position_embeddings=self.max_position_embeddings,
|
236 |
+
vocab_size=len(self.tokenizer),
|
237 |
+
scale_embedding=True,
|
238 |
+
add_final_layer_norm=True,
|
239 |
+
d_model=hidden_dimension,
|
240 |
+
)
|
241 |
+
)
|
242 |
+
self.model.config.is_encoder_decoder = True # to get cross-attention
|
243 |
+
self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id
|
244 |
+
self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference
|
245 |
+
|
246 |
+
if not name_or_path:
|
247 |
+
bart_state_dict = MBartForCausalLM.from_pretrained(
|
248 |
+
"facebook/mbart-large-50"
|
249 |
+
).state_dict()
|
250 |
+
new_bart_state_dict = self.model.state_dict()
|
251 |
+
for x in new_bart_state_dict:
|
252 |
+
if (
|
253 |
+
x.endswith("embed_positions.weight")
|
254 |
+
and self.max_position_embeddings != 1024
|
255 |
+
):
|
256 |
+
new_bart_state_dict[x] = torch.nn.Parameter(
|
257 |
+
self.resize_bart_abs_pos_emb(
|
258 |
+
bart_state_dict[x],
|
259 |
+
self.max_position_embeddings
|
260 |
+
+ 2, # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119
|
261 |
+
)
|
262 |
+
)
|
263 |
+
elif x.endswith("embed_tokens.weight") or x.endswith("lm_head.weight"):
|
264 |
+
new_bart_state_dict[x] = bart_state_dict[x][
|
265 |
+
: len(self.tokenizer), :
|
266 |
+
]
|
267 |
+
else:
|
268 |
+
new_bart_state_dict[x] = bart_state_dict[x]
|
269 |
+
self.model.load_state_dict(new_bart_state_dict, strict=False)
|
270 |
+
|
271 |
+
def add_special_tokens(self, list_of_tokens: List[str]):
|
272 |
+
"""
|
273 |
+
Add special tokens to tokenizer and resize the token embeddings
|
274 |
+
"""
|
275 |
+
newly_added_num = self.tokenizer.add_special_tokens(
|
276 |
+
{"additional_special_tokens": sorted(set(list_of_tokens))}
|
277 |
+
)
|
278 |
+
if newly_added_num > 0:
|
279 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
280 |
+
|
281 |
+
def prepare_inputs_for_inference(
|
282 |
+
self,
|
283 |
+
input_ids: torch.Tensor,
|
284 |
+
encoder_outputs: torch.Tensor,
|
285 |
+
past=None,
|
286 |
+
past_key_values=None,
|
287 |
+
use_cache: bool = None,
|
288 |
+
attention_mask: torch.Tensor = None,
|
289 |
+
):
|
290 |
+
"""
|
291 |
+
Args:
|
292 |
+
input_ids: (batch_size, sequence_length)
|
293 |
+
|
294 |
+
Returns:
|
295 |
+
input_ids: (batch_size, sequence_length)
|
296 |
+
attention_mask: (batch_size, sequence_length)
|
297 |
+
encoder_hidden_states: (batch_size, sequence_length, embedding_dim)
|
298 |
+
"""
|
299 |
+
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
|
300 |
+
past = past or past_key_values
|
301 |
+
if past is not None:
|
302 |
+
input_ids = input_ids[:, -1:]
|
303 |
+
output = {
|
304 |
+
"input_ids": input_ids,
|
305 |
+
"attention_mask": attention_mask,
|
306 |
+
"past_key_values": past,
|
307 |
+
"use_cache": use_cache,
|
308 |
+
"encoder_hidden_states": encoder_outputs.last_hidden_state,
|
309 |
+
}
|
310 |
+
return output
|
311 |
+
|
312 |
+
def forward(
|
313 |
+
self,
|
314 |
+
input_ids,
|
315 |
+
attention_mask: Optional[torch.Tensor] = None,
|
316 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
317 |
+
past_key_values: Optional[torch.Tensor] = None,
|
318 |
+
labels: Optional[torch.Tensor] = None,
|
319 |
+
use_cache: bool = None,
|
320 |
+
output_attentions: Optional[torch.Tensor] = None,
|
321 |
+
output_hidden_states: Optional[torch.Tensor] = None,
|
322 |
+
return_dict: bool = None,
|
323 |
+
):
|
324 |
+
return self.model.forward(
|
325 |
+
input_ids,
|
326 |
+
attention_mask=attention_mask,
|
327 |
+
labels=labels,
|
328 |
+
encoder_hidden_states=encoder_hidden_states,
|
329 |
+
past_key_values=past_key_values,
|
330 |
+
use_cache=use_cache,
|
331 |
+
output_attentions=output_attentions,
|
332 |
+
output_hidden_states=output_hidden_states,
|
333 |
+
return_dict=return_dict,
|
334 |
+
)
|
335 |
+
|
336 |
+
@staticmethod
|
337 |
+
def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tensor:
|
338 |
+
"""
|
339 |
+
Resize position embeddings
|
340 |
+
Truncate if sequence length of MBart backbone is greater than given max_length,
|
341 |
+
else interpolate to max_length
|
342 |
+
"""
|
343 |
+
if weight.shape[0] > max_length:
|
344 |
+
weight = weight[:max_length, ...]
|
345 |
+
else:
|
346 |
+
weight = (
|
347 |
+
F.interpolate(
|
348 |
+
weight.permute(1, 0).unsqueeze(0),
|
349 |
+
size=max_length,
|
350 |
+
mode="linear",
|
351 |
+
align_corners=False,
|
352 |
+
)
|
353 |
+
.squeeze(0)
|
354 |
+
.permute(1, 0)
|
355 |
+
)
|
356 |
+
return weight
|
357 |
+
|
358 |
+
|
359 |
+
class NougatConfig(PretrainedConfig):
|
360 |
+
r"""
|
361 |
+
This is the configuration class to store the configuration of a [`NougatModel`]. It is used to
|
362 |
+
instantiate a Nougat model according to the specified arguments, defining the model architecture
|
363 |
+
|
364 |
+
Args:
|
365 |
+
input_size:
|
366 |
+
Input image size (canvas size) of Nougat.encoder, SwinTransformer in this codebase
|
367 |
+
align_long_axis:
|
368 |
+
Whether to rotate image if height is greater than width
|
369 |
+
window_size:
|
370 |
+
Window size of Nougat.encoder, SwinTransformer in this codebase
|
371 |
+
encoder_layer:
|
372 |
+
Depth of each Nougat.encoder Encoder layer, SwinTransformer in this codebase
|
373 |
+
decoder_layer:
|
374 |
+
Number of hidden layers in the Nougat.decoder, such as BART
|
375 |
+
max_position_embeddings
|
376 |
+
Trained max position embeddings in the Nougat decoder,
|
377 |
+
if not specified, it will have same value with max_length
|
378 |
+
max_length:
|
379 |
+
Max position embeddings(=maximum sequence length) you want to train
|
380 |
+
name_or_path:
|
381 |
+
Name of a pretrained model name either registered in huggingface.co. or saved in local
|
382 |
+
"""
|
383 |
+
model_type = "nougat"
|
384 |
+
|
385 |
+
def __init__(
|
386 |
+
self,
|
387 |
+
input_size: List[int] = [896, 672],
|
388 |
+
align_long_axis: bool = False,
|
389 |
+
window_size: int = 7,
|
390 |
+
encoder_layer: List[int] = [2, 2, 14, 2],
|
391 |
+
decoder_layer: int = 10,
|
392 |
+
max_position_embeddings: int = None,
|
393 |
+
max_length: int = 4096,
|
394 |
+
name_or_path: Union[str, bytes, os.PathLike] = "",
|
395 |
+
patch_size: int = 4,
|
396 |
+
embed_dim: int = 128,
|
397 |
+
num_heads: List[int] = [4, 8, 16, 32],
|
398 |
+
hidden_dimension: int = 1024,
|
399 |
+
**kwargs,
|
400 |
+
):
|
401 |
+
super().__init__()
|
402 |
+
self.input_size = input_size
|
403 |
+
self.align_long_axis = align_long_axis
|
404 |
+
self.window_size = window_size
|
405 |
+
self.encoder_layer = encoder_layer
|
406 |
+
self.decoder_layer = decoder_layer
|
407 |
+
self.max_position_embeddings = (
|
408 |
+
max_length if max_position_embeddings is None else max_position_embeddings
|
409 |
+
)
|
410 |
+
self.max_length = max_length
|
411 |
+
self.name_or_path = name_or_path
|
412 |
+
self.patch_size = patch_size
|
413 |
+
self.embed_dim = embed_dim
|
414 |
+
self.num_heads = num_heads
|
415 |
+
self.hidden_dimension = hidden_dimension
|
416 |
+
|
417 |
+
|
418 |
+
class RunningVarTorch:
|
419 |
+
def __init__(self, L=15, norm=False):
|
420 |
+
self.values = None
|
421 |
+
self.L = L
|
422 |
+
self.norm = norm
|
423 |
+
|
424 |
+
def push(self, x: torch.Tensor):
|
425 |
+
assert x.dim() == 1
|
426 |
+
if self.values is None:
|
427 |
+
self.values = x[:, None]
|
428 |
+
elif self.values.shape[1] < self.L:
|
429 |
+
self.values = torch.cat((self.values, x[:, None]), 1)
|
430 |
+
else:
|
431 |
+
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
|
432 |
+
|
433 |
+
def variance(self):
|
434 |
+
if self.values is None:
|
435 |
+
return
|
436 |
+
if self.norm:
|
437 |
+
return torch.var(self.values, 1) / self.values.shape[1]
|
438 |
+
else:
|
439 |
+
return torch.var(self.values, 1)
|
440 |
+
|
441 |
+
|
442 |
+
class StoppingCriteriaScores(StoppingCriteria):
|
443 |
+
def __init__(self, threshold: float = 0.015, window_size: int = 200):
|
444 |
+
super().__init__()
|
445 |
+
self.threshold = threshold
|
446 |
+
self.vars = RunningVarTorch(norm=True)
|
447 |
+
self.varvars = RunningVarTorch(L=window_size)
|
448 |
+
self.stop_inds = defaultdict(int)
|
449 |
+
self.stopped = defaultdict(bool)
|
450 |
+
self.size = 0
|
451 |
+
self.window_size = window_size
|
452 |
+
|
453 |
+
@torch.no_grad()
|
454 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
455 |
+
last_scores = scores[-1]
|
456 |
+
self.vars.push(last_scores.max(1)[0].float().cpu())
|
457 |
+
self.varvars.push(self.vars.variance())
|
458 |
+
self.size += 1
|
459 |
+
if self.size < self.window_size:
|
460 |
+
return False
|
461 |
+
|
462 |
+
varvar = self.varvars.variance()
|
463 |
+
for b in range(len(last_scores)):
|
464 |
+
if varvar[b] < self.threshold:
|
465 |
+
if self.stop_inds[b] > 0 and not self.stopped[b]:
|
466 |
+
self.stopped[b] = self.stop_inds[b] >= self.size
|
467 |
+
else:
|
468 |
+
self.stop_inds[b] = int(
|
469 |
+
min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)
|
470 |
+
)
|
471 |
+
else:
|
472 |
+
self.stop_inds[b] = 0
|
473 |
+
self.stopped[b] = False
|
474 |
+
return all(self.stopped.values()) and len(self.stopped) > 0
|
475 |
+
|
476 |
+
|
477 |
+
def batch(l, b=15):
|
478 |
+
subs = []
|
479 |
+
for i in range(len(l) - b):
|
480 |
+
subs.append(l[i : i + b])
|
481 |
+
return subs
|
482 |
+
|
483 |
+
|
484 |
+
def subdiv(l, b=10):
|
485 |
+
subs = []
|
486 |
+
for i in range(len(l) - b):
|
487 |
+
subs.append(l[: i + b])
|
488 |
+
return subs
|
489 |
+
|
490 |
+
|
491 |
+
class NougatModel(PreTrainedModel):
|
492 |
+
r"""
|
493 |
+
Nougat: Neural Optical UnderstandinG for Academic documents.
|
494 |
+
The encoder converts an image of an academic document into a series of embeddings.
|
495 |
+
Then, the decoder generates a sequence of tokens based on encoder's output.
|
496 |
+
This sequence can be translated into a structured markup language format.
|
497 |
+
"""
|
498 |
+
config_class = NougatConfig
|
499 |
+
base_model_prefix = "nougat"
|
500 |
+
|
501 |
+
def __init__(self, config: NougatConfig):
|
502 |
+
super().__init__(config)
|
503 |
+
self.config = config
|
504 |
+
self.encoder = SwinEncoder(
|
505 |
+
input_size=self.config.input_size,
|
506 |
+
align_long_axis=self.config.align_long_axis,
|
507 |
+
window_size=self.config.window_size,
|
508 |
+
encoder_layer=self.config.encoder_layer,
|
509 |
+
name_or_path=self.config.name_or_path,
|
510 |
+
patch_size=self.config.patch_size,
|
511 |
+
embed_dim=self.config.embed_dim,
|
512 |
+
num_heads=self.config.num_heads,
|
513 |
+
)
|
514 |
+
self.decoder = BARTDecoder(
|
515 |
+
max_position_embeddings=self.config.max_position_embeddings,
|
516 |
+
decoder_layer=self.config.decoder_layer,
|
517 |
+
name_or_path=self.config.name_or_path,
|
518 |
+
hidden_dimension=self.config.hidden_dimension,
|
519 |
+
)
|
520 |
+
|
521 |
+
def forward(
|
522 |
+
self,
|
523 |
+
image_tensors: torch.Tensor,
|
524 |
+
decoder_input_ids: torch.Tensor,
|
525 |
+
attention_mask: Optional[torch.Tensor] = None,
|
526 |
+
):
|
527 |
+
"""
|
528 |
+
Calculate a loss given an input image and a desired token sequence,
|
529 |
+
the model will be trained in a teacher-forcing manner
|
530 |
+
|
531 |
+
Args:
|
532 |
+
image_tensors: (batch_size, num_channels, height, width)
|
533 |
+
decoder_input_ids: (batch_size, sequence_length, embedding_dim)
|
534 |
+
"""
|
535 |
+
encoder_outputs = self.encoder(image_tensors)
|
536 |
+
decoder_outputs = self.decoder(
|
537 |
+
input_ids=decoder_input_ids[:, :-1].contiguous(),
|
538 |
+
encoder_hidden_states=encoder_outputs,
|
539 |
+
attention_mask=attention_mask[:, :-1],
|
540 |
+
labels=decoder_input_ids[:, 1:].contiguous(),
|
541 |
+
)
|
542 |
+
return decoder_outputs
|
543 |
+
|
544 |
+
def _init_weights(self, *args, **kwargs):
|
545 |
+
return
|
546 |
+
|
547 |
+
def inference(
|
548 |
+
self,
|
549 |
+
image: Image.Image = None,
|
550 |
+
image_tensors: Optional[torch.Tensor] = None,
|
551 |
+
return_attentions: bool = False,
|
552 |
+
early_stopping: bool = True,
|
553 |
+
):
|
554 |
+
"""
|
555 |
+
Generate a token sequence in an auto-regressive manner.
|
556 |
+
|
557 |
+
Args:
|
558 |
+
image: input document image (PIL.Image)
|
559 |
+
image_tensors: (1, num_channels, height, width)
|
560 |
+
convert prompt to tensor if image_tensor is not fed
|
561 |
+
"""
|
562 |
+
output = {
|
563 |
+
"predictions": list(),
|
564 |
+
"sequences": list(),
|
565 |
+
"repeats": list(),
|
566 |
+
"repetitions": list(),
|
567 |
+
}
|
568 |
+
if image is None and image_tensors is None:
|
569 |
+
logging.warn("Image not found")
|
570 |
+
return output
|
571 |
+
|
572 |
+
if image_tensors is None:
|
573 |
+
image_tensors = self.encoder.prepare_input(image).unsqueeze(0)
|
574 |
+
|
575 |
+
if self.device.type != "mps":
|
576 |
+
image_tensors = image_tensors.to(next(self.parameters()).dtype)
|
577 |
+
|
578 |
+
image_tensors = image_tensors.to(self.device)
|
579 |
+
|
580 |
+
last_hidden_state = self.encoder(image_tensors)
|
581 |
+
|
582 |
+
encoder_outputs = ModelOutput(
|
583 |
+
last_hidden_state=last_hidden_state, attentions=None
|
584 |
+
)
|
585 |
+
|
586 |
+
if len(encoder_outputs.last_hidden_state.size()) == 1:
|
587 |
+
encoder_outputs.last_hidden_state = (
|
588 |
+
encoder_outputs.last_hidden_state.unsqueeze(0)
|
589 |
+
)
|
590 |
+
|
591 |
+
# get decoder output
|
592 |
+
decoder_output = self.decoder.model.generate(
|
593 |
+
encoder_outputs=encoder_outputs,
|
594 |
+
min_length=1,
|
595 |
+
max_length=self.config.max_length,
|
596 |
+
pad_token_id=self.decoder.tokenizer.pad_token_id,
|
597 |
+
eos_token_id=self.decoder.tokenizer.eos_token_id,
|
598 |
+
use_cache=True,
|
599 |
+
bad_words_ids=[
|
600 |
+
[self.decoder.tokenizer.unk_token_id],
|
601 |
+
],
|
602 |
+
return_dict_in_generate=True,
|
603 |
+
output_scores=True,
|
604 |
+
output_attentions=return_attentions,
|
605 |
+
do_sample=False,
|
606 |
+
stopping_criteria=StoppingCriteriaList(
|
607 |
+
[StoppingCriteriaScores()] if early_stopping else []
|
608 |
+
),
|
609 |
+
)
|
610 |
+
output["repetitions"] = decoder_output.sequences.clone()
|
611 |
+
output["sequences"] = decoder_output.sequences.clone()
|
612 |
+
batch_size = len(decoder_output.sequences)
|
613 |
+
|
614 |
+
logits = torch.stack(decoder_output.scores, 1).cpu().max(-1)
|
615 |
+
values = logits.values
|
616 |
+
indices = logits.indices
|
617 |
+
|
618 |
+
for b in range(batch_size):
|
619 |
+
mask = indices[b] != self.decoder.tokenizer.pad_token_id
|
620 |
+
N = mask.sum().item()
|
621 |
+
var = np.array(
|
622 |
+
[np.var(s) / len(s) for s in batch(values[b, mask].float().numpy())]
|
623 |
+
)
|
624 |
+
if len(var) < 10:
|
625 |
+
output["repeats"].append(None)
|
626 |
+
continue
|
627 |
+
varvar = np.array([np.var(v) for v in subdiv(var[::-1])][::-1])
|
628 |
+
minlen = 120
|
629 |
+
if (
|
630 |
+
indices[b] == self.decoder.tokenizer.eos_token_id
|
631 |
+
).any() and N + 1 < indices.shape[1]:
|
632 |
+
# there is an end to the generation, likely no repetitions
|
633 |
+
output["repeats"].append(None)
|
634 |
+
continue
|
635 |
+
small_var = np.where(varvar < 0.045)[0]
|
636 |
+
if early_stopping and len(small_var) > 1:
|
637 |
+
if np.all(np.diff(small_var) < 2):
|
638 |
+
idx = int(min(max(small_var[0], 1) * 1.08 + minlen, 4095))
|
639 |
+
if idx / N > 0.9: # at most last bit
|
640 |
+
output["repeats"].append(None)
|
641 |
+
continue
|
642 |
+
elif small_var[0] < 30:
|
643 |
+
idx = 0
|
644 |
+
logging.warn("Found repetitions in sample %i" % b)
|
645 |
+
output["repeats"].append(idx)
|
646 |
+
output["sequences"][b, idx:] = self.decoder.tokenizer.pad_token_id
|
647 |
+
output["repetitions"][b, :idx] = self.decoder.tokenizer.pad_token_id
|
648 |
+
else:
|
649 |
+
output["repeats"].append(None)
|
650 |
+
else:
|
651 |
+
output["repeats"].append(None)
|
652 |
+
output["repetitions"] = self.decoder.tokenizer.batch_decode(
|
653 |
+
output["repetitions"], skip_special_tokens=True
|
654 |
+
)
|
655 |
+
output["predictions"] = postprocess(
|
656 |
+
self.decoder.tokenizer.batch_decode(
|
657 |
+
output["sequences"], skip_special_tokens=True
|
658 |
+
),
|
659 |
+
markdown_fix=False,
|
660 |
+
)
|
661 |
+
|
662 |
+
if return_attentions:
|
663 |
+
output["attentions"] = {
|
664 |
+
"self_attentions": decoder_output.decoder_attentions,
|
665 |
+
"cross_attentions": decoder_output.cross_attentions,
|
666 |
+
}
|
667 |
+
|
668 |
+
return output
|
669 |
+
|
670 |
+
@classmethod
|
671 |
+
def from_pretrained(
|
672 |
+
cls,
|
673 |
+
model_path: Union[str, bytes, os.PathLike],
|
674 |
+
*model_args,
|
675 |
+
**kwargs,
|
676 |
+
):
|
677 |
+
r"""
|
678 |
+
Instantiate a pretrained nougat model from a pre-trained model configuration
|
679 |
+
|
680 |
+
Args:
|
681 |
+
model_path:
|
682 |
+
Name of a pretrained model name either registered in huggingface.co. or saved in local.
|
683 |
+
"""
|
684 |
+
model = super(NougatModel, cls).from_pretrained(
|
685 |
+
model_path, *model_args, **kwargs
|
686 |
+
)
|
687 |
+
|
688 |
+
# truncate or interpolate position embeddings of decoder
|
689 |
+
max_length = kwargs.get("max_length", model.config.max_position_embeddings)
|
690 |
+
if (
|
691 |
+
max_length != model.config.max_position_embeddings
|
692 |
+
): # if max_length of trained model differs max_length you want to train
|
693 |
+
model.decoder.model.model.decoder.embed_positions.weight = torch.nn.Parameter(
|
694 |
+
model.decoder.resize_bart_abs_pos_emb(
|
695 |
+
model.decoder.model.model.decoder.embed_positions.weight,
|
696 |
+
max_length
|
697 |
+
+ 2, # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119
|
698 |
+
)
|
699 |
+
)
|
700 |
+
model.config.max_position_embeddings = max_length
|
701 |
+
|
702 |
+
return model
|
nougat/postprocessing.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
from typing import Union, List
|
8 |
+
import re
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
from nltk.corpus import words
|
12 |
+
from multiprocessing import Pool
|
13 |
+
from functools import partial
|
14 |
+
from Levenshtein import ratio
|
15 |
+
|
16 |
+
|
17 |
+
reference_pattern = re.compile(r"^\* \[\d+\]", flags=re.M)
|
18 |
+
|
19 |
+
|
20 |
+
def markdown_compatible(s: str) -> str:
|
21 |
+
"""
|
22 |
+
Make text compatible with Markdown formatting.
|
23 |
+
|
24 |
+
This function makes various text formatting adjustments to make it compatible with Markdown.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
s (str): The input text to be made Markdown-compatible.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
str: The Markdown-compatible text.
|
31 |
+
"""
|
32 |
+
# equation tag
|
33 |
+
s = re.sub(
|
34 |
+
r"^\(([\d.]+[a-zA-Z]?)\) \\\[(.+?)\\\]$", r"\[\2 \\tag{\1}\]", s, flags=re.M
|
35 |
+
)
|
36 |
+
s = re.sub(
|
37 |
+
r"^\\\[(.+?)\\\] \(([\d.]+[a-zA-Z]?)\)$", r"\[\1 \\tag{\2}\]", s, flags=re.M
|
38 |
+
)
|
39 |
+
s = re.sub(
|
40 |
+
r"^\\\[(.+?)\\\] \(([\d.]+[a-zA-Z]?)\) (\\\[.+?\\\])$",
|
41 |
+
r"\[\1 \\tag{\2}\] \3",
|
42 |
+
s,
|
43 |
+
flags=re.M,
|
44 |
+
) # multi line
|
45 |
+
s = s.replace(r"\. ", ". ")
|
46 |
+
# bold formatting
|
47 |
+
s = s.replace(r"\bm{", r"\mathbf{").replace(r"{\\bm ", r"\mathbf{")
|
48 |
+
# s = s.replace(r"\it{", r"\mathit{").replace(r"{\\it ", r"\mathit{") # not needed
|
49 |
+
s = re.sub(r"\\mbox{ ?\\boldmath\$(.*?)\$}", r"\\mathbf{\1}", s)
|
50 |
+
# s=re.sub(r'\\begin{table}(.+?)\\end{table}\nTable \d+: (.+?)\n',r'\\begin{table}\1\n\\capation{\2}\n\\end{table}\n',s,flags=re.S)
|
51 |
+
# s=re.sub(r'###### Abstract\n(.*?)\n\n',r'\\begin{abstract}\n\1\n\\end{abstract}\n\n',s,flags=re.S)
|
52 |
+
# urls
|
53 |
+
s = re.sub(
|
54 |
+
r"((?:http|ftp|https):\/\/(?:[\w_-]+(?:(?:\.[\w_-]+)+))(?:[\w.,@?^=%&:\/~+#-]*[\w@?^=%&\/~+#-]))",
|
55 |
+
r"[\1](\1)",
|
56 |
+
s,
|
57 |
+
)
|
58 |
+
# algorithms
|
59 |
+
s = re.sub(r"```\s*(.+?)\s*```", r"```\n\1\n```", s, flags=re.S)
|
60 |
+
# lists
|
61 |
+
|
62 |
+
return s
|
63 |
+
|
64 |
+
|
65 |
+
def find_next_punctuation(s: str, start_inx=0):
|
66 |
+
"""
|
67 |
+
Find the index of the next punctuation mark
|
68 |
+
|
69 |
+
Args:
|
70 |
+
s: String to examine
|
71 |
+
start_inx: Index where to start
|
72 |
+
"""
|
73 |
+
|
74 |
+
for i in range(start_inx, len(s)):
|
75 |
+
if s[i] in [".", "?", "!", "\n"]:
|
76 |
+
return i
|
77 |
+
|
78 |
+
return None
|
79 |
+
|
80 |
+
|
81 |
+
def find_last_punctuation(s: str, start_inx=0):
|
82 |
+
"""
|
83 |
+
Find the index of the last punctuation mark before start_inx
|
84 |
+
|
85 |
+
Args:
|
86 |
+
s: String to examine
|
87 |
+
start_inx: Index where to look before
|
88 |
+
"""
|
89 |
+
|
90 |
+
for i in range(start_inx - 1, 0, -1):
|
91 |
+
if s[i] in [".", "?", "!", "\n"]:
|
92 |
+
return i
|
93 |
+
|
94 |
+
return None
|
95 |
+
|
96 |
+
|
97 |
+
def truncate_repetitions(s: str, min_len=30):
|
98 |
+
"""
|
99 |
+
Attempt to truncate repeating segments in the input string.
|
100 |
+
|
101 |
+
This function looks for the longest repeating substring at the end of the input string and truncates
|
102 |
+
it to appear only once. To be considered for removal, repetitions need to be continuous.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
s (str): The input raw prediction to be truncated.
|
106 |
+
min_len (int): The minimum length of the repeating segment.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
str: The input string with repeated segments truncated.
|
110 |
+
"""
|
111 |
+
s_lower = s.lower()
|
112 |
+
s_len = len(s_lower)
|
113 |
+
|
114 |
+
if s_len < 2 * min_len:
|
115 |
+
return s
|
116 |
+
|
117 |
+
# try to find a length at which the tail is repeating
|
118 |
+
max_rep_len = None
|
119 |
+
for rep_len in range(min_len, int(s_len / 2)):
|
120 |
+
# check if there is a repetition at the end
|
121 |
+
same = True
|
122 |
+
for i in range(0, rep_len):
|
123 |
+
if s_lower[s_len - rep_len - i - 1] != s_lower[s_len - i - 1]:
|
124 |
+
same = False
|
125 |
+
break
|
126 |
+
|
127 |
+
if same:
|
128 |
+
max_rep_len = rep_len
|
129 |
+
|
130 |
+
if max_rep_len is None:
|
131 |
+
return s
|
132 |
+
|
133 |
+
lcs = s_lower[-max_rep_len:]
|
134 |
+
|
135 |
+
# remove all but the last repetition
|
136 |
+
st = s
|
137 |
+
st_lower = s_lower
|
138 |
+
while st_lower.endswith(lcs):
|
139 |
+
st = st[:-max_rep_len]
|
140 |
+
st_lower = st_lower[:-max_rep_len]
|
141 |
+
|
142 |
+
# this is the tail with the repetitions
|
143 |
+
repeating_tail = s_lower[len(st_lower) :]
|
144 |
+
|
145 |
+
# add until next punctuation and make sure last sentence is not repeating
|
146 |
+
st_lower_out = st_lower
|
147 |
+
while True:
|
148 |
+
sentence_end = find_next_punctuation(s_lower, len(st_lower_out))
|
149 |
+
sentence_start = find_last_punctuation(s_lower, len(st_lower_out))
|
150 |
+
if sentence_end and sentence_start:
|
151 |
+
sentence = s_lower[sentence_start:sentence_end]
|
152 |
+
st_lower_out = s_lower[: sentence_end + 1]
|
153 |
+
if sentence in repeating_tail:
|
154 |
+
break
|
155 |
+
else:
|
156 |
+
break
|
157 |
+
|
158 |
+
s_out = s[: len(st_lower_out)]
|
159 |
+
|
160 |
+
return s_out
|
161 |
+
|
162 |
+
|
163 |
+
def close_envs(s: str) -> str:
|
164 |
+
"""checks if table envs are opened but not closed. Appends the closing statements and returns the new string"""
|
165 |
+
envs = ("bmatrix", "pmatrix", "matrix", "tabular", "table")
|
166 |
+
for env in envs:
|
167 |
+
begins, ends = s.count(r"\begin{%s}" % env), s.count(r"\end{%s}" % env)
|
168 |
+
if begins > ends:
|
169 |
+
s += (r"\end{%s}" % env) * (begins - ends)
|
170 |
+
return s
|
171 |
+
|
172 |
+
|
173 |
+
def remove_numbers(lines):
|
174 |
+
def _clean(s):
|
175 |
+
return re.sub(r"(?:[\d_]|\*\*)", "", s).strip()
|
176 |
+
|
177 |
+
if type(lines) is str:
|
178 |
+
return _clean(lines)
|
179 |
+
out = []
|
180 |
+
for l in lines:
|
181 |
+
out.append(_clean(l))
|
182 |
+
return out
|
183 |
+
|
184 |
+
|
185 |
+
def get_slices(lines, clean_lines):
|
186 |
+
"""
|
187 |
+
Get slices of text based on specific criteria within the lines.
|
188 |
+
|
189 |
+
This function identifies and returns slices of text from the input lines based on certain conditions.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
lines (list of str): The list of lines containing the text.
|
193 |
+
clean_lines (list of str): A cleaned version of the text (without numbers).
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
list of tuple: A list of tuples representing the start and end indices of text slices.
|
197 |
+
"""
|
198 |
+
inds = np.zeros(len(lines))
|
199 |
+
for i in range(len(lines) - 1):
|
200 |
+
j = i + 1
|
201 |
+
while not clean_lines[j] and j < len(lines) - 1:
|
202 |
+
j += 1
|
203 |
+
if (
|
204 |
+
len(clean_lines[i]) < 200
|
205 |
+
and len(clean_lines[i]) > 3
|
206 |
+
and len(clean_lines[j]) < 200
|
207 |
+
and len(clean_lines[j]) > 3
|
208 |
+
and not clean_lines[i].startswith("[MISSING_PAGE")
|
209 |
+
and (
|
210 |
+
clean_lines[i] == clean_lines[j]
|
211 |
+
or ratio(clean_lines[i], clean_lines[j]) > 0.9
|
212 |
+
)
|
213 |
+
):
|
214 |
+
inds[i:j] = 1
|
215 |
+
ids = np.where(inds)[0]
|
216 |
+
slices = []
|
217 |
+
if len(ids) == 0:
|
218 |
+
return slices
|
219 |
+
j0 = 0
|
220 |
+
for j, x in enumerate(np.diff(ids) > 3):
|
221 |
+
if x:
|
222 |
+
slices.append((ids[j0], ids[j] + 2))
|
223 |
+
j0 = j + 1
|
224 |
+
slices.append((ids[j0], ids[-1] + 2))
|
225 |
+
return [sli for sli in slices if sli[1] - sli[0] > 15]
|
226 |
+
|
227 |
+
|
228 |
+
def remove_slice_from_lines(lines, clean_text, sli) -> str:
|
229 |
+
"""
|
230 |
+
Remove a slice of text from the lines based on specific criteria.
|
231 |
+
|
232 |
+
This function identifies a slice of text within the lines and removes it based on certain conditions.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
lines (list of str): The list of lines containing the text.
|
236 |
+
clean_text (list of str): A cleaned version of the text (without numbers).
|
237 |
+
sli (tuple): A tuple representing the start and end indices of the slice to be removed.
|
238 |
+
|
239 |
+
Returns:
|
240 |
+
str: The removed slice of text as a single string.
|
241 |
+
"""
|
242 |
+
base = clean_text[sli[0]]
|
243 |
+
section = list(sli)
|
244 |
+
check_start_flag = False
|
245 |
+
# backwards pass
|
246 |
+
for i in range(max(0, sli[0] - 1), max(0, sli[0] - 5), -1):
|
247 |
+
if not lines[i]:
|
248 |
+
continue
|
249 |
+
if lines[i] == "## References":
|
250 |
+
section[0] = i
|
251 |
+
break
|
252 |
+
elif ratio(base, remove_numbers(lines[i])) < 0.9:
|
253 |
+
section[0] = i + 1
|
254 |
+
potential_ref = remove_numbers(lines[max(0, i - 1)].partition("* [")[-1])
|
255 |
+
if (
|
256 |
+
len(potential_ref) >= 0.75 * len(base)
|
257 |
+
and ratio(base, potential_ref) < 0.9
|
258 |
+
):
|
259 |
+
section[0] = i
|
260 |
+
check_start_flag = True
|
261 |
+
break
|
262 |
+
# forward pass
|
263 |
+
for i in range(min(len(lines), sli[1]), min(len(lines), sli[1] + 5)):
|
264 |
+
if ratio(base, remove_numbers(lines[i])) < 0.9:
|
265 |
+
section[1] = i
|
266 |
+
break
|
267 |
+
if len(lines) <= section[1]:
|
268 |
+
section[1] = len(lines) - 1
|
269 |
+
to_delete = "\n".join(lines[section[0] : section[1] + 1])
|
270 |
+
# cut off next page content
|
271 |
+
itera, iterb = enumerate(lines[section[1] - 1]), enumerate(lines[section[1]])
|
272 |
+
while True:
|
273 |
+
try:
|
274 |
+
(ia, a) = next(itera)
|
275 |
+
while a.isnumeric():
|
276 |
+
(ia, a) = next(itera)
|
277 |
+
(ib, b) = next(iterb)
|
278 |
+
while b.isnumeric():
|
279 |
+
(ib, b) = next(iterb)
|
280 |
+
if a != b:
|
281 |
+
break
|
282 |
+
except StopIteration:
|
283 |
+
break
|
284 |
+
if check_start_flag and "* [" in to_delete:
|
285 |
+
to_delete = "* [" + to_delete.partition("* [")[-1]
|
286 |
+
try:
|
287 |
+
delta = len(lines[section[1]]) - ib - 1
|
288 |
+
if delta > 0:
|
289 |
+
to_delete = to_delete[:-delta]
|
290 |
+
except UnboundLocalError:
|
291 |
+
pass
|
292 |
+
|
293 |
+
return to_delete.strip()
|
294 |
+
|
295 |
+
|
296 |
+
def remove_hallucinated_references(text: str) -> str:
|
297 |
+
"""
|
298 |
+
Remove hallucinated or missing references from the text.
|
299 |
+
|
300 |
+
This function identifies and removes references that are marked as missing or hallucinated
|
301 |
+
from the input text.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
text (str): The input text containing references.
|
305 |
+
|
306 |
+
Returns:
|
307 |
+
str: The text with hallucinated references removed.
|
308 |
+
"""
|
309 |
+
lines = text.split("\n")
|
310 |
+
if len(lines) == 0:
|
311 |
+
return ""
|
312 |
+
clean_lines = remove_numbers(lines)
|
313 |
+
slices = get_slices(lines, clean_lines)
|
314 |
+
to_delete = []
|
315 |
+
for sli in slices:
|
316 |
+
to_delete.append(remove_slice_from_lines(lines, clean_lines, sli))
|
317 |
+
for to_delete in reversed(to_delete):
|
318 |
+
text = text.replace(to_delete, "\n\n[MISSING_PAGE_POST]\n\n")
|
319 |
+
text = re.sub(
|
320 |
+
r"## References\n+\[MISSING_PAGE_POST(:\d+)?\]",
|
321 |
+
"\n\n[MISSING_PAGE_POST\\1]",
|
322 |
+
text,
|
323 |
+
)
|
324 |
+
return text
|
325 |
+
|
326 |
+
|
327 |
+
def postprocess_single(generation: str, markdown_fix: bool = True) -> str:
|
328 |
+
"""
|
329 |
+
Postprocess a single generated text.
|
330 |
+
|
331 |
+
Args:
|
332 |
+
generation (str): The generated text to be postprocessed.
|
333 |
+
markdown_fix (bool, optional): Whether to perform Markdown formatting fixes. Default is True.
|
334 |
+
|
335 |
+
Returns:
|
336 |
+
str: The postprocessed text.
|
337 |
+
"""
|
338 |
+
generation = re.sub(
|
339 |
+
r"(?:\n|^)#+ \d*\W? ?(.{100,})", r"\n\1", generation
|
340 |
+
) # too long section titles probably are none
|
341 |
+
generation = generation.strip()
|
342 |
+
generation = generation.replace("\n* [leftmargin=*]\n", "\n")
|
343 |
+
generation = re.sub(
|
344 |
+
r"^#+ (?:\.?(?:\d|[ixv])+)*\s*(?:$|\n\s*)", "", generation, flags=re.M
|
345 |
+
)
|
346 |
+
# most likely hallucinated titles
|
347 |
+
lines = generation.split("\n")
|
348 |
+
if (
|
349 |
+
lines[-1].startswith("#")
|
350 |
+
and lines[-1].lstrip("#").startswith(" ")
|
351 |
+
and len(lines) > 1
|
352 |
+
):
|
353 |
+
print("INFO: likely hallucinated title at the end of the page: " + lines[-1])
|
354 |
+
generation = "\n".join(lines[:-1])
|
355 |
+
# obvious repetition detection
|
356 |
+
generation = truncate_repetitions(generation)
|
357 |
+
# Reference corrections
|
358 |
+
generation = remove_hallucinated_references(generation)
|
359 |
+
generation = re.sub(
|
360 |
+
r"^\* \[\d+\](\s?[A-W]\.+\s?){10,}.*$", "", generation, flags=re.M
|
361 |
+
)
|
362 |
+
generation = re.sub(r"^(\* \[\d+\])\[\](.*)$", r"\1\2", generation, flags=re.M)
|
363 |
+
generation = re.sub(r"(^\w\n\n|\n\n\w$)", "", generation)
|
364 |
+
# pmc math artifact correction
|
365 |
+
generation = re.sub(
|
366 |
+
r"([\s.,()])_([a-zA-Z0-9])__([a-zA-Z0-9]){1,3}_([\s.,:()])",
|
367 |
+
r"\1\(\2_{\3}\)\4",
|
368 |
+
generation,
|
369 |
+
)
|
370 |
+
generation = re.sub(
|
371 |
+
r"([\s.,\d])_([a-zA-Z0-9])_([\s.,\d;])", r"\1\(\2\)\3", generation
|
372 |
+
)
|
373 |
+
# footnote mistakes
|
374 |
+
generation = re.sub(
|
375 |
+
r"(\nFootnote .*?:) (?:footnotetext|thanks):\W*(.*(?:\n\n|$))",
|
376 |
+
r"\1 \2",
|
377 |
+
generation,
|
378 |
+
)
|
379 |
+
# TODO Come up with footnote formatting inside a table
|
380 |
+
generation = re.sub(r"\[FOOTNOTE:.+?\](.*?)\[ENDFOOTNOTE\]", "", generation)
|
381 |
+
# itemize post processing
|
382 |
+
for match in reversed(
|
383 |
+
list(
|
384 |
+
re.finditer(
|
385 |
+
r"(?:^)(-|\*)?(?!-|\*) ?((?:\d|[ixv])+ )?.+? (-|\*) (((?:\d|[ixv])+)\.(\d|[ixv]) )?.*(?:$)",
|
386 |
+
generation,
|
387 |
+
flags=re.I | re.M,
|
388 |
+
)
|
389 |
+
)
|
390 |
+
):
|
391 |
+
start, stop = match.span()
|
392 |
+
delim = match.group(3) + " "
|
393 |
+
splits = match.group(0).split(delim)
|
394 |
+
replacement = ""
|
395 |
+
if match.group(1) is not None:
|
396 |
+
splits = splits[1:]
|
397 |
+
delim1 = match.group(1) + " "
|
398 |
+
else:
|
399 |
+
delim1 = ""
|
400 |
+
# too many false positives
|
401 |
+
continue
|
402 |
+
pre, post = generation[:start], generation[stop:]
|
403 |
+
for i, item in enumerate(splits):
|
404 |
+
level = 0
|
405 |
+
potential_numeral, _, rest = item.strip().partition(" ")
|
406 |
+
if not rest:
|
407 |
+
continue
|
408 |
+
if re.match(
|
409 |
+
r"^[\dixv]+((?:\.[\dixv])?)+$", potential_numeral, flags=re.I | re.M
|
410 |
+
):
|
411 |
+
level = potential_numeral.count(".")
|
412 |
+
|
413 |
+
replacement += (
|
414 |
+
("\n" if i > 0 else "")
|
415 |
+
+ ("\t" * level)
|
416 |
+
+ (delim if i > 0 or start == 0 else delim1)
|
417 |
+
+ item.strip()
|
418 |
+
)
|
419 |
+
if post == "":
|
420 |
+
post = "\n"
|
421 |
+
generation = pre + replacement + post
|
422 |
+
|
423 |
+
if generation.endswith((".", "}")):
|
424 |
+
generation += "\n\n"
|
425 |
+
if re.match(r"[A-Z0-9,;:]$", generation):
|
426 |
+
# add space in case it there is a comma or word ending
|
427 |
+
generation += " "
|
428 |
+
elif generation.startswith(("#", "**", "\\begin")):
|
429 |
+
generation = "\n\n" + generation
|
430 |
+
elif generation.split("\n")[-1].startswith(("#", "Figure", "Table")):
|
431 |
+
generation = generation + "\n\n"
|
432 |
+
else:
|
433 |
+
try:
|
434 |
+
last_word = generation.split(" ")[-1]
|
435 |
+
if last_word in words.words():
|
436 |
+
generation += " "
|
437 |
+
except LookupError:
|
438 |
+
# add space just in case. Will split words but better than concatenating them
|
439 |
+
generation += " "
|
440 |
+
# download for the next time
|
441 |
+
import nltk
|
442 |
+
|
443 |
+
nltk.download("words")
|
444 |
+
# table corrections
|
445 |
+
# remove obvious wrong tables
|
446 |
+
for l in generation.split("\n"):
|
447 |
+
if (
|
448 |
+
l.count("\\begin{tabular}") > 15
|
449 |
+
or l.count("\\multicolumn") > 60
|
450 |
+
or l.count("&") > 400
|
451 |
+
):
|
452 |
+
generation = generation.replace(l, "")
|
453 |
+
# whitespace corrections
|
454 |
+
generation = generation.replace(
|
455 |
+
"\\begin{table} \\begin{tabular}", "\\begin{table}\n\\begin{tabular}"
|
456 |
+
)
|
457 |
+
generation = generation.replace(
|
458 |
+
"\\end{tabular} \\end{table}", "\\end{tabular}\n\\end{table}"
|
459 |
+
)
|
460 |
+
generation = generation.replace("\\end{table} Tab", "\\end{table}\nTab")
|
461 |
+
generation = re.sub(r"(^.+)\\begin{tab", r"\1\n\\begin{tab", generation, flags=re.M)
|
462 |
+
|
463 |
+
generation = generation.replace(
|
464 |
+
r"\begin{tabular}{l l} & \\ \end{tabular}", ""
|
465 |
+
).replace("\\begin{tabular}{}\n\n\\end{tabular}", "")
|
466 |
+
generation = generation.replace("\\begin{array}[]{", "\\begin{array}{")
|
467 |
+
generation = re.sub(
|
468 |
+
r"\\begin{tabular}{([clr ]){2,}}\s*[& ]*\s*(\\\\)? \\end{tabular}",
|
469 |
+
"",
|
470 |
+
generation,
|
471 |
+
)
|
472 |
+
generation = re.sub(r"(\*\*S\. A\. B\.\*\*\n+){2,}", "", generation)
|
473 |
+
generation = re.sub(r"^#+( [\[\d\w])?$", "", generation, flags=re.M)
|
474 |
+
generation = re.sub(r"^\.\s*$", "", generation, flags=re.M)
|
475 |
+
generation = re.sub(r"\n{3,}", "\n\n", generation)
|
476 |
+
if markdown_fix:
|
477 |
+
return markdown_compatible(generation)
|
478 |
+
else:
|
479 |
+
return generation
|
480 |
+
|
481 |
+
|
482 |
+
def postprocess(
|
483 |
+
generation: Union[str, List[str]], markdown_fix: bool = True
|
484 |
+
) -> Union[str, List[str]]:
|
485 |
+
"""
|
486 |
+
Postprocess generated text or a list of generated texts.
|
487 |
+
|
488 |
+
This function can be used to perform postprocessing on generated text, such as fixing Markdown formatting.
|
489 |
+
|
490 |
+
Args:
|
491 |
+
generation (Union[str, List[str]]): The generated text or a list of generated texts.
|
492 |
+
markdown_fix (bool, optional): Whether to perform Markdown formatting fixes. Default is True.
|
493 |
+
|
494 |
+
Returns:
|
495 |
+
Union[str, List[str]]: The postprocessed text or list of postprocessed texts.
|
496 |
+
"""
|
497 |
+
if type(generation) == list:
|
498 |
+
if os.environ.get("NOUGAT_MULTIPROCESSING"):
|
499 |
+
with Pool(int(os.environ.get("NOUGAT_MULTIPROCESSING"))) as p:
|
500 |
+
return p.map(
|
501 |
+
partial(postprocess_single, markdown_fix=markdown_fix), generation
|
502 |
+
)
|
503 |
+
else:
|
504 |
+
return [
|
505 |
+
postprocess_single(s, markdown_fix=markdown_fix) for s in generation
|
506 |
+
]
|
507 |
+
else:
|
508 |
+
return postprocess_single(generation, markdown_fix=markdown_fix)
|
nougat/transforms.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
# Implements image augmentation
|
8 |
+
|
9 |
+
import albumentations as alb
|
10 |
+
from albumentations.pytorch import ToTensorV2
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
14 |
+
|
15 |
+
|
16 |
+
def alb_wrapper(transform):
|
17 |
+
def f(im):
|
18 |
+
return transform(image=np.asarray(im))["image"]
|
19 |
+
|
20 |
+
return f
|
21 |
+
|
22 |
+
|
23 |
+
class Erosion(alb.ImageOnlyTransform):
|
24 |
+
"""
|
25 |
+
Apply erosion operation to an image.
|
26 |
+
|
27 |
+
Erosion is a morphological operation that shrinks the white regions in a binary image.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
scale (int or tuple/list of int): The scale or range for the size of the erosion kernel.
|
31 |
+
If an integer is provided, a square kernel of that size will be used.
|
32 |
+
If a tuple or list is provided, it should contain two integers representing the minimum
|
33 |
+
and maximum sizes for the erosion kernel.
|
34 |
+
always_apply (bool, optional): Whether to always apply this transformation. Default is False.
|
35 |
+
p (float, optional): The probability of applying this transformation. Default is 0.5.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
numpy.ndarray: The transformed image.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, scale, always_apply=False, p=0.5):
|
42 |
+
super().__init__(always_apply=always_apply, p=p)
|
43 |
+
if type(scale) is tuple or type(scale) is list:
|
44 |
+
assert len(scale) == 2
|
45 |
+
self.scale = scale
|
46 |
+
else:
|
47 |
+
self.scale = (scale, scale)
|
48 |
+
|
49 |
+
def apply(self, img, **params):
|
50 |
+
kernel = cv2.getStructuringElement(
|
51 |
+
cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2))
|
52 |
+
)
|
53 |
+
img = cv2.erode(img, kernel, iterations=1)
|
54 |
+
return img
|
55 |
+
|
56 |
+
|
57 |
+
class Dilation(alb.ImageOnlyTransform):
|
58 |
+
"""
|
59 |
+
Apply dilation operation to an image.
|
60 |
+
|
61 |
+
Dilation is a morphological operation that expands the white regions in a binary image.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
scale (int or tuple/list of int): The scale or range for the size of the dilation kernel.
|
65 |
+
If an integer is provided, a square kernel of that size will be used.
|
66 |
+
If a tuple or list is provided, it should contain two integers representing the minimum
|
67 |
+
and maximum sizes for the dilation kernel.
|
68 |
+
always_apply (bool, optional): Whether to always apply this transformation. Default is False.
|
69 |
+
p (float, optional): The probability of applying this transformation. Default is 0.5.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
numpy.ndarray: The transformed image.
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self, scale, always_apply=False, p=0.5):
|
76 |
+
super().__init__(always_apply=always_apply, p=p)
|
77 |
+
if type(scale) is tuple or type(scale) is list:
|
78 |
+
assert len(scale) == 2
|
79 |
+
self.scale = scale
|
80 |
+
else:
|
81 |
+
self.scale = (scale, scale)
|
82 |
+
|
83 |
+
def apply(self, img, **params):
|
84 |
+
kernel = cv2.getStructuringElement(
|
85 |
+
cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2))
|
86 |
+
)
|
87 |
+
img = cv2.dilate(img, kernel, iterations=1)
|
88 |
+
return img
|
89 |
+
|
90 |
+
|
91 |
+
class Bitmap(alb.ImageOnlyTransform):
|
92 |
+
"""
|
93 |
+
Apply a bitmap-style transformation to an image.
|
94 |
+
|
95 |
+
This transformation replaces all pixel values below a certain threshold with a specified value.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
value (int, optional): The value to replace pixels below the threshold with. Default is 0.
|
99 |
+
lower (int, optional): The threshold value below which pixels will be replaced. Default is 200.
|
100 |
+
always_apply (bool, optional): Whether to always apply this transformation. Default is False.
|
101 |
+
p (float, optional): The probability of applying this transformation. Default is 0.5.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
numpy.ndarray: The transformed image.
|
105 |
+
"""
|
106 |
+
|
107 |
+
def __init__(self, value=0, lower=200, always_apply=False, p=0.5):
|
108 |
+
super().__init__(always_apply=always_apply, p=p)
|
109 |
+
self.lower = lower
|
110 |
+
self.value = value
|
111 |
+
|
112 |
+
def apply(self, img, **params):
|
113 |
+
img = img.copy()
|
114 |
+
img[img < self.lower] = self.value
|
115 |
+
return img
|
116 |
+
|
117 |
+
|
118 |
+
train_transform = alb_wrapper(
|
119 |
+
alb.Compose(
|
120 |
+
[
|
121 |
+
Bitmap(p=0.05),
|
122 |
+
alb.OneOf([Erosion((2, 3)), Dilation((2, 3))], p=0.02),
|
123 |
+
alb.Affine(shear={"x": (0, 3), "y": (-3, 0)}, cval=(255, 255, 255), p=0.03),
|
124 |
+
alb.ShiftScaleRotate(
|
125 |
+
shift_limit_x=(0, 0.04),
|
126 |
+
shift_limit_y=(0, 0.03),
|
127 |
+
scale_limit=(-0.15, 0.03),
|
128 |
+
rotate_limit=2,
|
129 |
+
border_mode=0,
|
130 |
+
interpolation=2,
|
131 |
+
value=(255, 255, 255),
|
132 |
+
p=0.03,
|
133 |
+
),
|
134 |
+
alb.GridDistortion(
|
135 |
+
distort_limit=0.05,
|
136 |
+
border_mode=0,
|
137 |
+
interpolation=2,
|
138 |
+
value=(255, 255, 255),
|
139 |
+
p=0.04,
|
140 |
+
),
|
141 |
+
alb.Compose(
|
142 |
+
[
|
143 |
+
alb.Affine(
|
144 |
+
translate_px=(0, 5), always_apply=True, cval=(255, 255, 255)
|
145 |
+
),
|
146 |
+
alb.ElasticTransform(
|
147 |
+
p=1,
|
148 |
+
alpha=50,
|
149 |
+
sigma=120 * 0.1,
|
150 |
+
alpha_affine=120 * 0.01,
|
151 |
+
border_mode=0,
|
152 |
+
value=(255, 255, 255),
|
153 |
+
),
|
154 |
+
],
|
155 |
+
p=0.04,
|
156 |
+
),
|
157 |
+
alb.RandomBrightnessContrast(0.1, 0.1, True, p=0.03),
|
158 |
+
alb.ImageCompression(95, p=0.07),
|
159 |
+
alb.GaussNoise(20, p=0.08),
|
160 |
+
alb.GaussianBlur((3, 3), p=0.03),
|
161 |
+
alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
|
162 |
+
ToTensorV2(),
|
163 |
+
]
|
164 |
+
)
|
165 |
+
)
|
166 |
+
test_transform = alb_wrapper(
|
167 |
+
alb.Compose(
|
168 |
+
[
|
169 |
+
alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
|
170 |
+
ToTensorV2(),
|
171 |
+
]
|
172 |
+
)
|
173 |
+
)
|
nougat/utils/__init__.py
ADDED
File without changes
|
nougat/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (149 Bytes). View file
|
|
nougat/utils/__pycache__/checkpoint.cpython-310.pyc
ADDED
Binary file (3.84 kB). View file
|
|
nougat/utils/__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (8.96 kB). View file
|
|
nougat/utils/checkpoint.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
from typing import Optional
|
8 |
+
import requests
|
9 |
+
import os
|
10 |
+
import tqdm
|
11 |
+
import io
|
12 |
+
from pathlib import Path
|
13 |
+
import torch
|
14 |
+
|
15 |
+
BASE_URL = "https://github.com/facebookresearch/nougat/releases/download"
|
16 |
+
MODEL_TAG = "0.1.0-small"
|
17 |
+
|
18 |
+
|
19 |
+
# source: https://stackoverflow.com/a/71459251
|
20 |
+
def download_as_bytes_with_progress(url: str, name: str = None) -> bytes:
|
21 |
+
"""
|
22 |
+
Download a file from a URL and return the contents as bytes, with progress bar.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
url: The URL of the file to download.
|
26 |
+
name: The name of the file to save to. If None, the filename will be the same as the URL.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
bytes: The contents of the file.
|
30 |
+
"""
|
31 |
+
resp = requests.get(url, stream=True, allow_redirects=True)
|
32 |
+
total = int(resp.headers.get("content-length", 0))
|
33 |
+
bio = io.BytesIO()
|
34 |
+
if name is None:
|
35 |
+
name = url
|
36 |
+
with tqdm.tqdm(
|
37 |
+
desc=name,
|
38 |
+
total=total,
|
39 |
+
unit="b",
|
40 |
+
unit_scale=True,
|
41 |
+
unit_divisor=1024,
|
42 |
+
) as bar:
|
43 |
+
for chunk in resp.iter_content(chunk_size=65536):
|
44 |
+
bar.update(len(chunk))
|
45 |
+
bio.write(chunk)
|
46 |
+
return bio.getvalue()
|
47 |
+
|
48 |
+
|
49 |
+
def download_checkpoint(checkpoint: Path, model_tag: str = MODEL_TAG):
|
50 |
+
"""
|
51 |
+
Download the Nougat model checkpoint.
|
52 |
+
|
53 |
+
This function downloads the Nougat model checkpoint from GitHub.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
checkpoint (Path): The path to the checkpoint.
|
57 |
+
model_tag (str): The model tag to download. Default is "0.1.0-small".
|
58 |
+
"""
|
59 |
+
print("downloading nougat checkpoint version", model_tag, "to path", checkpoint)
|
60 |
+
files = [
|
61 |
+
"config.json",
|
62 |
+
"pytorch_model.bin",
|
63 |
+
"special_tokens_map.json",
|
64 |
+
"tokenizer.json",
|
65 |
+
"tokenizer_config.json",
|
66 |
+
]
|
67 |
+
for file in files:
|
68 |
+
download_url = f"{BASE_URL}/{model_tag}/{file}"
|
69 |
+
binary_file = download_as_bytes_with_progress(download_url, file)
|
70 |
+
if len(binary_file) > 15: # sanity check
|
71 |
+
(checkpoint / file).write_bytes(binary_file)
|
72 |
+
|
73 |
+
|
74 |
+
def torch_hub(model_tag: Optional[str] = MODEL_TAG) -> Path:
|
75 |
+
old_path = Path(torch.hub.get_dir() + "/nougat")
|
76 |
+
if model_tag is None:
|
77 |
+
model_tag = MODEL_TAG
|
78 |
+
hub_path = old_path.with_name(f"nougat-{model_tag}")
|
79 |
+
if old_path.exists():
|
80 |
+
# move to new format
|
81 |
+
old_path.rename(old_path.with_name("nougat-0.1.0-small"))
|
82 |
+
return hub_path
|
83 |
+
|
84 |
+
|
85 |
+
def get_checkpoint(
|
86 |
+
checkpoint_path: Optional[os.PathLike] = None,
|
87 |
+
model_tag: str = MODEL_TAG,
|
88 |
+
download: bool = True,
|
89 |
+
) -> Path:
|
90 |
+
"""
|
91 |
+
Get the path to the Nougat model checkpoint.
|
92 |
+
|
93 |
+
This function retrieves the path to the Nougat model checkpoint. If the checkpoint does not
|
94 |
+
exist or is empty, it can optionally download the checkpoint.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
checkpoint_path (Optional[os.PathLike]): The path to the checkpoint. If not provided,
|
98 |
+
it will check the "NOUGAT_CHECKPOINT" environment variable or use the default location.
|
99 |
+
Default is None.
|
100 |
+
model_tag (str): The model tag to download. Default is "0.1.0-small".
|
101 |
+
download (bool): Whether to download the checkpoint if it doesn't exist or is empty.
|
102 |
+
Default is True.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
Path: The path to the Nougat model checkpoint.
|
106 |
+
"""
|
107 |
+
checkpoint = Path(
|
108 |
+
checkpoint_path or os.environ.get("NOUGAT_CHECKPOINT", torch_hub(model_tag))
|
109 |
+
)
|
110 |
+
if checkpoint.exists() and checkpoint.is_file():
|
111 |
+
checkpoint = checkpoint.parent
|
112 |
+
if download and (not checkpoint.exists() or len(os.listdir(checkpoint)) < 5):
|
113 |
+
checkpoint.mkdir(parents=True, exist_ok=True)
|
114 |
+
download_checkpoint(checkpoint, model_tag=model_tag or MODEL_TAG)
|
115 |
+
return checkpoint
|
116 |
+
|
117 |
+
|
118 |
+
if __name__ == "__main__":
|
119 |
+
get_checkpoint()
|
nougat/utils/dataset.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Donut
|
3 |
+
Copyright (c) 2022-present NAVER Corp.
|
4 |
+
MIT License
|
5 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
6 |
+
"""
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
from math import prod
|
10 |
+
from pathlib import Path
|
11 |
+
from functools import partial
|
12 |
+
import random
|
13 |
+
from typing import Dict, Tuple, Callable
|
14 |
+
from PIL import Image, UnidentifiedImageError
|
15 |
+
from typing import List, Optional
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import pypdf
|
19 |
+
import orjson
|
20 |
+
from torch.utils.data import Dataset
|
21 |
+
from transformers.modeling_utils import PreTrainedModel
|
22 |
+
from nougat.dataset.rasterize import rasterize_paper
|
23 |
+
|
24 |
+
|
25 |
+
class ImageDataset(torch.utils.data.Dataset):
|
26 |
+
"""
|
27 |
+
Dataset for processing a list of images using a preparation function.
|
28 |
+
|
29 |
+
This dataset takes a list of image paths and applies a preparation function to each image.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
img_list (list): List of image paths.
|
33 |
+
prepare (Callable): A preparation function to process the images.
|
34 |
+
|
35 |
+
Attributes:
|
36 |
+
img_list (list): List of image paths.
|
37 |
+
prepare (Callable): The preparation function.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, img_list, prepare: Callable):
|
41 |
+
super().__init__()
|
42 |
+
self.img_list = img_list
|
43 |
+
self.prepare = prepare
|
44 |
+
|
45 |
+
def __len__(self):
|
46 |
+
return len(self.img_list)
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def ignore_none_collate(batch):
|
50 |
+
if batch is None:
|
51 |
+
return
|
52 |
+
try:
|
53 |
+
batch = [x for x in batch if x is not None and x[0] is not None]
|
54 |
+
if len(batch) == 0:
|
55 |
+
return
|
56 |
+
return torch.utils.data.dataloader.default_collate(batch)
|
57 |
+
except AttributeError:
|
58 |
+
pass
|
59 |
+
|
60 |
+
def __getitem__(self, idx):
|
61 |
+
try:
|
62 |
+
img = Image.open(self.img_list[idx])
|
63 |
+
return self.prepare(img)
|
64 |
+
except Exception as e:
|
65 |
+
logging.error(e)
|
66 |
+
|
67 |
+
|
68 |
+
class LazyDataset(Dataset):
|
69 |
+
"""
|
70 |
+
Lazy loading dataset for processing PDF documents.
|
71 |
+
|
72 |
+
This dataset allows lazy loading of PDF documents and provides access to processed images
|
73 |
+
using a specified preparation function.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
pdf (str): Path to the PDF document.
|
77 |
+
prepare (Callable): A preparation function to process the images.
|
78 |
+
|
79 |
+
Attributes:
|
80 |
+
name (str): Name of the PDF document.
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(self, pdf, prepare: Callable, pages: Optional[List[int]] = None):
|
84 |
+
super().__init__()
|
85 |
+
self.prepare = prepare
|
86 |
+
self.name = str(pdf)
|
87 |
+
self.init_fn = partial(rasterize_paper, pdf, pages=pages)
|
88 |
+
self.dataset = None
|
89 |
+
self.size = len(pypdf.PdfReader(pdf).pages) if pages is None else len(pages)
|
90 |
+
|
91 |
+
def __len__(self):
|
92 |
+
return self.size
|
93 |
+
|
94 |
+
def __getitem__(self, i):
|
95 |
+
if i == 0 or self.dataset is None:
|
96 |
+
self.dataset = ImageDataset(self.init_fn(), self.prepare)
|
97 |
+
if i <= self.size and i >= 0:
|
98 |
+
return self.dataset[i], self.name if i == self.size - 1 else ""
|
99 |
+
else:
|
100 |
+
raise IndexError
|
101 |
+
|
102 |
+
@staticmethod
|
103 |
+
def ignore_none_collate(batch):
|
104 |
+
if batch is None:
|
105 |
+
return None, None
|
106 |
+
try:
|
107 |
+
_batch = []
|
108 |
+
for i, x in enumerate(batch):
|
109 |
+
image, name = x
|
110 |
+
if image is not None:
|
111 |
+
_batch.append(x)
|
112 |
+
elif name:
|
113 |
+
if i > 0:
|
114 |
+
_batch[-1] = (_batch[-1][0], name)
|
115 |
+
elif len(batch) > 1:
|
116 |
+
_batch.append((batch[1][0] * 0, name))
|
117 |
+
if len(_batch) == 0:
|
118 |
+
return None, None
|
119 |
+
return torch.utils.data.dataloader.default_collate(_batch)
|
120 |
+
except AttributeError:
|
121 |
+
pass
|
122 |
+
return None, None
|
123 |
+
|
124 |
+
|
125 |
+
class SciPDFDataset(Dataset):
|
126 |
+
"""
|
127 |
+
Custom dataset for scientific PDF data.
|
128 |
+
|
129 |
+
This dataset loads data from JSONL files and provides access to images, ground truth,
|
130 |
+
and metadata.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
path_to_index (str): Path to the index file.
|
134 |
+
split (str, optional): Split of the dataset (e.g., "train", "test"). Default is "train".
|
135 |
+
root_name (str, optional): Root directory name. Default is an empty string.
|
136 |
+
template (str, optional): Template for split naming. Default is "%s".
|
137 |
+
|
138 |
+
Attributes:
|
139 |
+
empty_sample: Placeholder for empty samples.
|
140 |
+
"""
|
141 |
+
|
142 |
+
empty_sample = None
|
143 |
+
|
144 |
+
def __init__(
|
145 |
+
self,
|
146 |
+
path_to_index: str,
|
147 |
+
split: str = "train",
|
148 |
+
root_name="",
|
149 |
+
template="%s",
|
150 |
+
) -> None:
|
151 |
+
super().__init__()
|
152 |
+
self.path_to_index = Path(path_to_index)
|
153 |
+
self.root_name = root_name
|
154 |
+
self.path_to_root = self.path_to_index.parent
|
155 |
+
if not split in self.path_to_index.stem:
|
156 |
+
pti = self.path_to_root / (template % split + ".jsonl")
|
157 |
+
if pti.exists():
|
158 |
+
self.path_to_index = pti
|
159 |
+
else:
|
160 |
+
raise ValueError(f'Dataset file for split "{split}" not found: {pti}')
|
161 |
+
self.dataset_file = None # mulitprocessing
|
162 |
+
# load seek map
|
163 |
+
seek_path = self.path_to_root / (self.path_to_index.stem + ".seek.map")
|
164 |
+
if seek_path.exists():
|
165 |
+
self.seek_map = orjson.loads(seek_path.open().read())
|
166 |
+
else:
|
167 |
+
raise ValueError(
|
168 |
+
'No "%s" found in %s' % (seek_path.name, str(self.path_to_root))
|
169 |
+
)
|
170 |
+
self.dataset_length = len(self.seek_map)
|
171 |
+
|
172 |
+
def __len__(self) -> int:
|
173 |
+
return self.dataset_length
|
174 |
+
|
175 |
+
def __getitem__(self, index: int) -> Dict:
|
176 |
+
position = self.seek_map[index]
|
177 |
+
if self.dataset_file is None:
|
178 |
+
self.dataset_file = self.path_to_index.open()
|
179 |
+
self.dataset_file.seek(position)
|
180 |
+
line = self.dataset_file.readline()
|
181 |
+
try:
|
182 |
+
data: Dict = orjson.loads(line)
|
183 |
+
except Exception as e:
|
184 |
+
logging.info(
|
185 |
+
"JSONL for sample %i could not be loaded at position %i: %s\n%s",
|
186 |
+
index,
|
187 |
+
position,
|
188 |
+
str(e),
|
189 |
+
line,
|
190 |
+
)
|
191 |
+
return self.empty_sample
|
192 |
+
img_path: Path = self.path_to_root / self.root_name / data.pop("image")
|
193 |
+
if not img_path.exists():
|
194 |
+
logging.info("Sample %s could not be found.", img_path)
|
195 |
+
return self.empty_sample
|
196 |
+
try:
|
197 |
+
img = Image.open(img_path)
|
198 |
+
except UnidentifiedImageError:
|
199 |
+
logging.info("Image %s could not be opened.", img_path)
|
200 |
+
return self.empty_sample
|
201 |
+
return {"image": img, "ground_truth": data.pop("markdown"), "meta": data}
|
202 |
+
|
203 |
+
def __iter__(self):
|
204 |
+
for i in range(self.dataset_length):
|
205 |
+
yield self[i]
|
206 |
+
|
207 |
+
|
208 |
+
class NougatDataset(Dataset):
|
209 |
+
"""
|
210 |
+
Args:
|
211 |
+
dataset_path: the path to the jsonl file
|
212 |
+
"""
|
213 |
+
|
214 |
+
def __init__(
|
215 |
+
self,
|
216 |
+
dataset_path: str,
|
217 |
+
nougat_model: PreTrainedModel,
|
218 |
+
max_length: int,
|
219 |
+
split: str = "train",
|
220 |
+
root_name: str = "arxiv",
|
221 |
+
):
|
222 |
+
super().__init__()
|
223 |
+
self.nougat_model = nougat_model
|
224 |
+
self.max_length = max_length
|
225 |
+
self.split = split
|
226 |
+
self.perturb = "NOUGAT_PERTURB" in os.environ and os.environ["NOUGAT_PERTURB"]
|
227 |
+
# TODO improve naming conventions
|
228 |
+
template = "%s"
|
229 |
+
self.dataset = SciPDFDataset(
|
230 |
+
dataset_path, split=self.split, template=template, root_name=root_name
|
231 |
+
)
|
232 |
+
self.dataset_length = len(self.dataset)
|
233 |
+
|
234 |
+
def __len__(self) -> int:
|
235 |
+
return self.dataset_length
|
236 |
+
|
237 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
238 |
+
"""
|
239 |
+
Load image from image_path of given dataset_path and convert into input_tensor and labels.
|
240 |
+
Convert gt data into input_ids (tokenized string)
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
input_tensor : preprocessed image
|
244 |
+
input_ids : tokenized gt_data
|
245 |
+
"""
|
246 |
+
sample = self.dataset[idx]
|
247 |
+
if sample is None:
|
248 |
+
# if sample is broken choose another randomly
|
249 |
+
return self[random.randint(0, self.dataset_length - 1)]
|
250 |
+
if sample is None or sample["image"] is None or prod(sample["image"].size) == 0:
|
251 |
+
input_tensor = None
|
252 |
+
else:
|
253 |
+
input_tensor = self.nougat_model.encoder.prepare_input(
|
254 |
+
sample["image"], random_padding=self.split == "train"
|
255 |
+
)
|
256 |
+
|
257 |
+
tokenizer_out = self.nougat_model.decoder.tokenizer(
|
258 |
+
sample["ground_truth"],
|
259 |
+
max_length=self.max_length,
|
260 |
+
padding="max_length",
|
261 |
+
return_token_type_ids=False,
|
262 |
+
truncation=True,
|
263 |
+
return_tensors="pt",
|
264 |
+
)
|
265 |
+
input_ids = tokenizer_out["input_ids"].squeeze(0)
|
266 |
+
attention_mask = tokenizer_out["attention_mask"].squeeze(0)
|
267 |
+
# randomly perturb ground truth tokens
|
268 |
+
if self.split == "train" and self.perturb:
|
269 |
+
# check if we perturb tokens
|
270 |
+
unpadded_length = attention_mask.sum()
|
271 |
+
while random.random() < 0.1:
|
272 |
+
try:
|
273 |
+
pos = random.randint(1, unpadded_length - 2)
|
274 |
+
token = random.randint(
|
275 |
+
23, len(self.nougat_model.decoder.tokenizer) - 1
|
276 |
+
)
|
277 |
+
input_ids[pos] = token
|
278 |
+
except ValueError:
|
279 |
+
break
|
280 |
+
return input_tensor, input_ids, attention_mask
|
nougat/utils/device.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
import torch
|
8 |
+
import logging
|
9 |
+
|
10 |
+
|
11 |
+
def default_batch_size():
|
12 |
+
if torch.cuda.is_available():
|
13 |
+
batch_size = int(
|
14 |
+
torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1000 * 0.3
|
15 |
+
)
|
16 |
+
if batch_size == 0:
|
17 |
+
logging.warning("GPU VRAM is too small. Computing on CPU.")
|
18 |
+
elif torch.backends.mps.is_available():
|
19 |
+
# I don't know if there's an equivalent API so heuristically choosing bs=4
|
20 |
+
batch_size = 4
|
21 |
+
else:
|
22 |
+
# don't know what a good value is here. Would not recommend to run on CPU
|
23 |
+
batch_size = 1
|
24 |
+
logging.warning("No GPU found. Conversion on CPU is very slow.")
|
25 |
+
return batch_size
|
26 |
+
|
27 |
+
|
28 |
+
def move_to_device(model, bf16: bool = True, cuda: bool = True):
|
29 |
+
try:
|
30 |
+
if torch.backends.mps.is_available():
|
31 |
+
return model.to("mps")
|
32 |
+
except AttributeError:
|
33 |
+
pass
|
34 |
+
if bf16:
|
35 |
+
model = model.to(torch.bfloat16)
|
36 |
+
if cuda and torch.cuda.is_available():
|
37 |
+
model = model.to("cuda")
|
38 |
+
return model
|
predict.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
import sys
|
8 |
+
from pathlib import Path
|
9 |
+
import logging
|
10 |
+
import re
|
11 |
+
import argparse
|
12 |
+
import re
|
13 |
+
from functools import partial
|
14 |
+
import torch
|
15 |
+
from torch.utils.data import ConcatDataset
|
16 |
+
from tqdm import tqdm
|
17 |
+
from nougat import NougatModel
|
18 |
+
from nougat.utils.dataset import LazyDataset
|
19 |
+
from nougat.utils.checkpoint import get_checkpoint
|
20 |
+
from nougat.postprocessing import markdown_compatible
|
21 |
+
import fitz
|
22 |
+
|
23 |
+
logging.basicConfig(level=logging.INFO)
|
24 |
+
|
25 |
+
if torch.cuda.is_available():
|
26 |
+
BATCH_SIZE = int(
|
27 |
+
torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1000 * 0.3
|
28 |
+
)
|
29 |
+
if BATCH_SIZE == 0:
|
30 |
+
logging.warning("GPU VRAM is too small. Computing on CPU.")
|
31 |
+
else:
|
32 |
+
# don't know what a good value is here. Would not recommend to run on CPU
|
33 |
+
BATCH_SIZE = 1
|
34 |
+
logging.warning("No GPU found. Conversion on CPU is very slow.")
|
35 |
+
|
36 |
+
|
37 |
+
def get_args():
|
38 |
+
parser = argparse.ArgumentParser()
|
39 |
+
parser.add_argument(
|
40 |
+
"--batchsize",
|
41 |
+
"-b",
|
42 |
+
type=int,
|
43 |
+
default=BATCH_SIZE,
|
44 |
+
help="Batch size to use.",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--checkpoint",
|
48 |
+
"-c",
|
49 |
+
type=Path,
|
50 |
+
default=None,
|
51 |
+
help="Path to checkpoint directory.",
|
52 |
+
)
|
53 |
+
parser.add_argument("--out", "-o", type=Path, help="Output directory.")
|
54 |
+
parser.add_argument(
|
55 |
+
"--recompute",
|
56 |
+
action="store_true",
|
57 |
+
help="Recompute already computed PDF, discarding previous predictions.",
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--markdown",
|
61 |
+
action="store_true",
|
62 |
+
help="Add postprocessing step for markdown compatibility.",
|
63 |
+
)
|
64 |
+
parser.add_argument("pdf", nargs="+", type=Path, help="PDF(s) to process.")
|
65 |
+
args = parser.parse_args()
|
66 |
+
if args.checkpoint is None or not args.checkpoint.exists():
|
67 |
+
args.checkpoint = get_checkpoint(args.checkpoint)
|
68 |
+
if args.out is None:
|
69 |
+
logging.warning("No output directory. Output will be printed to console.")
|
70 |
+
else:
|
71 |
+
if not args.out.exists():
|
72 |
+
logging.info("Output directory does not exist. Creating output directory.")
|
73 |
+
args.out.mkdir(parents=True)
|
74 |
+
if not args.out.is_dir():
|
75 |
+
logging.error("Output has to be directory.")
|
76 |
+
sys.exit(1)
|
77 |
+
if len(args.pdf) == 1 and not args.pdf[0].suffix == ".pdf":
|
78 |
+
# input is a list of pdfs
|
79 |
+
try:
|
80 |
+
args.pdf = [
|
81 |
+
Path(l) for l in open(args.pdf[0]).read().split("\n") if len(l) > 0
|
82 |
+
]
|
83 |
+
except:
|
84 |
+
pass
|
85 |
+
return args
|
86 |
+
|
87 |
+
|
88 |
+
def main():
|
89 |
+
args = get_args()
|
90 |
+
model = NougatModel.from_pretrained(args.checkpoint).to(torch.bfloat16)
|
91 |
+
if args.batchsize > 0:
|
92 |
+
if torch.cuda.is_available():
|
93 |
+
model.to("cuda")
|
94 |
+
else:
|
95 |
+
# set batch size to 1. Need to check if there are benefits for CPU conversion for >1
|
96 |
+
args.batchsize = 1
|
97 |
+
model.eval()
|
98 |
+
datasets = []
|
99 |
+
for pdf in args.pdf:
|
100 |
+
if not pdf.exists():
|
101 |
+
continue
|
102 |
+
if args.out:
|
103 |
+
out_path = args.out / pdf.with_suffix(".mmd").name
|
104 |
+
if out_path.exists() and not args.recompute:
|
105 |
+
logging.info(
|
106 |
+
f"Skipping {pdf.name}, already computed. Run with --recompute to convert again."
|
107 |
+
)
|
108 |
+
continue
|
109 |
+
try:
|
110 |
+
dataset = LazyDataset(
|
111 |
+
pdf, partial(model.encoder.prepare_input, random_padding=False)
|
112 |
+
)
|
113 |
+
except fitz.fitz.FileDataError:
|
114 |
+
logging.info(f"Could not load file {str(pdf)}.")
|
115 |
+
continue
|
116 |
+
datasets.append(dataset)
|
117 |
+
if len(datasets) == 0:
|
118 |
+
return
|
119 |
+
dataloader = torch.utils.data.DataLoader(
|
120 |
+
ConcatDataset(datasets),
|
121 |
+
batch_size=args.batchsize,
|
122 |
+
shuffle=False,
|
123 |
+
collate_fn=LazyDataset.ignore_none_collate,
|
124 |
+
)
|
125 |
+
|
126 |
+
predictions = []
|
127 |
+
file_index = 0
|
128 |
+
page_num = 0
|
129 |
+
for i, (sample, is_last_page) in enumerate(tqdm(dataloader)):
|
130 |
+
model_output = model.inference(image_tensors=sample)
|
131 |
+
# check if model output is faulty
|
132 |
+
for j, output in enumerate(model_output["predictions"]):
|
133 |
+
if page_num == 0:
|
134 |
+
logging.info(
|
135 |
+
"Processing file %s with %i pages"
|
136 |
+
% (datasets[file_index].name, datasets[file_index].size)
|
137 |
+
)
|
138 |
+
page_num += 1
|
139 |
+
if output.strip() == "[MISSING_PAGE_POST]":
|
140 |
+
# uncaught repetitions -- most likely empty page
|
141 |
+
predictions.append(f"\n\n[MISSING_PAGE_EMPTY:{page_num}]\n\n")
|
142 |
+
elif model_output["repeats"][j] is not None:
|
143 |
+
if model_output["repeats"][j] > 0:
|
144 |
+
# If we end up here, it means the output is most likely not complete and was truncated.
|
145 |
+
logging.warning(f"Skipping page {page_num} due to repetitions.")
|
146 |
+
predictions.append(f"\n\n[MISSING_PAGE_FAIL:{page_num}]\n\n")
|
147 |
+
else:
|
148 |
+
# If we end up here, it means the document page is too different from the training domain.
|
149 |
+
# This can happen e.g. for cover pages.
|
150 |
+
predictions.append(
|
151 |
+
f"\n\n[MISSING_PAGE_EMPTY:{i*args.batchsize+j+1}]\n\n"
|
152 |
+
)
|
153 |
+
else:
|
154 |
+
if args.markdown:
|
155 |
+
output = markdown_compatible(output)
|
156 |
+
predictions.append(output)
|
157 |
+
if is_last_page[j]:
|
158 |
+
out = "".join(predictions).strip()
|
159 |
+
out = re.sub(r"\n{3,}", "\n\n", out).strip()
|
160 |
+
if args.out:
|
161 |
+
out_path = args.out / Path(is_last_page[j]).with_suffix(".mmd").name
|
162 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
163 |
+
out_path.write_text(out, encoding="utf-8")
|
164 |
+
else:
|
165 |
+
print(out, "\n\n")
|
166 |
+
predictions = []
|
167 |
+
page_num = 0
|
168 |
+
file_index += 1
|
169 |
+
|
170 |
+
|
171 |
+
if __name__ == "__main__":
|
172 |
+
main()
|