zphilip commited on
Commit
9d1fa0d
1 Parent(s): 876fac2

adding part 1

Browse files
Files changed (42) hide show
  1. .gitattributes +2 -0
  2. app.py +297 -0
  3. nougat/__init__.py +15 -0
  4. nougat/__pycache__/__init__.cpython-310.pyc +0 -0
  5. nougat/__pycache__/_version.cpython-310.pyc +0 -0
  6. nougat/__pycache__/model.cpython-310.pyc +0 -0
  7. nougat/__pycache__/postprocessing.cpython-310.pyc +0 -0
  8. nougat/__pycache__/transforms.cpython-310.pyc +0 -0
  9. nougat/_version.py +8 -0
  10. nougat/dataset/__init__.py +0 -0
  11. nougat/dataset/__pycache__/__init__.cpython-310.pyc +0 -0
  12. nougat/dataset/__pycache__/rasterize.cpython-310.pyc +0 -0
  13. nougat/dataset/create_index.py +173 -0
  14. nougat/dataset/gen_seek.py +36 -0
  15. nougat/dataset/parser/__init__.py +0 -0
  16. nougat/dataset/parser/document.py +703 -0
  17. nougat/dataset/parser/html2md.py +67 -0
  18. nougat/dataset/parser/latexml_parser.py +441 -0
  19. nougat/dataset/parser/markdown.py +396 -0
  20. nougat/dataset/pdffigures.py +71 -0
  21. nougat/dataset/rasterize.py +81 -0
  22. nougat/dataset/split_htmls_to_pages.py +219 -0
  23. nougat/dataset/split_md_to_pages.py +477 -0
  24. nougat/dataset/splitter.py +393 -0
  25. nougat/dataset/staircase.py +314 -0
  26. nougat/dataset/tokenizer.json +0 -0
  27. nougat/dataset/utils/__init__.py +8 -0
  28. nougat/dataset/utils/latex_conversion.py +146 -0
  29. nougat/dataset/utils/pdf_text_extract.py +86 -0
  30. nougat/dataset/utils/utils.py +20 -0
  31. nougat/metrics.py +117 -0
  32. nougat/model.py +702 -0
  33. nougat/postprocessing.py +508 -0
  34. nougat/transforms.py +173 -0
  35. nougat/utils/__init__.py +0 -0
  36. nougat/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  37. nougat/utils/__pycache__/checkpoint.cpython-310.pyc +0 -0
  38. nougat/utils/__pycache__/dataset.cpython-310.pyc +0 -0
  39. nougat/utils/checkpoint.py +119 -0
  40. nougat/utils/dataset.py +280 -0
  41. nougat/utils/device.py +38 -0
  42. predict.py +172 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.bin filter=lfs diff=lfs merge=lfs -text
37
+ *.pdf filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import uuid
4
+ import os
5
+ import requests
6
+ import re
7
+
8
+ os.environ['http_proxy'] = ""
9
+ os.environ['https_proxy'] = ""
10
+
11
+ """
12
+ Copyright (c) Meta Platforms, Inc. and affiliates.
13
+
14
+ This source code is licensed under the MIT license found in the
15
+ LICENSE file in the root directory of this source tree.
16
+ """
17
+ import sys
18
+ from pathlib import Path
19
+ import logging
20
+ import re
21
+ import argparse
22
+ import re
23
+ from functools import partial
24
+ import torch
25
+ from torch.utils.data import ConcatDataset
26
+ from tqdm import tqdm
27
+ from nougat import NougatModel
28
+ from nougat.utils.dataset import LazyDataset
29
+ from nougat.utils.checkpoint import get_checkpoint
30
+ from nougat.postprocessing import markdown_compatible
31
+ import fitz
32
+
33
+ logging.basicConfig(level=logging.INFO)
34
+ if torch.cuda.is_available():
35
+ BATCH_SIZE = int(
36
+ torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1000 * 0.3
37
+ )
38
+ if BATCH_SIZE == 0:
39
+ logging.warning("GPU VRAM is too small. Computing on CPU.")
40
+ else:
41
+ # don't know what a good value is here. Would not recommend to run on CPU
42
+ BATCH_SIZE = 1
43
+ logging.warning("No GPU found. Conversion on CPU is very slow.")
44
+
45
+ def nougat_predict(input_files, output_path, checkpoint, batchsize, markdown,recompute):
46
+ print(f'*** nougat predict with input :{input_files} ***')
47
+ model = NougatModel.from_pretrained(checkpoint).to(torch.bfloat16)
48
+ if batchsize > 0:
49
+ if torch.cuda.is_available():
50
+ model.to("cuda")
51
+ else:
52
+ # set batch size to 1. Need to check if there are benefits for CPU conversion for >1
53
+ batchsize = 1
54
+ model.eval()
55
+ datasets = []
56
+ for pdf in input_files:
57
+ #if not pdf.exists():
58
+ if not os.path.exists(pdf):
59
+ continue
60
+ if output_path:
61
+ out_path = output_path / pdf.with_suffix(".mmd").name
62
+ print(out_path)
63
+ if out_path.exists() and not recompute:
64
+ logging.info(
65
+ f"Skipping {pdf.name}, already computed. Run with --recompute to convert again."
66
+ )
67
+ continue
68
+ try:
69
+ dataset = LazyDataset(
70
+ pdf, partial(model.encoder.prepare_input, random_padding=False)
71
+ )
72
+ except fitz.fitz.FileDataError:
73
+ logging.info(f"Could not load file {str(pdf)}.")
74
+ continue
75
+ datasets.append(dataset)
76
+ if len(datasets) == 0:
77
+ print(f'*** nougat out files :{out_path} ***')
78
+ return out_path
79
+ dataloader = torch.utils.data.DataLoader(
80
+ ConcatDataset(datasets),
81
+ batch_size=batchsize,
82
+ shuffle=False,
83
+ collate_fn=LazyDataset.ignore_none_collate,
84
+ )
85
+
86
+ predictions = []
87
+ file_index = 0
88
+ page_num = 0
89
+ for i, (sample, is_last_page) in enumerate(tqdm(dataloader)):
90
+ model_output = model.inference(image_tensors=sample)
91
+ # check if model output is faulty
92
+ for j, output in enumerate(model_output["predictions"]):
93
+ if page_num == 0:
94
+ logging.info(
95
+ "Processing file %s with %i pages"
96
+ % (datasets[file_index].name, datasets[file_index].size)
97
+ )
98
+ page_num += 1
99
+ if output.strip() == "[MISSING_PAGE_POST]":
100
+ # uncaught repetitions -- most likely empty page
101
+ predictions.append(f"\n\n[MISSING_PAGE_EMPTY:{page_num}]\n\n")
102
+ elif model_output["repeats"][j] is not None:
103
+ if model_output["repeats"][j] > 0:
104
+ # If we end up here, it means the output is most likely not complete and was truncated.
105
+ logging.warning(f"Skipping page {page_num} due to repetitions.")
106
+ predictions.append(f"\n\n[MISSING_PAGE_FAIL:{page_num}]\n\n")
107
+ else:
108
+ # If we end up here, it means the document page is too different from the training domain.
109
+ # This can happen e.g. for cover pages.
110
+ predictions.append(
111
+ f"\n\n[MISSING_PAGE_EMPTY:{i*args.batchsize+j+1}]\n\n"
112
+ )
113
+ else:
114
+ if markdown:
115
+ output = markdown_compatible(output)
116
+ predictions.append(output)
117
+ if is_last_page[j]:
118
+ out = "".join(predictions).strip()
119
+ out = re.sub(r"\n{3,}", "\n\n", out).strip()
120
+ if output_path:
121
+ out_path = output_path / Path(is_last_page[j]).with_suffix(".mmd").name
122
+ out_path.parent.mkdir(parents=True, exist_ok=True)
123
+ out_path.write_text(out, encoding="utf-8")
124
+ else:
125
+ print(out, "\n\n")
126
+ predictions = []
127
+ page_num = 0
128
+ file_index += 1
129
+ print(f'the generated markdown file is : {out_path}')
130
+ return out_path
131
+
132
+ def get_pdf(pdf_link):
133
+ # Generate a unique filename
134
+ unique_filename = f"input/downloaded_paper_{uuid.uuid4().hex}.pdf"
135
+
136
+ # Send a GET request to the PDF link
137
+ response = requests.get(pdf_link)
138
+
139
+ if response.status_code == 200:
140
+ # Save the PDF content to a local file
141
+ with open(unique_filename, 'wb') as pdf_file:
142
+ pdf_file.write(response.content)
143
+ print("PDF downloaded successfully.")
144
+ else:
145
+ print("Failed to download the PDF.")
146
+ return unique_filename #.split('/')[-1][:-4]
147
+
148
+
149
+ def nougat_ocr(file_name):
150
+
151
+ #unique_filename = f"/content/output/downloaded_paper_{uuid.uuid4().hex}.pdf"
152
+ # Command to run
153
+ cli_command = [
154
+ 'nougat',
155
+ #'--out', unique_filename,
156
+ '--out', 'output',
157
+ 'pdf', f'{file_name}',
158
+ '--checkpoint', 'nougat',
159
+ '--markdown'
160
+ ]
161
+
162
+ # Run the command and capture its output
163
+ #completed_process =
164
+ subprocess.run(cli_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
165
+
166
+ return #unique_filename
167
+
168
+ import pathlib
169
+ def predict(pdf_file, pdf_link):
170
+ print("*************** inference ******************")
171
+ if pdf_file is None:
172
+ if pdf_link == '':
173
+ print("No file is uploaded and No link is provided")
174
+ return "No data provided. Upload a pdf file or provide a pdf link and try again!"
175
+ else:
176
+ print(f'pdf_link is - {pdf_link}')
177
+ file_name = get_pdf(pdf_link)
178
+ print(f'file_name is - {file_name}')
179
+ else:
180
+ print(pdf_file)
181
+ file_name = pdf_file.name
182
+ print(file_name)
183
+ pdf_name = pdf_file.name.split('/')[-1].split('.')[0]
184
+ print(pdf_name)
185
+
186
+ # Call nougat
187
+ #nougat_ocr(file_name)
188
+ #nougat_predict(file_name)
189
+ input_files = file_name if isinstance(file_name, os.PathLike) else pathlib.Path(file_name),
190
+ #input_files = pathlib.Path(file_name),
191
+ output_path = pathlib.Path("./output")
192
+ checkpoint = pathlib.Path("./config1/")
193
+ config = pathlib.Path("./config1/config.json")
194
+ markdown = True
195
+ batchsize = BATCH_SIZE
196
+ output_files = nougat_predict(input_files=input_files, output_path=output_path, checkpoint = checkpoint, batchsize = batchsize, markdown = markdown, recompute=False)
197
+ print(f'the generated markdown file is : {output_files}')
198
+ #print("BACKKKK")
199
+
200
+ # Open the file for reading
201
+ file_name = file_name.split('/')[-1][:-4]
202
+ #with open(f'output/{file_name}.mmd', 'r') as file:
203
+ with open(output_files, 'r+') as file:
204
+ content = file.read()
205
+ # switch math delimiters
206
+ content = content.replace(r"\(", "\$").replace(r'\)', '\$').replace(r'\[', '\$\$').replace(r'\]', '\$\$')
207
+ print("***********************************")
208
+ print("convert successfully")
209
+ print("***********************************")
210
+
211
+ return content
212
+
213
+
214
+ def nougat_ocr1(file_name):
215
+ print('******* inside nougat_ocr *******')
216
+ # CLI Command to run
217
+ cli_command = [
218
+ 'python predict',
219
+ '--out', 'output',
220
+ 'pdf', f'{file_name}',
221
+ '--checkpoint', '../config1/',
222
+ '--markdown'
223
+ ]
224
+
225
+ # Run the command and get .mmd file in an output folder
226
+ subprocess.run(cli_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
227
+ return
228
+
229
+
230
+ def predict1(pdf_file):
231
+ print('******* inside predict *******')
232
+ print(f"temporary file - {pdf_file.name}")
233
+ pdf_name = pdf_file.name.split('/')[-1].split('.')[0]
234
+ print(f"pdf file name - {pdf_name}")
235
+
236
+ #! Get prediction for a PDF using nougat
237
+ nougat_ocr(pdf_file.name)
238
+ print("BAACCKKK")
239
+
240
+ # Open the multimarkdown (.mmd) file for reading
241
+ with open(f'output/{pdf_name}.mmd', 'r') as file:
242
+ content = file.read()
243
+
244
+ return content
245
+
246
+ def process_example(pdf_file,pdf_link):
247
+ ocr_content = predict(pdf_file,pdf_link)
248
+ return gr.update(value=ocr_content)
249
+
250
+ css = """
251
+ #mkd {
252
+ height: 500px;
253
+ overflow: auto;
254
+ border: 1px solid #ccc;
255
+ }
256
+ """
257
+
258
+ with gr.Blocks(css=css) as demo:
259
+ gr.HTML("<h1><center>Nougat: Neural Optical Understanding for Academic Documents<center><h1>")
260
+ gr.HTML("<h3><center>Lukas Blecher et al. <a href='https://arxiv.org/pdf/2308.13418.pdf' target='_blank'>Paper</a>, <a href='https://facebookresearch.github.io/nougat/'>Project</a><center></h3>")
261
+
262
+ with gr.Row():
263
+ mkd = gr.Markdown('<h4><center>Upload a PDF</center></h4>',scale=1)
264
+ mkd = gr.Markdown('<h4><center><i>OR</i></center></h4>',scale=1)
265
+ mkd = gr.Markdown('<h4><center>Provide a PDF link</center></h4>',scale=1)
266
+
267
+ with gr.Row(equal_height=True):
268
+ pdf_file = gr.File(label='PDF📃', file_count='single', scale=1)
269
+ pdf_link = gr.Textbox(placeholder='Enter an Arxiv link here', label='PDF link🔗🌐', scale=1)
270
+
271
+ with gr.Row():
272
+ btn = gr.Button('Run NOUGAT🍫')
273
+ clr = gr.Button('Clear🚿')
274
+
275
+ output_headline = gr.Markdown("<h3>PDF converted to markup language through Nougat-OCR👇:</h3>")
276
+ parsed_output = gr.Markdown(elem_id='mkd', value='📃🔤OCR Output')
277
+
278
+ btn.click(predict, [pdf_file, pdf_link], parsed_output )
279
+ print('******* 1 *******')
280
+ clr.click(lambda : (gr.update(value=None),
281
+ gr.update(value=None),
282
+ gr.update(value=None)),
283
+ [],
284
+ [pdf_file, pdf_link, parsed_output]
285
+ )
286
+
287
+ gr.Examples(
288
+ [["./input/test.pdf", ""], [None, "https://arxiv.org/pdf/2308.08316.pdf"]],
289
+ inputs = [pdf_file, pdf_link],
290
+ outputs = parsed_output,
291
+ fn=process_example,
292
+ cache_examples=True,
293
+ label='Click on any Examples below to get Nougat OCR results quickly:'
294
+ )
295
+
296
+ demo.queue()
297
+ demo.launch(debug=True,share=True, server_name="0.0.0.0")
nougat/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Donut
3
+ Copyright (c) 2022-present NAVER Corp.
4
+ MIT License
5
+ Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ """
7
+ from .model import NougatConfig, NougatModel
8
+ from .utils.dataset import NougatDataset
9
+ from ._version import __version__
10
+
11
+ __all__ = [
12
+ "NougatConfig",
13
+ "NougatModel",
14
+ "NougatDataset",
15
+ ]
nougat/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (464 Bytes). View file
 
nougat/__pycache__/_version.cpython-310.pyc ADDED
Binary file (355 Bytes). View file
 
nougat/__pycache__/model.cpython-310.pyc ADDED
Binary file (19.9 kB). View file
 
nougat/__pycache__/postprocessing.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
nougat/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (5.66 kB). View file
 
nougat/_version.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ __version__ = "0.1.17"
nougat/dataset/__init__.py ADDED
File without changes
nougat/dataset/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (151 Bytes). View file
 
nougat/dataset/__pycache__/rasterize.cpython-310.pyc ADDED
Binary file (2.82 kB). View file
 
nougat/dataset/create_index.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ """
8
+ This script creates an index of all available pages and parses the meta data for all pages into a separate file.
9
+ Optionally TesseractOCR is called for each image.
10
+ """
11
+ import argparse
12
+ import json
13
+ from typing import Dict, List
14
+ import numpy as np
15
+ from pathlib import Path
16
+ import multiprocessing
17
+ from pebble import ProcessPool
18
+ from PIL import Image
19
+ import pytesseract
20
+ import re
21
+ import logging
22
+ from tqdm import tqdm
23
+
24
+
25
+ logging.basicConfig()
26
+ logger = logging.getLogger()
27
+ logger.setLevel(logging.INFO)
28
+
29
+
30
+ def convert_pt2px(pt, dpi=96):
31
+ if isinstance(pt, list):
32
+ return [round(dpi / 72 * p) for p in pt]
33
+ elif isinstance(pt, dict):
34
+ for k in pt:
35
+ pt[k] = round(dpi / 72 * pt[k])
36
+ return pt
37
+
38
+
39
+ def read_metadata(data: Dict) -> List[List[Dict]]:
40
+ N = data["num_pages"]
41
+ out = [[] for _ in range(N)]
42
+ # pdffigures2 meta data
43
+ if "pdffigures" in data and data["pdffigures"]:
44
+ for item in data["pdffigures"]:
45
+ p = item.pop("page", None)
46
+ if p is None or p >= N:
47
+ continue
48
+ item["source"] = "fig"
49
+ if "regionBoundary" in item:
50
+ item["regionBoundary"] = convert_pt2px(item["regionBoundary"])
51
+ if "captionBoundary" in item:
52
+ item["captionBoundary"] = convert_pt2px(item["captionBoundary"])
53
+ out[p].append(item)
54
+
55
+ return out
56
+
57
+
58
+ def index_paper(directory: Path, args: argparse.Namespace):
59
+ """
60
+ Pack all image-text pairs into a single h5 file and save it at `args.out`
61
+ """
62
+ paper = directory.name
63
+ markdowns = directory.glob("*.mmd")
64
+ meta_file = directory / "meta.json"
65
+ data_samples = []
66
+ if not meta_file.exists():
67
+ return
68
+ # load meta info
69
+ try:
70
+ meta = read_metadata(json.load(meta_file.open("r", encoding="utf-8")))
71
+ except json.JSONDecodeError:
72
+ return
73
+
74
+ for md_path in markdowns:
75
+ image = md_path.parent / (md_path.stem + ".png")
76
+ i = int(image.stem) - 1
77
+ if not image.exists():
78
+ continue
79
+ if i >= len(meta):
80
+ continue
81
+ data_sample = {}
82
+ ocr_path = image.parent / (image.stem + "_OCR.txt")
83
+ if args.tesseract and not ocr_path.exists():
84
+ try:
85
+ pil = Image.open(image)
86
+ ocr = pytesseract.image_to_string(pil, lang="eng", timeout=2)
87
+ ocr = re.sub(r"\n+\s+?([^\s])", r"\n\n\1", ocr).strip()
88
+ with ocr_path.open("w", encoding="utf-8") as f_ocr:
89
+ f_ocr.write(ocr)
90
+ except RuntimeError:
91
+ logger.info("Page %s of paper %s timed out", image.stem, paper)
92
+ pass
93
+ if ocr_path.exists():
94
+ data_sample["ocr"] = str(ocr_path.relative_to(args.root))
95
+ data_sample["image"] = str(image.relative_to(args.root))
96
+ data_sample["markdown"] = md_path.read_text(encoding="utf8").strip()
97
+ data_sample["meta"] = meta[i]
98
+ data_samples.append(data_sample)
99
+ return data_samples
100
+
101
+
102
+ def create_index(args):
103
+ if not args.dir.exists() and not args.dir.is_dir():
104
+ logger.error("%s does not exist or is no dir.", args.dir)
105
+ return
106
+ papers = []
107
+ depth = 0
108
+ p = args.dir
109
+ while True:
110
+ p = next(p.iterdir())
111
+ if p.is_file():
112
+ break
113
+ else:
114
+ depth += 1
115
+ papers = args.dir.glob("*/" * depth)
116
+ index = []
117
+ with ProcessPool(max_workers=args.workers) as pool:
118
+ tasks = {}
119
+ for j, paper in enumerate(papers):
120
+ fname = paper.name
121
+ tasks[fname] = pool.schedule(
122
+ index_paper,
123
+ args=[paper, args],
124
+ timeout=args.timeout,
125
+ )
126
+
127
+ for fname in tqdm(tasks):
128
+ try:
129
+ res = tasks[fname].result()
130
+ if res is None:
131
+ logger.info("%s is faulty", fname)
132
+ continue
133
+ index.append(res)
134
+ except TimeoutError:
135
+ logger.info("%s timed out", fname)
136
+
137
+ with args.out.open("w", encoding="utf-8") as f:
138
+ for item in index:
139
+ for page in item:
140
+ if len(page) == 0:
141
+ continue
142
+ f.write(json.dumps(page) + "\n")
143
+
144
+
145
+ if __name__ == "__main__":
146
+ parser = argparse.ArgumentParser()
147
+ parser.add_argument("--out", type=Path, required=True, help="Index file")
148
+ parser.add_argument(
149
+ "--dir", type=Path, required=True, help="Parent directory for input dirs"
150
+ )
151
+ parser.add_argument("--root", type=Path, default=None)
152
+ parser.add_argument(
153
+ "--tesseract",
154
+ action="store_true",
155
+ help="Tesseract OCR prediction for each page",
156
+ )
157
+ parser.add_argument(
158
+ "--workers",
159
+ type=int,
160
+ default=multiprocessing.cpu_count(),
161
+ help="How many processes to use",
162
+ )
163
+ parser.add_argument(
164
+ "--dpi", type=int, default=96, help="DPI the images were saved with"
165
+ )
166
+ parser.add_argument("--timeout", type=int, default=240, help="Max time per paper")
167
+ args = parser.parse_args()
168
+ if args.root is None:
169
+ args.root = args.dir
170
+ else:
171
+ # check if dir is subdir of root
172
+ args.dir.relative_to(args.root)
173
+ create_index(args)
nougat/dataset/gen_seek.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ from tqdm import tqdm
8
+ import json
9
+ from pathlib import Path
10
+ import argparse
11
+
12
+
13
+ def get_args():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument("src_file", nargs="+", type=Path, help="JSONL file in question")
16
+ args = parser.parse_args()
17
+ return args
18
+
19
+
20
+ if __name__ == "__main__":
21
+ args = get_args()
22
+ for file in args.src_file:
23
+ seek_map = []
24
+ seek_pos = 0
25
+ with open(file) as f:
26
+ with tqdm(smoothing=0.0) as pbar:
27
+ line = f.readline()
28
+ while line:
29
+ seek_map.append(seek_pos)
30
+ seek_pos = f.tell()
31
+ line = f.readline()
32
+ pbar.update(1)
33
+
34
+ out_file = file.parent / (file.stem + ".seek.map")
35
+ with open(out_file, "w") as f:
36
+ f.write(json.dumps(seek_map))
nougat/dataset/parser/__init__.py ADDED
File without changes
nougat/dataset/parser/document.py ADDED
@@ -0,0 +1,703 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ from collections import defaultdict
8
+ from copy import copy
9
+ import itertools
10
+ import re
11
+ from dataclasses import dataclass, field, asdict
12
+ from typing import (
13
+ Any,
14
+ List,
15
+ Dict,
16
+ Optional,
17
+ TypeVar,
18
+ Type,
19
+ Generic,
20
+ )
21
+ import numpy as np
22
+
23
+ import logging
24
+
25
+ logger = logging.getLogger()
26
+
27
+ from dataclasses import dataclass, field, asdict
28
+ from typing import List, Dict, TypeVar, Type, Generic
29
+
30
+ T = TypeVar("T")
31
+ EL = TypeVar("EL")
32
+
33
+
34
+ @dataclass
35
+ class Element(Generic[EL]):
36
+ """
37
+ Generic class representing an element with children in a tree-like structure.
38
+
39
+ Attributes:
40
+ parent (Element): The parent element.
41
+ children (List[Element]): List of child elements.
42
+ """
43
+
44
+ parent: "Element" = None
45
+ children: List["Element"] = field(default_factory=list)
46
+
47
+ @property
48
+ def plaintext(self):
49
+ return "".join([child.plaintext for child in self.children])
50
+
51
+ def append(self, child: EL) -> EL:
52
+ self.children.append(child)
53
+ child.parent = self
54
+ return child
55
+
56
+ def find_parent(self, class_or_tuple: Type[T]) -> T:
57
+ elem = self
58
+ while elem:
59
+ if isinstance(elem, class_or_tuple):
60
+ return elem
61
+ elem = elem.parent
62
+ return None
63
+
64
+
65
+ @dataclass
66
+ class UnknownElement(Element):
67
+ pass
68
+
69
+
70
+ @dataclass
71
+ class TextElement(Element):
72
+ content: str = ""
73
+
74
+ @property
75
+ def plaintext(self):
76
+ return self.content
77
+
78
+ def append(self, child: "Element"):
79
+ raise Exception(f"Cannot append elements to {self.__class__.__name__}")
80
+
81
+
82
+ @dataclass
83
+ class Math(Element):
84
+ pass
85
+
86
+
87
+ @dataclass
88
+ class PlaintextMath(Math):
89
+ pass
90
+
91
+
92
+ @dataclass
93
+ class LatexMath(Math):
94
+ inline: bool = True
95
+ code: str = ""
96
+
97
+ @property
98
+ def plaintext(self):
99
+ return self.code
100
+
101
+
102
+ @dataclass
103
+ class Author:
104
+ fullname: str = None
105
+ lastname: str = None
106
+ affiliation: str = None
107
+
108
+
109
+ @dataclass
110
+ class Link(Element):
111
+ target: str = None
112
+
113
+
114
+ @dataclass
115
+ class InlineRef(Element):
116
+ target: str = None
117
+
118
+ def as_dict(self):
119
+ return {
120
+ "target": self.target,
121
+ }
122
+
123
+
124
+ @dataclass
125
+ class Reference:
126
+ """
127
+ Data class representing a reference with various attributes.
128
+
129
+ Attributes:
130
+ title (Element): The title of the reference.
131
+ authors (List[Author]): List of authors of the reference.
132
+ ids (Dict[str, str]): Dictionary of identification information.
133
+ date (str): The publication date of the reference.
134
+ url (str): The URL link to the reference.
135
+ journal (str): The journal where the reference is published.
136
+ full_text (str): The full text content of the reference.
137
+
138
+ Methods:
139
+ as_dict(): Convert the reference object to a dictionary.
140
+ """
141
+
142
+ title: Element = None
143
+ authors: List[Author] = field(default_factory=list)
144
+ ids: Dict[str, str] = field(default_factory=dict)
145
+ date: str = None
146
+ url: str = None
147
+ journal: str = None
148
+ full_text: str = None
149
+
150
+ def as_dict(self):
151
+ return {
152
+ "title": self.title.plaintext,
153
+ "authors": [asdict(auth) for auth in self.authors],
154
+ "ids": self.ids,
155
+ "date": self.date,
156
+ "url": self.url,
157
+ "journal": self.journal,
158
+ "full_text": self.full_text,
159
+ }
160
+
161
+
162
+ @dataclass
163
+ class SpanElement(Element):
164
+ pass
165
+
166
+
167
+ @dataclass
168
+ class Italic(SpanElement):
169
+ pass
170
+
171
+
172
+ @dataclass
173
+ class Bold(SpanElement):
174
+ pass
175
+
176
+
177
+ @dataclass
178
+ class Superscript(SpanElement):
179
+ pass
180
+
181
+
182
+ @dataclass
183
+ class Subscript(SpanElement):
184
+ pass
185
+
186
+
187
+ @dataclass
188
+ class Paragraph(Element):
189
+ pass
190
+
191
+
192
+ @dataclass
193
+ class TableRow(Element):
194
+ cells: List[Element] = field(default_factory=list)
195
+
196
+ def add_cell(self, cell: Element):
197
+ self.cells.append(cell)
198
+ cell.parent = self
199
+ return cell
200
+
201
+ @property
202
+ def plaintext(self):
203
+ return "\t".join([cell.plaintext for cell in self.cells])
204
+
205
+
206
+ @dataclass
207
+ class TableHead(TableRow):
208
+ pass
209
+
210
+
211
+ @dataclass
212
+ class Table(Element):
213
+ id: str = None
214
+ header: Element = None
215
+ caption: Element = None
216
+ rows: List[TableRow] = field(default_factory=list)
217
+ keep_table: bool = False
218
+
219
+ def add_row(self, row: TableRow) -> TableRow:
220
+ self.rows.append(row)
221
+ row.parent = self
222
+ return row
223
+
224
+ @property
225
+ def plaintext(self):
226
+ return "\n".join([row.plaintext for row in self.rows])
227
+
228
+
229
+ @dataclass
230
+ class Equation(Element):
231
+ pass
232
+
233
+
234
+ @dataclass
235
+ class EquationList(Element):
236
+ equations: List[Equation] = field(default_factory=list)
237
+
238
+ def add_equation(self, eqn: Equation) -> Equation:
239
+ self.equations.append(eqn)
240
+ eqn.parent = self
241
+ return eqn
242
+
243
+ @property
244
+ def plaintext(self):
245
+ return "\n".join([eqn.plaintext for eqn in self.equations])
246
+
247
+
248
+ @dataclass
249
+ class Algorithm(Element):
250
+ caption: Element = None
251
+ lines: List[Element] = field(default_factory=list)
252
+ inline: bool = False
253
+
254
+ def add_line(self, line: Element) -> Element:
255
+ self.lines.append(line)
256
+ line.parent = self
257
+ return line
258
+
259
+ @property
260
+ def plaintext(self):
261
+ return "\n".join([line.plaintext for line in self.lines])
262
+
263
+
264
+ @dataclass
265
+ class Definition(Element):
266
+ term: Element = None
267
+ definition: Element = None
268
+
269
+ @property
270
+ def plaintext(self):
271
+ parts = []
272
+ if self.term:
273
+ parts.append(f"{self.term.plaintext}:")
274
+ if self.definition:
275
+ parts.append(self.definition.plaintext)
276
+ return " ".join(parts)
277
+
278
+
279
+ @dataclass
280
+ class DefinitionList(Element):
281
+ """
282
+ Data class representing a list of definitions with an optional header.
283
+
284
+ Attributes:
285
+ header (Element): The header element for the definition list.
286
+ items (List[Definition]): List of Definition elements.
287
+
288
+ Methods:
289
+ add_item(item: Definition) -> Definition: Add a definition item to the list.
290
+ """
291
+
292
+ header: Element = None
293
+ items: List[Element] = field(default_factory=list)
294
+
295
+ def add_item(self, item: Definition) -> Definition:
296
+ self.items.append(item)
297
+ item.parent = self
298
+ return item
299
+
300
+ @property
301
+ def plaintext(self):
302
+ parts = []
303
+ if self.header:
304
+ parts.append(self.header.plaintext)
305
+ parts.extend([df.plaintext for df in self.items])
306
+ return "\n".join(parts)
307
+
308
+
309
+ @dataclass
310
+ class Figure(Element):
311
+ id: str = None
312
+ header: Element = None
313
+ caption: Element = None
314
+
315
+
316
+ @dataclass
317
+ class Section(Element):
318
+ id: str = None
319
+ header: Element = None
320
+ level: int = 0
321
+ hnum: int = 1
322
+
323
+
324
+ @dataclass
325
+ class SectionHeader(Element):
326
+ id: str = None
327
+ header: Element = None
328
+ level: int = 0
329
+
330
+
331
+ @dataclass
332
+ class ListItem(Element):
333
+ label: str = ""
334
+
335
+
336
+ @dataclass
337
+ class ListContainer(Element):
338
+ level: int = 0
339
+ ordered: bool = False
340
+ items: List[Element] = field(default_factory=list)
341
+
342
+ def add_item(self, item: ListItem) -> ListItem:
343
+ self.items.append(item)
344
+ item.parent = self
345
+ return item
346
+
347
+ @property
348
+ def plaintext(self):
349
+ return "\n".join([item.plaintext for item in self.items])
350
+
351
+
352
+ @dataclass
353
+ class Footnote(Element):
354
+ id: str = None
355
+
356
+
357
+ @dataclass
358
+ class Document(Element, Reference):
359
+ abstract: Element = None
360
+ language: str = None
361
+ keywords: List[Element] = field(default_factory=list)
362
+ references: List[Reference] = field(default_factory=list)
363
+ inline_refs: List[InlineRef] = field(default_factory=list)
364
+ bib: Reference = None
365
+
366
+ def add_reference(self, reference):
367
+ self.references.append(reference)
368
+
369
+ def add_inline_ref(self, in_ref):
370
+ self.inline_refs.append(in_ref)
371
+
372
+ def set_bib(self, reference):
373
+ self.bib = reference
374
+
375
+
376
+ @dataclass
377
+ class Spec:
378
+ """
379
+ Data class representing specifications for table cells.
380
+
381
+ Attributes:
382
+ t (int): The top border size.
383
+ b (int): The bottom border size.
384
+ l (int): The left border size.
385
+ r (int): The right border size.
386
+ align (str): The alignment of the cell content ('c' for center, 'l' for left, 'r' for right,
387
+ or 'p{width}' for justified with a specified width).
388
+
389
+ Methods:
390
+ __hash__() -> int: Compute the hash of the specification.
391
+ __eq__(__o: object) -> bool: Check if two specifications are equal.
392
+ set_align(classes: List[str], style: Optional[str] = None) -> None:
393
+ Extract alignment information from HTML classes.
394
+ set_border(classes: List[str]) -> None: Automatically set border specifications.
395
+ set_attrs(attrs: Dict[str, Any]) -> None: Automatically set all attributes from HTML class attributes.
396
+ __str__() -> str: Get the string representation of the specification.
397
+ """
398
+
399
+ t: int = field(default=0, repr=False)
400
+ b: int = field(default=0, repr=False)
401
+ l: int = field(default=0)
402
+ r: int = field(default=0)
403
+ align: str = field(default="")
404
+
405
+ def __hash__(self) -> int:
406
+ return hash(repr(self))
407
+
408
+ def __eq__(self, __o: object) -> bool:
409
+ return repr(self) == repr(__o)
410
+
411
+ def set_align(self, classes: List[str], style: Optional[str] = None) -> None:
412
+ """extract alignment information from available classes (html)"""
413
+ aligns = [s for s in classes if "align" in s]
414
+ if len(aligns) == 0:
415
+ return
416
+ elif len(aligns) > 1:
417
+ logger.warn("Found multiple aligns in classes: %s", ", ".join(classes))
418
+ align = aligns[0]
419
+ if "center" in align or align == "c":
420
+ self.align = "c"
421
+ elif "left" in align or align == "l":
422
+ self.align = "l"
423
+ elif "right" in align or align == "r":
424
+ self.align = "r"
425
+ elif "justify" in align or align == "p":
426
+ # assert style is not None, "justify without style information"
427
+ if style is None:
428
+ self.align = "c"
429
+ else:
430
+ width = style.partition("width:")[2].partition(";")[0]
431
+ self.align = "p{%s}" % width
432
+ else:
433
+ logger.warn(
434
+ "only center, left, right, justify supported at the moment. Found %s",
435
+ align,
436
+ )
437
+ self.align = "c"
438
+
439
+ def set_border(self, classes: List[str]) -> None:
440
+ """automatically set spec with border classes e.g 'ltx_border_t'"""
441
+ for border in classes:
442
+ orientation = border.partition("border_")[2]
443
+ if len(orientation) > 0 and orientation[0] in "tbrl":
444
+ setattr(self, orientation[0], len(orientation))
445
+
446
+ def set_attrs(self, attrs: Dict[str, Any]) -> None:
447
+ """automatically set all attr from html class attributes"""
448
+ classes = attrs["class"]
449
+ style = attrs["style"] if "style" in attrs else None
450
+
451
+ self.set_align(classes, style=style)
452
+ self.set_border(classes)
453
+
454
+ def __str__(self) -> str:
455
+ if self.align:
456
+ return "|" * self.l + self.align + "|" * self.r
457
+ else:
458
+ # default center
459
+ return "|" * self.l + "c" + "|" * self.r
460
+
461
+
462
+ @dataclass
463
+ class TableCell(Element):
464
+ """
465
+ Represents a cell in an HTML table.
466
+
467
+ Attributes:
468
+ multicolumn (Optional[int]): The number of columns spanned by the cell.
469
+ multirow (Optional[int]): The number of rows spanned by the cell.
470
+ spec (Spec): The specification for the cell's formatting.
471
+ content (Element): The content of the cell.
472
+
473
+ Methods:
474
+ __post_init__(*args, **kwargs) -> None: Initialize the cell, ensuring that the spec property is not None.
475
+ __hash__() -> int: Compute the hash of the cell.
476
+ __eq__(__o: object) -> bool: Check if two cells are equal.
477
+ set_attrs(attrs: Dict[str, Any]) -> None: Set attributes for the cell from HTML attributes.
478
+ plaintext() -> str: Get the plaintext content of the cell.
479
+ """
480
+
481
+ multicolumn: Optional[int] = None
482
+ multirow: Optional[int] = None
483
+ spec: Spec = None
484
+ content: Element = None
485
+
486
+ def __post_init__(self, *args, **kwargs) -> None:
487
+ # spec property cannot be None
488
+ if self.spec is None:
489
+ self.spec = Spec()
490
+
491
+ def __hash__(self) -> int:
492
+ return hash(repr(self))
493
+
494
+ def __eq__(self, __o: object) -> bool:
495
+ return repr(self) == repr(__o)
496
+
497
+ def set_attrs(self, attrs: Dict[str, Any]) -> None:
498
+ if "colspan" in attrs:
499
+ self.multicolumn = int(attrs["colspan"])
500
+ if "rowspan" in attrs:
501
+ self.multirow = int(attrs["rowspan"])
502
+ self.spec.set_attrs(attrs)
503
+
504
+ @property
505
+ def plaintext(self):
506
+ if self.content is None:
507
+ return ""
508
+ return self.content.plaintext
509
+
510
+
511
+ @dataclass
512
+ class TableRow(Element):
513
+ """
514
+ Represents a row in an HTML table.
515
+
516
+ Attributes:
517
+ cells (List[TableCell]): The list of cells in the row.
518
+
519
+ Methods:
520
+ add_cell(cell: TableCell) -> TableCell: Add a cell to the row.
521
+ __iter__() -> Iterator: Iterate through the cells in the row.
522
+ __len__() -> int: Get the number of cells in the row.
523
+ __bool__() -> bool: Check if the row is not empty.
524
+ cum_cell_widths() -> List[int]: Get the cumulative cell widths.
525
+ cell_widths() -> List[int]: Get the widths of individual cells.
526
+ width() -> int: Get the total width of the row.
527
+ _hline(orientation: str) -> str: Determine horizontal lines to be inserted.
528
+ hline_above() -> str: Get the horizontal line description for the top of the row.
529
+ hline_below() -> str: Get the horizontal line description for the bottom of the row.
530
+ plaintext() -> str: Get the plaintext content of the row.
531
+ """
532
+
533
+ cells: List[TableCell] = field(default_factory=list)
534
+
535
+ def add_cell(self, cell: TableCell):
536
+ self.cells.append(cell)
537
+ cell.parent = self
538
+ return cell
539
+
540
+ def __iter__(self):
541
+ return iter(self.cells)
542
+
543
+ def __len__(self) -> int:
544
+ return len(self.cells)
545
+
546
+ def __bool__(self) -> bool:
547
+ return True
548
+
549
+ @property
550
+ def cum_cell_widths(self) -> List[int]:
551
+ return np.cumsum(self.cell_widths)
552
+
553
+ @property
554
+ def cell_widths(self) -> List[int]:
555
+ return [(cell.multicolumn or 1) for cell in self.cells]
556
+
557
+ @property
558
+ def width(self) -> int:
559
+ return sum(self.cell_widths)
560
+
561
+ def _hline(self, orientation: str) -> str:
562
+ """Figure out if and where horizontal lines need to be inserted.
563
+
564
+ Args:
565
+ orientation (str): Either 't' (top) or 'b' (bottom)
566
+
567
+ Returns:
568
+ str: Correct vertical line description for latex tables.
569
+ """
570
+ assert orientation == "t" or orientation == "b"
571
+ lines = []
572
+ for cell in self.cells:
573
+ lines.extend([getattr(cell.spec, orientation)] * (cell.multicolumn or 1))
574
+ lines.append(0)
575
+ indices = []
576
+ start = None
577
+ for i, v in enumerate(lines):
578
+ if v and start is None:
579
+ start = i
580
+ elif start is not None and not v:
581
+ indices.append((start, i - 1))
582
+ start = None
583
+ s = ""
584
+ for a, b in indices:
585
+ if b - a + 1 == self.width:
586
+ s += "\\hline " * lines[0]
587
+ else:
588
+ s += "\\cline{%i-%i} " % (a + 1, b + 1)
589
+ return s.strip()
590
+
591
+ @property
592
+ def hline_above(self) -> str:
593
+ return self._hline("t")
594
+
595
+ @property
596
+ def hline_below(self) -> str:
597
+ return self._hline("b")
598
+
599
+ @property
600
+ def plaintext(self) -> str:
601
+ return "\t".join([cell.plaintext for cell in self.cells])
602
+
603
+
604
+ @dataclass
605
+ class Tabular(Element):
606
+ rows: List[TableRow] = field(default_factory=list)
607
+ """
608
+ Represents a tabular structure, such as an HTML table.
609
+
610
+ Attributes:
611
+ rows (List[TableRow]): The list of rows in the tabular structure.
612
+
613
+ Methods:
614
+ add_row(row: TableRow) -> TableRow: Add a row to the tabular structure.
615
+ width() -> int: Get the maximum width of the tabular structure.
616
+ cols() -> List[List[TableCell]]: Get a list of columns in the tabular structure.
617
+ _square_table() -> None: Ensure the table has an equal number of columns in each row.
618
+ get_table_spec() -> str: Generate a LaTeX table specification based on cell alignments.
619
+ plaintext() -> str: Get the plaintext content of the tabular structure.
620
+ """
621
+
622
+ def add_row(self, row: TableRow) -> TableRow:
623
+ self.rows.append(row)
624
+ row.parent = self
625
+ return row
626
+
627
+ @property
628
+ def width(self) -> int:
629
+ if len(self.rows) > 0:
630
+ return max([r.width for r in self.rows])
631
+ else:
632
+ return 0
633
+
634
+ @property
635
+ def cols(self) -> List[List[TableCell]]:
636
+ return list(
637
+ map(
638
+ list,
639
+ itertools.zip_longest(*[r.cells for r in self.rows], fillvalue=None),
640
+ )
641
+ )
642
+
643
+ def _square_table(self) -> None:
644
+ """check if number of columns is equal for every row. Add placeholders for `\multirow` instances"""
645
+ for i, row in enumerate(self.rows):
646
+ for j, cell in enumerate(row.cells):
647
+ if cell.multirow is not None and cell.multirow > 1:
648
+ spec = copy(cell.spec)
649
+ # assume no hlines in multi cells: disable bottom lines for top and top lines for lower cells.
650
+ spec.t = 0
651
+ cell.spec.b = 0
652
+ for k in range(i + 1, i + cell.multirow):
653
+ if k < len(self.rows):
654
+ for _ in range(row.cell_widths[j]):
655
+ # add empty cell
656
+ self.rows[k].cells.insert(
657
+ j, TableCell(parent=self.rows[k], spec=spec)
658
+ )
659
+
660
+ def get_table_spec(self) -> str:
661
+ """Generates a LaTeX table spec."""
662
+ # First make table square
663
+ self._square_table()
664
+ # Find the most used spec in regular cells (no multi-col/row)
665
+ specs = [Spec() for _ in range(self.width)]
666
+ for i, col in enumerate(self.cols):
667
+ counts = defaultdict(int)
668
+ for cell in col:
669
+ if cell is None or cell.spec.align == "":
670
+ continue
671
+ if cell.multicolumn is None and cell.multirow is None:
672
+ counts[cell.spec] += 1
673
+ if len(counts) > 0:
674
+ specs[i] = max(counts, key=counts.get)
675
+ # convert all cells that don't match the column style into a multicol{1}{custom_spec}
676
+ for i, col in enumerate(self.cols):
677
+ for cell in col:
678
+ if cell is not None and cell.spec != specs[i]:
679
+ # check if there is text in the cell. If not alignment doesn't matter
680
+ if (
681
+ len(cell.children) == 0
682
+ and cell.spec.l == specs[i].l
683
+ and cell.spec.r == specs[i].r
684
+ ):
685
+ continue
686
+ # convert any standard cell into a multicol cell of width 1
687
+ if cell.multicolumn is None:
688
+ cell.multicolumn = 1
689
+ # generate final latex table spec
690
+ out = " ".join([str(spec) for spec in specs])
691
+ out = re.sub(r"(\|) +(\w)", r"\1\2", out)
692
+ out = re.sub(r"(\w) +(\|)", r"\1\2", out)
693
+ return out
694
+
695
+ @property
696
+ def plaintext(self):
697
+ return "\n".join([row.plaintext for row in self.rows])
698
+
699
+
700
+ @dataclass
701
+ class Table(Element):
702
+ id: str = None
703
+ caption: Element = None
nougat/dataset/parser/html2md.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ import argparse
8
+ from pathlib import Path
9
+ from typing import List, Optional
10
+ from bs4 import BeautifulSoup
11
+ from tqdm import tqdm
12
+ import htmlmin
13
+ from nougat.dataset.parser.latexml_parser import parse_latexml, _clean_html_whitespace
14
+ from nougat.dataset.parser.markdown import format_document
15
+
16
+
17
+ def check_file_path(paths: List[Path], wdir: Optional[Path] = None) -> List[str]:
18
+ """
19
+ Checks if the given file paths exist.
20
+
21
+ Args:
22
+ paths: A list of file paths.
23
+ wdir: The working directory. If None, the current working directory is used.
24
+
25
+ Returns:
26
+ A list of file paths that exist.
27
+ """
28
+ files = []
29
+ for path in paths:
30
+ if type(path) == str:
31
+ if path == "":
32
+ continue
33
+ path = Path(path)
34
+ pathsi = [path] if wdir is None else [path, wdir / path]
35
+ for p in pathsi:
36
+ if p.exists():
37
+ files.append((p.resolve()))
38
+ elif "*" in path.name:
39
+ files.extend([(pi.resolve()) for pi in p.parent.glob(p.name)])
40
+ return list(set(files))
41
+
42
+
43
+ if __name__ == "__main__":
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument("--html", type=Path, nargs="+", help="HTML file", required=True)
46
+ parser.add_argument("--out", type=Path, help="Output file", required=True)
47
+ args = parser.parse_args()
48
+ args.html = check_file_path(args.html)
49
+ for f in tqdm(args.html):
50
+ html = BeautifulSoup(
51
+ htmlmin.minify(
52
+ open(f, "r", encoding="utf-8").read().replace("\xa0", " "),
53
+ remove_all_empty_space=1,
54
+ ),
55
+ features="html.parser",
56
+ )
57
+ try:
58
+ doc = parse_latexml(html)
59
+ except ValueError as e:
60
+ print(e)
61
+ continue
62
+ if doc is None:
63
+ continue
64
+ out, fig = format_document(doc, keep_refs=True)
65
+ outp = (args.out if args.out.is_dir() else args.out.parent) / (f.stem + ".mmd")
66
+ with open(outp, "w", encoding="utf-8") as f:
67
+ f.write(out)
nougat/dataset/parser/latexml_parser.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ import re
8
+ import sys
9
+ import requests
10
+ from typing import Optional, Set
11
+ from bs4 import BeautifulSoup, NavigableString
12
+ import soupsieve as sv
13
+
14
+ from nougat.dataset.parser.document import *
15
+
16
+
17
+ def printerr(*args, **kwargs):
18
+ # uncomment for debugging
19
+ # print(*args, **kwargs)
20
+ pass
21
+
22
+
23
+ latexml_wrapper_selector = sv.compile(
24
+ ", ".join(
25
+ [
26
+ ".ltx_engrafo_equation_container",
27
+ "tbody",
28
+ ".ltx_note_content",
29
+ ".ltx_role_footnote",
30
+ ".ltx_note_type",
31
+ ".ltx_theorem",
32
+ ".ltx_proof",
33
+ ".ltx_quote",
34
+ "blockquote",
35
+ ".ltx_inline-para",
36
+ ".ltx_inline-block",
37
+ ]
38
+ )
39
+ )
40
+ latexml_ignore_selector = sv.compile(".ltx_rule, .ltx_pagination.ltx_role_newpage")
41
+
42
+
43
+ def is_wrapper_element(element: BeautifulSoup) -> bool:
44
+ return latexml_wrapper_selector.match(element)
45
+
46
+
47
+ def ignore_element(element: BeautifulSoup) -> bool:
48
+ return latexml_ignore_selector.match(element)
49
+
50
+
51
+ def _get_classes(el: BeautifulSoup) -> Set[str]:
52
+ if not hasattr(el, "attrs"):
53
+ return set()
54
+ classes = el.attrs.get("class")
55
+ if classes is None:
56
+ return set()
57
+ return set(classes)
58
+
59
+
60
+ def _detach_selected(element: BeautifulSoup, selector: str) -> None:
61
+ for elem in element.select(selector):
62
+ elem.extract()
63
+
64
+
65
+ def parse_latexml_authors(ltx_authors: BeautifulSoup) -> List[Author]:
66
+ authors = Paragraph()
67
+ parse_latexml_children(ltx_authors, authors)
68
+ return authors
69
+
70
+
71
+ def parse_latexml_citations(cite: BeautifulSoup, parent: Element) -> None:
72
+ """
73
+ Parses LaTeXML citations and appends them as children to the given parent element.
74
+
75
+ Args:
76
+ cite (BeautifulSoup): The BeautifulSoup object containing the citation data.
77
+ parent (Element): The parent element to which the citations will be added as children.
78
+ """
79
+ parse_latexml_children(cite, parent)
80
+ if ("[" in parent.plaintext and "]" in parent.plaintext) or re.search(
81
+ r"[A-Za-z]", parent.plaintext
82
+ ):
83
+ return
84
+
85
+ parent.children.insert(0, TextElement(content="["))
86
+ parent.children.append(TextElement(content="]"))
87
+
88
+
89
+ def _clean_html_whitespace(text: str) -> str:
90
+ if text.strip():
91
+ text = re.sub(r"(^\n+|\n+$)", "\n", text)
92
+ else:
93
+ text = text.strip("\n")
94
+ text = re.sub(r"[ \t]+", " ", text)
95
+ return text
96
+
97
+
98
+ def parse_latexml_children(html: BeautifulSoup, parent: Element) -> None:
99
+ """
100
+ Parses LaTeXML children and appends them as appropriate elements to the given parent element.
101
+
102
+ Args:
103
+ html (BeautifulSoup): The BeautifulSoup object containing the HTML data.
104
+ parent (Element): The parent element to which the parsed children will be added.
105
+ """
106
+ if html is None:
107
+ return
108
+ for child in html.children:
109
+ classes = _get_classes(child)
110
+ if isinstance(child, NavigableString):
111
+ parent.append(TextElement(content=_clean_html_whitespace(str(child))))
112
+ elif sv.match(
113
+ "p, .ltx_p, div.ltx_para, span.ltx_para, section.ltx_paragraph", child
114
+ ):
115
+ paragraph = parent.append(Paragraph())
116
+ parse_latexml_children(child, paragraph)
117
+ elif sv.match(".ltx_tag", child):
118
+ if "ltx_tag_note" not in classes:
119
+ if sv.match(".ltx_tag_section", child):
120
+ child.string = child.string.upper()
121
+ elif sv.match(".ltx_tag_subsection", child):
122
+ child.string = ""
123
+ parse_latexml_children(child, parent)
124
+ elif "ltx_tag_bibitem" in classes:
125
+ parse_latexml_children(child, parent.append(SpanElement()))
126
+ elif sv.match(".ltx_note_outer", child):
127
+ # try to place the footnote outside the current paragraph
128
+ paragraph = parent.find_parent(Paragraph)
129
+ if paragraph is not None and paragraph.parent is not None:
130
+ footnote = paragraph.parent.append(Footnote())
131
+ else:
132
+ footnote = parent.append(Footnote())
133
+ parse_latexml_children(child, footnote)
134
+ elif sv.match(".ltx_note_content > .ltx_note_mark", child):
135
+ footnote = parent.find_parent(Footnote)
136
+ if footnote is not None:
137
+ footnote.id = child.get_text(strip=True)
138
+ else:
139
+ printerr("Unable to find footnote to set its id", file=sys.stderr)
140
+ parse_latexml_children(child, parent)
141
+ elif sv.match("sup", child):
142
+ sup = parent.append(Superscript())
143
+ parse_latexml_children(child, sup)
144
+ elif sv.match("sub", child):
145
+ sub = parent.append(Subscript())
146
+ parse_latexml_children(child, sub)
147
+ elif sv.match("span.ltx_Math, span.ltx_DisplayMath", child):
148
+ inline = "ltx_DisplayMath" not in classes
149
+ math_elem = child.select_one(".mjx-math")
150
+ if math_elem:
151
+ tex = math_elem.attrs["aria-label"]
152
+ if inline:
153
+ tex = rf"\({tex}\)"
154
+ else:
155
+ tex = rf"\[{tex}\]"
156
+ parent.append(LatexMath(code=tex, inline=inline))
157
+ elif sv.match("math.ltx_Math", child):
158
+ # not sure if the math tag LaTeXML version specific, but that seems to work
159
+ inline = True
160
+ if "display" in child.attrs:
161
+ inline = child.attrs["display"] == "inline"
162
+ tex = child.attrs["alttext"]
163
+ if inline:
164
+ tex = rf"\({tex}\)"
165
+ else:
166
+ tex = rf"\[{tex}\]"
167
+ parent.append(LatexMath(code=tex, inline=inline))
168
+ elif sv.match("a.ref", child):
169
+ link = parent.append(Link())
170
+ link.target = child.attrs.get("href")
171
+ parse_latexml_children(child, link)
172
+ elif sv.match(
173
+ ".ltx_ref.ltx_missing_citation, .ltx_ref.ltx_missing_label", child
174
+ ):
175
+ placeholder = child.get_text().strip()
176
+ resolved = False
177
+ if placeholder.isnumeric():
178
+ parent.append(TextElement(content=placeholder))
179
+ resolved = True
180
+ else:
181
+ target = child.attrs.get("href")
182
+ if target is not None:
183
+ potential_num = target.partition(".bib")[2]
184
+ if potential_num.isnumeric():
185
+ parent.append(TextElement(content=potential_num))
186
+ resolved = True
187
+ if not resolved:
188
+ raise ValueError("missing reference detected")
189
+ elif sv.match(
190
+ ".ltx_bibblock, .ltx_role_author, .ltx_contact, .ltx_role_email, .ltx_role_affiliation",
191
+ child,
192
+ ):
193
+ parse_latexml_children(child, parent.append(SpanElement()))
194
+ parent.append(TextElement(content="\n"))
195
+ elif sv.match(
196
+ ".ltx_authors, .ltx_personname, .ltx_role_creation.ltx_date, .ltx_engrafo_author_notes, .ltx_author_notes, .ltx_date.ltx_role_creation",
197
+ child,
198
+ ):
199
+ parse_latexml_children(child, parent.append(Paragraph()))
200
+ parent.append(TextElement(content="\n"))
201
+ elif sv.match(
202
+ ".ltx_author_before, .ltx_role_pubyear, .ltx_role_pagerange", child
203
+ ):
204
+ pass
205
+ elif sv.match("h1.ltx_title_document", child):
206
+ doc = parent.find_parent(Document)
207
+ if doc is not None:
208
+ if doc.title is None:
209
+ doc.title = SectionHeader(parent=doc)
210
+ doc.title.hnum = int(child.name[1])
211
+ parse_latexml_children(child, doc.title)
212
+ else:
213
+ printerr("Document title is already set", file=sys.stderr)
214
+ else:
215
+ printerr("Unable to find document to set title", file=sys.stderr)
216
+ elif sv.match("section", child):
217
+ if ".ltx_bibliography" not in classes:
218
+ section = parent.append(Section())
219
+ parse_latexml_children(child, section)
220
+ elif sv.match("h1, h2, h3, h4, h5, h6", child) and "ltx_title" in classes:
221
+ if {"ltx_title_theorem", "ltx_title_proof"} & classes:
222
+ parse_latexml_children(child, parent)
223
+ parent.append(TextElement(content=": "))
224
+ elif isinstance(parent, Section):
225
+ parent.hnum = int(child.name[1])
226
+ if parent.header is None:
227
+ parent.header = SpanElement()
228
+ parse_latexml_children(child, parent.header)
229
+ else:
230
+ printerr("Dangling title element", file=sys.stderr)
231
+ parse_latexml_children(child, parent)
232
+ elif sv.match(".ltx_TOC.ltx_toc_toc", child):
233
+ s = parent.append(Section(hnum=6, header=TextElement(content="Contents")))
234
+ parse_latexml_children(child, s.append(Paragraph()))
235
+ elif sv.match(
236
+ "ul.ltx_itemize, ul.ltx_toclist, ul.ltx_biblist, ol.ltx_enumerate", child
237
+ ):
238
+ lst = parent.append(ListContainer())
239
+ lst.ordered = child.name == "ol"
240
+ parent_list = parent.find_parent(ListContainer)
241
+ lst.level = parent_list.level + 1 if parent_list is not None else 1
242
+ parse_latexml_children(child, lst)
243
+ elif sv.match("li.ltx_item, li.ltx_tocentry, li.ltx_bibitem", child):
244
+ lst = parent.find_parent(ListContainer)
245
+ if lst is not None:
246
+ item = lst.add_item(ListItem())
247
+ parse_latexml_children(child, item)
248
+ else:
249
+ printerr("List item outside list", file=sys.stderr)
250
+ elif sv.match("cite", child):
251
+ span = parent.append(SpanElement())
252
+ parse_latexml_citations(child, span)
253
+ elif sv.match("a.ltx_ref", child):
254
+ target = child.attrs.get("href")
255
+ if target.startswith("#bib"): # citation link
256
+ in_ref = parent.append(InlineRef())
257
+ in_ref.target = target
258
+ text = child.get_text()
259
+ in_ref.target = target
260
+ if text.strip().isnumeric():
261
+ in_ref.append(TextElement(content=text))
262
+ elif re.search(r"[A-Za-z][:;.,_]?\d", text):
263
+ # probably a broken citation, go with link number instead
264
+ in_ref.append(
265
+ TextElement(
266
+ content=re.sub(r"\D", "", target.partition(".bib")[2])
267
+ )
268
+ )
269
+ else:
270
+ raise ValueError('unusable reference "%s"' % text)
271
+ doc = parent.find_parent(Document)
272
+ if doc:
273
+ doc.add_inline_ref(in_ref)
274
+ else:
275
+ link = parent.append(Link())
276
+ link.target = target
277
+ parse_latexml_children(child, link)
278
+ elif sv.match("a", child) and len(classes) == 0:
279
+ target = child.attrs.get("href")
280
+ parse_latexml_children(child, parent.append(Link(target=target)))
281
+ elif sv.match(".ltx_eqn_table", child):
282
+ eqn_list = parent.append(EquationList())
283
+ parse_latexml_children(child, eqn_list)
284
+ elif sv.match(".ltx_eqn_row", child):
285
+ eqn_list = parent.find_parent(EquationList)
286
+ if eqn_list is not None:
287
+ eqn = eqn_list.add_equation(Equation())
288
+ parse_latexml_children(child, eqn)
289
+ else:
290
+ printerr("Dangling equation row", file=sys.stderr)
291
+ parse_latexml_children(child, parent)
292
+ elif sv.match(".ltx_eqn_cell", child):
293
+ parse_latexml_children(child, parent)
294
+ elif sv.match("table, span.ltx_tabular, div.ltx_tabular", child):
295
+ tabular = parent.append(Tabular())
296
+ parse_latexml_children(child, tabular)
297
+ elif sv.match("thead.ltx_thead", child):
298
+ table = parent.find_parent(Tabular)
299
+ if table is not None:
300
+ parse_latexml_children(child, table)
301
+ else:
302
+ printerr("Table header element outside table", file=sys.stderr)
303
+ elif sv.match("tbody.ltx_tbody", child):
304
+ parse_latexml_children(child, parent)
305
+ elif sv.match("tr.ltx_tr", child):
306
+ table = parent.find_parent(Tabular)
307
+ if table is not None:
308
+ row = table.add_row(TableRow())
309
+ parse_latexml_children(child, row)
310
+ else:
311
+ printerr("TableRow element outside table", file=sys.stderr)
312
+ elif sv.match("td.ltx_td, th.ltx_th", child):
313
+ row = parent.find_parent(TableRow)
314
+ if row is not None:
315
+ cell = TableCell()
316
+ cell.set_attrs(child.attrs)
317
+ row.add_cell(cell)
318
+ parse_latexml_children(child, cell)
319
+ else:
320
+ printerr("TableData element outside table row", file=sys.stderr)
321
+ elif sv.match("span.ltx_text, em.ltx_emph", child):
322
+ if (
323
+ child.find_parent(ListItem) is None
324
+ or child.get_text() != "[label=0)]"
325
+ or child.get_text() != "[leftmargin=*] "
326
+ ):
327
+ if "ltx_font_italic" in classes:
328
+ elem = Italic()
329
+ elif "ltx_font_bold" in classes:
330
+ elem = Bold()
331
+ else:
332
+ elem = SpanElement()
333
+ parent.append(elem)
334
+ parse_latexml_children(child, elem)
335
+ else:
336
+ parent.find_parent(ListContainer).items.pop()
337
+ elif sv.match("figure.ltx_table", child):
338
+ figure = parent.append(Table())
339
+ if "id" in child.attrs:
340
+ figure.id = child.attrs["id"]
341
+ parse_latexml_children(child, figure)
342
+ elif sv.match("figure.ltx_figure", child):
343
+ figure = parent.append(Figure())
344
+ if "id" in child.attrs:
345
+ figure.id = child.attrs["id"]
346
+ parse_latexml_children(child, figure)
347
+ elif sv.match("figure.ltx_float", child):
348
+ parse_latexml_children(child, parent)
349
+ elif sv.match(".ltx_listing", child):
350
+ alg = parent.append(Algorithm())
351
+ parse_latexml_children(child, alg)
352
+ elif sv.match(".ltx_listingline", child):
353
+ alg = parent.find_parent(Algorithm)
354
+ if alg is not None:
355
+ line = alg.add_line(Element())
356
+ parse_latexml_children(child, line)
357
+ else:
358
+ printerr("Listing line outside algorithm environment", file=sys.stderr)
359
+ elif sv.match("dl.ltx_description", child):
360
+ def_list = parent.append(DefinitionList())
361
+ parse_latexml_children(child, def_list)
362
+ elif sv.match("dt.ltx_item", child):
363
+ def_list = parent.find_parent(DefinitionList)
364
+ if def_list is not None:
365
+ item = def_list.add_item(Definition())
366
+ item.term = SpanElement(parent=item)
367
+ parse_latexml_children(child, item.term)
368
+ else:
369
+ printerr("Found dangling definition term", file=sys.stderr)
370
+ elif sv.match("dd.ltx_item", child):
371
+ def_list = parent.find_parent(DefinitionList)
372
+ if def_list is not None:
373
+ if def_list.items and def_list.items[-1].definition is None:
374
+ item = def_list.items[-1]
375
+ else:
376
+ printerr("Found definition without term", file=sys.stderr)
377
+ item = def_list.add_item(Definition())
378
+ item.definition = SpanElement(parent=item)
379
+ parse_latexml_children(child, item.definition)
380
+ else:
381
+ printerr("Found dangling definition", file=sys.stderr)
382
+ parse_latexml_children(child, parent)
383
+ elif sv.match("figcaption", child):
384
+ fig = parent.find_parent((Figure, Table))
385
+ if fig is not None:
386
+ if fig.caption is None:
387
+ fig.caption = Paragraph(parent=fig)
388
+ parse_latexml_children(child, fig.caption)
389
+ fig.caption.append(TextElement(content="\n"))
390
+ else:
391
+ printerr("Figure caption outside figure element", file=sys.stderr)
392
+ para = parent.append(Paragraph())
393
+ parse_latexml_children(child, para)
394
+ elif sv.match(".ltx_break", child):
395
+ parent.append(TextElement(content="\n\n"))
396
+ elif sv.match(".ltx_abstract, .ltx_acknowledgements", child):
397
+ abstract = parent.append(Section())
398
+ parse_latexml_children(child, abstract)
399
+ elif sv.match(".ltx_ERROR", child):
400
+ printerr(
401
+ f"LaTeX error element: {child.get_text(strip=True)}", file=sys.stderr
402
+ )
403
+ elif is_wrapper_element(child):
404
+ parse_latexml_children(child, parent)
405
+ elif ignore_element(child):
406
+ continue
407
+ else:
408
+ printerr(
409
+ f"Unknown LaTeXML element <{child.name}> with classes {', '.join(classes)}",
410
+ file=sys.stderr,
411
+ )
412
+ elem = parent.append(UnknownElement())
413
+ parse_latexml_children(child, elem)
414
+
415
+
416
+ # TODO: move this somewhere else, so I can use it with plaintext too
417
+ sess = requests.Session()
418
+
419
+
420
+ def parse_latexml_references(html: BeautifulSoup, doc: Document) -> None:
421
+ for child in html.select("li.ltx_bibitem"):
422
+ child.attrs.get("id")
423
+ ref_text = child.get_text(strip=False).replace("\n", " ")
424
+ reference = Reference()
425
+ reference.title = TextElement(content=child.get_text(strip=True))
426
+ doc.add_reference(reference)
427
+
428
+
429
+ def parse_latexml(
430
+ html: BeautifulSoup,
431
+ ) -> Optional[Document]:
432
+ if html.article is None:
433
+ printerr("Missing article element", file=sys.stderr)
434
+ return None
435
+ doc = Document()
436
+ parse_latexml_children(html.article, doc)
437
+ parse_latexml_references(
438
+ html.article,
439
+ doc,
440
+ )
441
+ return doc
nougat/dataset/parser/markdown.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ from typing import Iterable, List, Optional, Tuple
8
+ import re
9
+ from uuid import uuid4
10
+ from nougat.dataset.utils import normalize_tex
11
+ from nougat.dataset.parser.document import *
12
+ from nougat.dataset.parser.latexml_parser import _clean_html_whitespace
13
+ from unidecode import unidecode
14
+
15
+ SUPERSCRIPT_MAP = str.maketrans("0123456789", "⁰¹²³⁴⁵⁶⁷⁸⁹")
16
+ SUBSCRIPT_MAP = str.maketrans("0123456789", "₀₁₂₃₄₅₆₇₈₉")
17
+ figure_regex = re.compile(r"\[(FOOTNOTE|FIGURE|TABLE)(.*?)\](.*?)\[END\1\]", re.S)
18
+ conv = {
19
+ "&": r"\&",
20
+ "%": r"\%",
21
+ "$": r"\$",
22
+ "#": r"\#",
23
+ "_": r"\_",
24
+ "{": r"\{",
25
+ "}": r"\}",
26
+ "~": r"\textasciitilde{}",
27
+ "^": r"\^{}",
28
+ "\\": r"\textbackslash{}",
29
+ "<": r"\textless{}",
30
+ ">": r"\textgreater{}",
31
+ }
32
+ regex = re.compile(
33
+ "|".join(
34
+ re.escape(str(key)) for key in sorted(conv.keys(), key=lambda item: -len(item))
35
+ )
36
+ )
37
+
38
+
39
+ def remove_trailing_whitespace(parts: List[str]) -> None:
40
+ """Removes whitespace elements in list inplace"""
41
+ for s in reversed(parts):
42
+ if s.rstrip() == "":
43
+ del parts[-1]
44
+ else:
45
+ break
46
+
47
+
48
+ def remove_line_breaks(parts: List[str]):
49
+ out = []
50
+ for s in parts:
51
+ out.append(s.replace("\n", " "))
52
+ return out
53
+
54
+
55
+ def leading_trailing_whitespace(
56
+ parts: List[str],
57
+ ) -> Tuple[List[str], List[str], List[str]]:
58
+ """splits the list into three parts. The first and last return elements are made up only of whitespace
59
+
60
+ Args:
61
+ parts (List[str]): List to split.
62
+
63
+ Returns:
64
+ Tuple[List[str],List[str],List[str]]: Splitted list
65
+ """
66
+ lead = []
67
+ trail = []
68
+ out_slice = [None, None]
69
+ for i, s in enumerate(parts):
70
+ if s.strip() == "":
71
+ lead.append(s)
72
+ out_slice[0] = i + 1
73
+ else:
74
+ break
75
+ for i, s in enumerate(reversed(parts)):
76
+ if s.strip() == "":
77
+ trail.append(s)
78
+ out_slice[1] = -1 - i
79
+ else:
80
+ break
81
+ return lead, parts[slice(*out_slice)], trail[::-1]
82
+
83
+
84
+ def latex_escape(string: str) -> str:
85
+ return regex.sub(lambda match: conv[match.group()], string)
86
+
87
+
88
+ def is_empty(content: List) -> bool:
89
+ """Used to determine if a Section is empty"""
90
+ empty = True
91
+ for part in content:
92
+ if len(part.strip()):
93
+ empty = False
94
+ break
95
+ return empty
96
+
97
+
98
+ def format_element(
99
+ element: Element, keep_refs: bool = False, latex_env: bool = False
100
+ ) -> List[str]:
101
+ """
102
+ Formats a given Element into a list of formatted strings.
103
+
104
+ Args:
105
+ element (Element): The element to be formatted.
106
+ keep_refs (bool, optional): Whether to keep references in the formatting. Default is False.
107
+ latex_env (bool, optional): Whether to use LaTeX environment formatting. Default is False.
108
+
109
+ Returns:
110
+ List[str]: A list of formatted strings representing the formatted element.
111
+ """
112
+ if isinstance(element, TextElement):
113
+ if latex_env:
114
+ return [latex_escape(element.content)]
115
+ else:
116
+ return [element.content]
117
+ if isinstance(element, Bold):
118
+ parts = format_children(element, keep_refs, latex_env)
119
+ if element.find_parent(Algorithm) is not None:
120
+ return parts
121
+ lead, text, tail = leading_trailing_whitespace("".join(parts))
122
+ return [*lead, "**", *remove_line_breaks(text), "**", *tail]
123
+ if isinstance(element, Italic):
124
+ parts = format_children(element, keep_refs, latex_env)
125
+ if element.find_parent(Algorithm) is not None:
126
+ return parts
127
+ lead, text, tail = leading_trailing_whitespace("".join(parts))
128
+ return [*lead, "_", *remove_line_breaks(text), "_", *tail]
129
+ if isinstance(element, PlaintextMath):
130
+ return format_children(element, keep_refs) + ["\n"]
131
+ if isinstance(element, Paragraph):
132
+ return format_children(element, keep_refs, latex_env) + ["\n\n"]
133
+ if isinstance(element, TableCell):
134
+ parts = format_children(element, keep_refs, latex_env)
135
+ remove_trailing_whitespace(parts)
136
+ if element.multirow is not None:
137
+ parts.insert(0, "\\multirow{%i}{*}{" % (element.multirow))
138
+ parts.append("}")
139
+ if element.multicolumn is not None:
140
+ parts.insert(
141
+ 0, "\\multicolumn{%i}{%s}{" % (element.multicolumn, element.spec)
142
+ )
143
+ parts.append("}")
144
+ return parts
145
+ if isinstance(element, TableRow):
146
+ parts = []
147
+ if element.hline_above:
148
+ parts.append(element.hline_above + "\n")
149
+ parts.extend(
150
+ remove_line_breaks(
151
+ format_iterator(element.cells, keep_refs, latex_env, join=" & ")
152
+ )
153
+ )
154
+ parts.append(r" \\")
155
+ parts.append((" " + element.hline_below).rstrip())
156
+ return parts
157
+ if isinstance(element, Tabular):
158
+ parts = [
159
+ "\\begin{tabular}",
160
+ "{%s}\n" % element.get_table_spec(),
161
+ ]
162
+ parts.extend(format_iterator(element.rows, keep_refs, True, join="\n"))
163
+ parts.append("\n\\end{tabular}\n")
164
+ return parts
165
+ if isinstance(element, Table):
166
+ parts = [
167
+ "[TABLE%s]\n\\begin{table}\n"
168
+ % (str(uuid4())[:5] if element.id is None else ":" + str(element.id))
169
+ ]
170
+ parts.extend(format_children(element, keep_refs, latex_env))
171
+ caption_parts = format_element(element.caption, keep_refs, latex_env)
172
+ remove_trailing_whitespace(caption_parts)
173
+ parts.append("\\end{table}\n")
174
+ if len(caption_parts) > 0:
175
+ parts.extend(caption_parts + ["\n"])
176
+ parts.append("[ENDTABLE]\n\n")
177
+ return parts
178
+ if isinstance(element, Figure):
179
+ parts = format_element(element.caption, keep_refs)
180
+ remove_trailing_whitespace(parts)
181
+ return (
182
+ [
183
+ "[FIGURE%s]\n"
184
+ % (str(uuid4())[:5] if element.id is None else ":" + str(element.id))
185
+ ]
186
+ + parts
187
+ + ["\n[ENDFIGURE]\n\n"]
188
+ )
189
+ if isinstance(element, SectionHeader):
190
+ parts = ["# "]
191
+ if element.id:
192
+ parts.append(f"{element.id.upper()} ")
193
+ if element.header:
194
+ header = format_element(element.header, keep_refs)
195
+ else:
196
+ header = format_iterator(element.children, keep_refs)
197
+ _, title, _ = leading_trailing_whitespace("".join(header))
198
+ parts.append(title)
199
+ parts.append("\n\n")
200
+ return parts
201
+ if isinstance(element, Section):
202
+ children_parts = format_children(element, keep_refs)
203
+ if is_empty(children_parts):
204
+ return []
205
+ if element.header:
206
+ parts = [f"\n\n{'#'*element.hnum} "]
207
+ _, title, _ = leading_trailing_whitespace(
208
+ "".join(format_element(element.header, keep_refs))
209
+ )
210
+ parts.append(title)
211
+ parts.append("\n\n")
212
+ else:
213
+ parts = []
214
+ return parts + children_parts
215
+ if isinstance(element, Footnote):
216
+ if element.id is not None:
217
+ foot = f"\n[FOOTNOTE:{element.id}]Footnote {element.id}: "
218
+ else:
219
+ foot = "\n[FOOTNOTE:%s]Footnote: " % (str(uuid4())[:5])
220
+ return [foot] + format_children(element, keep_refs) + ["[ENDFOOTNOTE]\n\n"]
221
+ if isinstance(element, ListContainer):
222
+ items = [
223
+ (
224
+ item.label,
225
+ "".join(format_element(item, keep_refs)).strip().replace("\n", " "),
226
+ )
227
+ for item in element.items
228
+ ]
229
+ parts = ["\n"]
230
+ indent = " " * max(element.level - 1, 0)
231
+ for i, (label, item) in enumerate(items, 1):
232
+ if label:
233
+ bullet = label
234
+ else:
235
+ bullet = f"{i}." if element.ordered else "*"
236
+ parts.append(f"{indent}{bullet} {item}\n")
237
+ parts.append("\n")
238
+ return parts
239
+ if isinstance(element, Equation):
240
+ # equation comprises of multiple displaystyle TeX formulas and optional equation label
241
+ parts = []
242
+ for child in element.children:
243
+ if isinstance(child, LatexMath):
244
+ tex = normalize_tex(
245
+ "".join(format_element(child, keep_refs)).strip(" \n"), inline=False
246
+ )
247
+ parts.append(tex)
248
+ else:
249
+ text = "".join(format_element(child, keep_refs))
250
+ if text:
251
+ parts.append(text)
252
+ lead, eqs, tail = leading_trailing_whitespace(parts)
253
+ s = " ".join(eqs).replace(r"\] \[", " ")
254
+ return [*lead, s, *tail]
255
+ if isinstance(element, EquationList):
256
+ parts = ["\n"]
257
+ items = element.equations
258
+ items = ["".join(format_element(item, keep_refs)).rstrip() for item in items]
259
+ items = [item + "\n" for item in items if item]
260
+ if items:
261
+ parts.extend(items)
262
+ parts.append("\n")
263
+ return parts
264
+ if isinstance(element, Algorithm):
265
+ parts = []
266
+ items = element.lines
267
+ items = ["".join(format_element(item, keep_refs)).rstrip() for item in items]
268
+ if element.inline:
269
+ items = [item for item in items if item]
270
+ else:
271
+ items = [item + "\n" for item in items if item]
272
+ if items:
273
+ prepend = "`" if element.inline else "\n```\n"
274
+ parts.append(prepend)
275
+ parts.extend(items)
276
+ append = "`" if element.inline else "```\n\n"
277
+ parts.append(append)
278
+ return parts
279
+ if isinstance(element, DefinitionList):
280
+ parts = ["\n"]
281
+ if element.header is not None:
282
+ parts.extend(format_element(element.header, keep_refs))
283
+ parts.append("\n")
284
+ items = [
285
+ "".join(format_element(item, keep_refs)).rstrip() for item in element.items
286
+ ]
287
+ items = [item + "\n" for item in items if item]
288
+ if items:
289
+ parts.extend(items)
290
+ parts.append("\n")
291
+ return parts
292
+ if isinstance(element, Definition):
293
+ parts = []
294
+ if element.term is not None:
295
+ term = (
296
+ "".join(format_element(element.term, keep_refs)).rstrip(" \n\t:") + ": "
297
+ )
298
+ # maths in wiki might be inside a definition without a term
299
+ if term.strip() != ":":
300
+ parts.append(term)
301
+ if element.definition is not None:
302
+ definition = "".join(format_element(element.definition, keep_refs)).rstrip()
303
+ parts.append(definition)
304
+ if parts:
305
+ parts.append("\n")
306
+ return parts
307
+ if isinstance(element, LatexMath):
308
+ parts = []
309
+ if not element.inline:
310
+ parts.append("\n\n")
311
+ parts.append(normalize_tex(element.code, element.inline).strip())
312
+ if not element.inline:
313
+ parts.append("\n\n")
314
+ return parts
315
+ if isinstance(element, (Superscript, Subscript)):
316
+ content = element.plaintext
317
+ if content.strip().isdigit():
318
+ script_map = (
319
+ SUBSCRIPT_MAP if isinstance(element, Subscript) else SUPERSCRIPT_MAP
320
+ )
321
+ return [content.translate(script_map)]
322
+ else:
323
+ return format_children(element, keep_refs)
324
+ if isinstance(element, InlineRef):
325
+ parts = format_children(element, keep_refs)
326
+ return parts
327
+ return format_children(element, keep_refs, latex_env)
328
+
329
+
330
+ def format_iterator(
331
+ iterator: Iterable,
332
+ keep_refs: bool = False,
333
+ latex_env: bool = False,
334
+ join: Optional[str] = None,
335
+ ) -> List[str]:
336
+ """
337
+ The `format_iterator` function takes an iterator and formats its elements, optionally joining them with a specified string.
338
+
339
+ :param iterator: The `iterator` parameter is an iterable object that contains the elements to be formatted. It could be a list, tuple, set, or any other iterable object
340
+ :type iterator: Iterable
341
+ :param keep_refs: The `keep_refs` parameter is a boolean flag that determines whether references to other elements should be preserved in the formatted output. If `keep_refs` is set to `True`, the references will be included in the output. If `keep_refs` is set to `False` (default), the, defaults to False
342
+ :type keep_refs: bool (optional)
343
+ :param latex_env: The `latex_env` parameter is a boolean flag that determines whether the output should be formatted as LaTeX code. If `latex_env` is set to `True`, the output will be formatted using LaTeX syntax. If `latex_env` is set to `False` (default), the output will be, defaults to False
344
+ :type latex_env: bool (optional)
345
+ :param join: The `join` parameter is an optional string that specifies the delimiter to be used when joining the formatted elements of the iterator into a single string. If `join` is provided, it will be inserted between each formatted element. If `join` is not provided, the formatted elements will be returned as
346
+ :type join: Optional[str]
347
+ :return: The function `format_iterator` returns a list of strings.
348
+ """
349
+ parts = []
350
+ for child in iterator:
351
+ parts.extend(format_element(child, keep_refs, latex_env))
352
+ if join is not None:
353
+ parts.append(join)
354
+ if join is not None:
355
+ parts = parts[:-1]
356
+ return parts
357
+
358
+
359
+ def format_children(
360
+ element: Element, keep_refs: bool = False, latex_env: bool = False
361
+ ) -> List[str]:
362
+ if element is None:
363
+ return []
364
+ return format_iterator(element.children, keep_refs, latex_env)
365
+
366
+
367
+ def format_document(
368
+ doc: Document, keep_refs: bool = False
369
+ ) -> Tuple[str, Dict[str, str]]:
370
+ """
371
+ The `format_document` function takes a `doc` object of type `Document` and a boolean `keep_refs` as input and returns a tuple containing the formatted text of the document and a dictionary of figures found in the document.
372
+
373
+ :param doc: The `doc` parameter is of type `Document`, which is presumably a custom class representing a document
374
+ :type doc: Document
375
+ :param keep_refs: The `keep_refs` parameter is a boolean flag that determines whether to keep references in the formatted document or not. If `keep_refs` is set to `True`, the references will be included in the formatted document. If `keep_refs` is set to `False`, the references will be excluded, defaults to False
376
+ :type keep_refs: bool (optional)
377
+ :return: The function `format_document` returns a tuple containing two elements: a formatted text document and a dictionary of figures.
378
+ """
379
+ parts = []
380
+
381
+ if doc.title:
382
+ parts.extend([*format_element(doc.title), "\n"])
383
+ parts.append("\n")
384
+ parts.extend(format_children(doc, keep_refs))
385
+ text = "".join(parts)
386
+ text = text.replace("\xa0", " ") # replace non-breakable spaces
387
+ text = re.sub(r" $", "", text, flags=re.MULTILINE)
388
+ text = re.sub(r"\n[\t ]*$", "\n", text, flags=re.MULTILINE)
389
+ text = re.sub(r"(?<!\n) {2,}", " ", text)
390
+ text = re.sub(r"\n{3,}", "\n\n", text).lstrip()
391
+ figures = {unidecode(m[0] + m[1]): m[2].strip() for m in figure_regex.findall(text)}
392
+ text = figure_regex.sub(
393
+ r"[\1\2][END\1]",
394
+ text,
395
+ )
396
+ return text, figures
nougat/dataset/pdffigures.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ import os
8
+ import subprocess
9
+ import logging
10
+
11
+ PDFFIGURES2_JAR_PATH = os.environ.get("PDFFIGURES_PATH", None)
12
+ logger = logging.getLogger()
13
+ if PDFFIGURES2_JAR_PATH is None:
14
+ logger.warning(
15
+ "You need to configure the path to the pdffigures2 executable in this file (nougat/dataset/pdffigures.py) or set the environment variable 'PDFFIGURES_PATH'."
16
+ )
17
+
18
+
19
+ def call_pdffigures(
20
+ pdf_path: str, figures_dir: str, timeout: int = 30, verbose: bool = False
21
+ ):
22
+ """
23
+ Extract figures from a PDF file using pdffigures2.
24
+
25
+ Args:
26
+ pdf_path (str): The path to the PDF file.
27
+ figures_dir (str): The directory where the figures will be extracted.
28
+ timeout (int, optional): The timeout in seconds for the pdffigures2 command. Defaults to 30.
29
+ verbose (bool, optional): Whether to print the output of the pdffigures2 command. Defaults to False.
30
+
31
+ Returns:
32
+ str: The path to the JSON file containing the extracted figures.
33
+ """
34
+ os.makedirs(figures_dir, exist_ok=True)
35
+ kwargs = (
36
+ {} if verbose else {"stderr": subprocess.DEVNULL, "stdout": subprocess.DEVNULL}
37
+ )
38
+ if PDFFIGURES2_JAR_PATH is None:
39
+ return
40
+ process = subprocess.Popen(
41
+ "java"
42
+ " -jar {pdffigures_jar_path}"
43
+ " -d {figures_dir}/"
44
+ " -c"
45
+ " -q"
46
+ " {pdf_path}".format(
47
+ pdffigures_jar_path=PDFFIGURES2_JAR_PATH,
48
+ pdf_path=pdf_path,
49
+ figures_dir=figures_dir,
50
+ ),
51
+ shell=True,
52
+ **kwargs
53
+ )
54
+
55
+ try:
56
+ exit_code = process.wait(timeout=timeout)
57
+ if exit_code != 0:
58
+ logger.error("Extracting figures from file %s failed.", pdf_path)
59
+ return False
60
+ except subprocess.TimeoutExpired as e:
61
+ logger.error(
62
+ "pdffigures2 command did not terminate in 30 seconds, "
63
+ "terminating. Error: %s",
64
+ e,
65
+ )
66
+ process.terminate() # give up
67
+ return False
68
+ pdf_name = os.path.basename(pdf_path).partition(".pdf")[0]
69
+ dest_file = os.path.join(figures_dir, (pdf_name + ".json"))
70
+
71
+ return dest_file
nougat/dataset/rasterize.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ import argparse
8
+ import logging
9
+ import pypdfium2
10
+ from pathlib import Path
11
+ from tqdm import tqdm
12
+ import io
13
+ from typing import Optional, List, Union
14
+
15
+ logging.getLogger("pypdfium2").setLevel(logging.WARNING)
16
+
17
+
18
+ def rasterize_paper(
19
+ pdf: Union[Path, bytes],
20
+ outpath: Optional[Path] = None,
21
+ dpi: int = 96,
22
+ return_pil=False,
23
+ pages=None,
24
+ ) -> Optional[List[io.BytesIO]]:
25
+ """
26
+ Rasterize a PDF file to PNG images.
27
+
28
+ Args:
29
+ pdf (Path): The path to the PDF file.
30
+ outpath (Optional[Path], optional): The output directory. If None, the PIL images will be returned instead. Defaults to None.
31
+ dpi (int, optional): The output DPI. Defaults to 96.
32
+ return_pil (bool, optional): Whether to return the PIL images instead of writing them to disk. Defaults to False.
33
+ pages (Optional[List[int]], optional): The pages to rasterize. If None, all pages will be rasterized. Defaults to None.
34
+
35
+ Returns:
36
+ Optional[List[io.BytesIO]]: The PIL images if `return_pil` is True, otherwise None.
37
+ """
38
+ pils = []
39
+ if outpath is None:
40
+ return_pil = True
41
+ try:
42
+ if isinstance(pdf, (str, Path)):
43
+ pdf = pypdfium2.PdfDocument(pdf)
44
+ if pages is None:
45
+ pages = range(len(pdf))
46
+ renderer = pdf.render(
47
+ pypdfium2.PdfBitmap.to_pil,
48
+ page_indices=pages,
49
+ scale=dpi / 72,
50
+ )
51
+ for i, image in zip(pages, renderer):
52
+ if return_pil:
53
+ page_bytes = io.BytesIO()
54
+ image.save(page_bytes, "bmp")
55
+ pils.append(page_bytes)
56
+ else:
57
+ image.save((outpath / ("%02d.png" % (i + 1))), "png")
58
+ except Exception as e:
59
+ logging.error(e)
60
+ if return_pil:
61
+ return pils
62
+
63
+
64
+ if __name__ == "__main__":
65
+ parser = argparse.ArgumentParser()
66
+ parser.add_argument("--pdfs", nargs="+", type=Path, help="PDF files", required=True)
67
+ parser.add_argument("--out", type=Path, help="Output dir", default=None)
68
+ parser.add_argument(
69
+ "--dpi", type=int, default=96, help="What resolution the pages will be saved"
70
+ )
71
+ parser.add_argument(
72
+ "--pages", type=int, nargs="+", default=None, help="list of page numbers"
73
+ )
74
+ args = parser.parse_args()
75
+ if args.pages:
76
+ args.pages = [p - 1 for p in args.pages]
77
+ for pdf_file in tqdm(args.pdfs):
78
+ assert pdf_file.exists() and pdf_file.is_file()
79
+ outpath: Path = args.out or (pdf_file.parent / pdf_file.stem)
80
+ outpath.mkdir(exist_ok=True)
81
+ rasterize_paper(pdf_file, outpath, pages=args.pages, dpi=args.dpi)
nougat/dataset/split_htmls_to_pages.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ import argparse
8
+ from io import BytesIO
9
+ import multiprocessing
10
+ from pebble import ProcessPool
11
+ from concurrent.futures import TimeoutError
12
+ from tqdm import tqdm
13
+ from typing import Tuple
14
+ import os
15
+ from pathlib import Path
16
+ import logging
17
+ import pypdf
18
+ from PIL import Image
19
+ import pytesseract
20
+ from nougat.dataset.split_md_to_pages import *
21
+ from nougat.dataset.parser.html2md import *
22
+ from nougat.dataset.pdffigures import call_pdffigures
23
+
24
+ logging.basicConfig()
25
+ logger = logging.getLogger()
26
+ logger.setLevel(logging.INFO)
27
+
28
+
29
+ def process_paper(
30
+ fname: str,
31
+ pdf_file: Path,
32
+ html_file: Path,
33
+ json_file: Path,
34
+ args: argparse.Namespace,
35
+ ) -> Tuple[int, int]:
36
+ """
37
+ Process a single paper.
38
+
39
+ Args:
40
+ fname (str): The paper's filename.
41
+ pdf_file (Path): The path to the PDF file.
42
+ html_file (Path): The path to the HTML file.
43
+ json_file (Path): The path to the JSON file containing the extracted figures.
44
+ args (argparse.Namespace): The command-line arguments.
45
+
46
+ Returns:
47
+ Tuple[int, int]: The number of total pages and the number of recognized pages.
48
+ """
49
+ total_pages = 0
50
+ num_recognized_pages = 0
51
+ try:
52
+ pdf = pypdf.PdfReader(pdf_file)
53
+ total_pages = len(pdf.pages)
54
+ outpath: Path = args.out / fname
55
+ # skip this paper if already processed
56
+ dirs_with_same_stem = list(args.out.glob(fname.partition("v")[0] + "*"))
57
+ if (
58
+ len(dirs_with_same_stem) > 0
59
+ and len(list(dirs_with_same_stem[0].iterdir())) > 0
60
+ and not args.recompute
61
+ ):
62
+ logger.info(
63
+ "%s (or another version thereof) already processed. Skipping paper",
64
+ fname,
65
+ )
66
+ return total_pages, len(list(outpath.glob("*.mmd")))
67
+ html = BeautifulSoup(
68
+ htmlmin.minify(
69
+ open(html_file, "r", encoding="utf-8").read().replace("\xa0", " "),
70
+ remove_all_empty_space=True,
71
+ ),
72
+ features="html.parser",
73
+ )
74
+ doc = parse_latexml(html)
75
+ if doc is None:
76
+ return
77
+ out, fig = format_document(doc, keep_refs=True)
78
+
79
+ if args.markdown:
80
+ md_out = args.markdown / (fname + ".mmd")
81
+ with open(md_out, "w", encoding="utf-8") as f:
82
+ f.write(out)
83
+
84
+ if json_file is None:
85
+ json_file = call_pdffigures(pdf_file, args.figure)
86
+ if json_file:
87
+ figure_info = json.load(open(json_file, "r", encoding="utf-8"))
88
+ else:
89
+ figure_info = None
90
+ split = split_markdown(
91
+ out, pdf_file, figure_info=figure_info, doc_fig=fig, min_score=0.9
92
+ )
93
+ if split is None:
94
+ return
95
+ pages, meta = split
96
+ num_recognized_pages = sum([len(p) > 0 for p in pages])
97
+ if all([len(p) == 0 for p in pages]):
98
+ return
99
+ os.makedirs(outpath, exist_ok=True)
100
+ recognized_indices = []
101
+ for i, content in enumerate(pages):
102
+ with (outpath / "meta.json").open("w", encoding="utf-8") as f:
103
+ f.write(json.dumps(meta))
104
+ if content:
105
+ if re.search(r"\[(?:\?\?(?:. )?)+\]", content):
106
+ # there are wrongly parsed references in the page eg [??].
107
+ continue
108
+ with (outpath / ("%02d.mmd" % (i + 1))).open(
109
+ "w", encoding="utf-8"
110
+ ) as f:
111
+ f.write(content)
112
+ recognized_indices.append(i)
113
+ rasterize_paper(pdf_file, outpath, dpi=args.dpi, pages=recognized_indices)
114
+ if args.tesseract:
115
+ for i in recognized_indices:
116
+ ocr = pytesseract.image_to_string(
117
+ Image.open((outpath / ("%02d.png" % (i + 1)))), lang="eng"
118
+ )
119
+ ocr = re.sub(r"\n+\s+?([^\s])", r"\n\n\1", ocr).strip()
120
+ with (outpath / ("%02d_OCR.txt" % (i + 1))).open(
121
+ "w", encoding="utf-8"
122
+ ) as f_ocr:
123
+ f_ocr.write(ocr)
124
+ except Exception as e:
125
+ logger.error(e)
126
+
127
+ return total_pages, num_recognized_pages
128
+
129
+
130
+ def process_htmls(args):
131
+ for input_dir in (args.pdfs, args.html):
132
+ if not input_dir.exists() and not input_dir.is_dir():
133
+ logger.error("%s does not exist or is no dir.", input_dir)
134
+ return
135
+ htmls: List[Path] = args.html.glob("*.html")
136
+ args.out.mkdir(exist_ok=True)
137
+ if args.markdown:
138
+ args.markdown.mkdir(exist_ok=True)
139
+
140
+ with ProcessPool(max_workers=args.workers) as pool:
141
+ total_pages, total_pages_extracted = 0, 0
142
+ tasks = {}
143
+ for j, html_file in enumerate(htmls):
144
+ fname = html_file.stem
145
+ pdf_file = args.pdfs / (fname + ".pdf")
146
+ if not pdf_file.exists():
147
+ logger.info("%s pdf could not be found.", fname)
148
+ continue
149
+ json_file = args.figure / (fname + ".json")
150
+ if not json_file.exists():
151
+ logger.info("%s figure json could not be found.", fname)
152
+ json_file = None
153
+ tasks[fname] = pool.schedule(
154
+ process_paper,
155
+ args=[fname, pdf_file, html_file, json_file, args],
156
+ timeout=args.timeout,
157
+ )
158
+
159
+ for fname in tqdm(tasks):
160
+ try:
161
+ res = tasks[fname].result()
162
+ if res is None:
163
+ logger.info("%s is faulty", fname)
164
+ continue
165
+ num_pages, num_recognized_pages = res
166
+ total_pages += num_pages
167
+ total_pages_extracted += num_recognized_pages
168
+ logger.info(
169
+ "%s: %i/%i pages recognized. Percentage: %.2f%%",
170
+ fname,
171
+ num_recognized_pages,
172
+ num_pages,
173
+ (100 * num_recognized_pages / max(1, num_pages)),
174
+ )
175
+ except TimeoutError:
176
+ logger.info("%s timed out", fname)
177
+ if total_pages > 0:
178
+ logger.info(
179
+ "In total: %i/%i pages recognized. Percentage: %.2f%%",
180
+ total_pages_extracted,
181
+ total_pages,
182
+ (100 * total_pages_extracted / max(1, total_pages)),
183
+ )
184
+
185
+
186
+ if __name__ == "__main__":
187
+ parser = argparse.ArgumentParser()
188
+ parser.add_argument("--html", type=Path, help="HTML files", required=True)
189
+ parser.add_argument("--pdfs", type=Path, help="PDF files", required=True)
190
+ parser.add_argument("--out", type=Path, help="Output dir", required=True)
191
+ parser.add_argument("--recompute", action="store_true", help="recompute all splits")
192
+ parser.add_argument(
193
+ "--markdown", type=Path, help="Markdown output dir", default=None
194
+ )
195
+ parser.add_argument(
196
+ "--figure",
197
+ type=Path,
198
+ help="Figure info JSON dir",
199
+ )
200
+ parser.add_argument(
201
+ "--workers",
202
+ type=int,
203
+ default=multiprocessing.cpu_count(),
204
+ help="How many processes to use",
205
+ )
206
+ parser.add_argument(
207
+ "--dpi", type=int, default=96, help="What resolution the pages will be saved at"
208
+ )
209
+ parser.add_argument(
210
+ "--timeout", type=float, default=120, help="max time per paper in seconds"
211
+ )
212
+ parser.add_argument(
213
+ "--tesseract",
214
+ action="store_true",
215
+ help="Tesseract OCR prediction for each page",
216
+ )
217
+ args = parser.parse_args()
218
+ print(args)
219
+ process_htmls(args)
nougat/dataset/split_md_to_pages.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ import argparse
8
+ from collections import Counter
9
+ from copy import deepcopy
10
+ import json
11
+ import math
12
+ from operator import itemgetter
13
+ import re
14
+ from typing import Dict, List, Tuple, Union, Optional
15
+ import os
16
+ import pypdf
17
+ from unidecode import unidecode
18
+ import Levenshtein
19
+
20
+ import numpy as np
21
+ from sklearn.feature_extraction.text import CountVectorizer
22
+ from sklearn.feature_extraction.text import TfidfTransformer
23
+ from sklearn.linear_model import SGDClassifier
24
+
25
+ from nougat.dataset.staircase import Staircase
26
+ from nougat.dataset.splitter import (
27
+ Splitter,
28
+ get_first_last,
29
+ get_glob_index,
30
+ )
31
+ from nougat.dataset.utils import unicode_to_latex, remove_pretty_linebreaks
32
+ from nougat.dataset.utils.pdf_text_extract import get_pages, get_paragraphs
33
+ from nougat.dataset.rasterize import rasterize_paper
34
+
35
+
36
+ class BagOfWords:
37
+ """
38
+ A bag-of-words model for text classification.
39
+
40
+ Args:
41
+ sentences (List[str]): The training sentences.
42
+ target (Optional[List[int]]): The target labels for the training sentences. Defaults to None.
43
+
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ sentences: List[str],
49
+ target: Optional[List[int]] = None,
50
+ ) -> None:
51
+ self.sentences = sentences
52
+ self.target = target
53
+ self.train()
54
+
55
+ def train(self):
56
+ if self.target is None:
57
+ self.target = np.arange(len(self.sentences))
58
+ self.count_vect = CountVectorizer()
59
+ X_train_counts = self.count_vect.fit_transform(self.sentences)
60
+ self.tfidf_transformer = TfidfTransformer(use_idf=True)
61
+ X_train_tfidf = self.tfidf_transformer.fit_transform(X_train_counts)
62
+ self.clf = SGDClassifier(
63
+ loss="hinge",
64
+ penalty="l2",
65
+ alpha=1e-3,
66
+ random_state=42,
67
+ max_iter=5,
68
+ tol=None,
69
+ )
70
+ self.clf.fit(X_train_tfidf, self.target)
71
+
72
+ def __call__(
73
+ self, text: Union[str, List[str]], lob_probs: bool = False
74
+ ) -> np.ndarray:
75
+ if type(text) == str:
76
+ text = [text]
77
+ X_new_counts = self.count_vect.transform(text)
78
+ X_new_tfidf = self.tfidf_transformer.transform(X_new_counts)
79
+ if lob_probs:
80
+ return self.clf.predict_log_proba(X_new_tfidf)
81
+ else:
82
+ return self.clf.predict(X_new_tfidf)
83
+
84
+
85
+ def remove_short_seqs(seqs: List[str], minimum: int = 10) -> List[str]:
86
+ """Remove sequences shorter than the specified minimum length."""
87
+ out = []
88
+ for seq in seqs:
89
+ if len(seq) > minimum:
90
+ out.append(seq)
91
+ return out
92
+
93
+
94
+ def find_figures(
95
+ pdf_pages: List[List[str]], figure_info: Union[Dict, List]
96
+ ) -> List[Tuple[int, int]]:
97
+ """ "
98
+ Find the locations of figures in a PDF file.
99
+
100
+ Args:
101
+ pdf_pages (List[List[str]]): The text of the PDF pages.
102
+ figure_info (Union[Dict, List]): A dictionary or list of dictionaries, where each dictionary
103
+ specifies the information about a figure, such as its caption, page number, and bounding box.
104
+
105
+ Returns:
106
+ List[Tuple[int, int]]: A list of tuples, where each tuple contains the figure index, page number,
107
+ start position, and end position of the figure in the PDF file.
108
+ """
109
+ figure_locations = []
110
+ iterator = figure_info.values() if type(figure_info) == dict else [figure_info]
111
+ for figure_list in iterator:
112
+ for i, f in enumerate(figure_list):
113
+ if "caption" in f:
114
+ fig_string = f["caption"]
115
+ elif "text" in f:
116
+ fig_string = f["text"]
117
+ else:
118
+ continue
119
+ fig_string = unicode_to_latex(fig_string)
120
+ if f["page"] >= len(pdf_pages):
121
+ continue
122
+ block, score = Splitter.fuzzysearch(
123
+ "\n".join(pdf_pages[f["page"]]),
124
+ fig_string,
125
+ )
126
+ if score > 0.8 and block[2] > 0:
127
+ figure_locations.append((i, f["page"], block[0], block[2]))
128
+ return figure_locations
129
+
130
+
131
+ def flatten(l: List) -> List:
132
+ return [item for sublist in l for item in sublist]
133
+
134
+
135
+ def get_doc_text(
136
+ pdf: str,
137
+ splitn: bool = True,
138
+ split_block: bool = True,
139
+ minlen: Optional[int] = 10,
140
+ ) -> List[List[str]]:
141
+ """
142
+ Get the text from a PDF document.
143
+
144
+ Args:
145
+ doc (str): Path to the PDF document.
146
+ splitn (bool): Whether to split the text into lines. Defaults to True.
147
+ split_block (bool): Whether to split the text into blocks. Defaults to True.
148
+ minlen (Optional[int]): The minimum length of a line or block. Defaults to 10.
149
+
150
+ Returns:
151
+ List[List[str]]: The text of the PDF document, either as a list of lines or a list of blocks..
152
+ """
153
+ document_lines = []
154
+ if split_block:
155
+ pages = get_paragraphs(pdf)
156
+ else:
157
+ pages = [get_pages(pdf)]
158
+ for blocks in pages:
159
+ page_lines = []
160
+ for block in blocks:
161
+ if splitn:
162
+ page_lines.extend(block.split("\n"))
163
+ else:
164
+ page_lines.append(block)
165
+ if splitn:
166
+ page_lines = remove_short_seqs(page_lines, minlen)
167
+ document_lines.append(page_lines)
168
+ return document_lines
169
+
170
+
171
+ def clean_pdf_text(pages: List[List[str]], num_words: int = 10) -> List[List[str]]:
172
+ """
173
+ Clean the text of a PDF document by removing frequent words from the beginning and end of each page.
174
+
175
+ Args:
176
+ pages (List[List[str]]): The text of the PDF document, as a list of lists of strings.
177
+ num_words (int, optional): The number of words to consider at the beginning and end of each page. Defaults to 10.
178
+
179
+ Returns:
180
+ List[List[str]]: The cleaned text of the PDF document.
181
+ """
182
+ words = []
183
+ for page in pages:
184
+ first = get_first_last(
185
+ " ".join(page).lower(), num_words=num_words, first_only=True
186
+ )
187
+ words.extend(first.split(" "))
188
+ word_counts = Counter(words)
189
+ common_words = [
190
+ "the",
191
+ "of",
192
+ "a",
193
+ "and",
194
+ "to",
195
+ "in",
196
+ "is",
197
+ "that",
198
+ "for",
199
+ "are",
200
+ "this",
201
+ "we",
202
+ "figure",
203
+ "fig.",
204
+ "",
205
+ ]
206
+ frequent_words = []
207
+ for w, f in word_counts.items():
208
+ if w in common_words or w.startswith("\\"):
209
+ continue
210
+ if f / len(pages) >= 0.4:
211
+ frequent_words.append(w)
212
+ if len(frequent_words) == 0:
213
+ return pages
214
+ # remove frequent words from page beginning/end
215
+ for i in range(len(pages)):
216
+ page = pages[i]
217
+ stop = 0
218
+ page_num_words = 0
219
+ for p in page:
220
+ page_num_words += len(p.split(" "))
221
+ stop += 1
222
+ if page_num_words >= num_words:
223
+ break
224
+ for w in frequent_words:
225
+ for j in range(stop):
226
+ if w == "-": # probably page number - \d -
227
+ pages[i][j] = re.sub(
228
+ r"-\s*\d{1,3}\s*-", "", pages[i][j], flags=re.IGNORECASE
229
+ )
230
+ pages[i][j] = re.sub(re.escape(w), "", pages[i][j], flags=re.IGNORECASE)
231
+ return pages
232
+
233
+
234
+ def split_markdown(
235
+ doc: str,
236
+ pdf_file: str,
237
+ figure_info: Optional[List[Dict]] = None,
238
+ doc_fig: Dict[str, str] = {},
239
+ minlen: int = 3,
240
+ min_num_words: int = 22,
241
+ doc_paragraph_chars: int = 1000,
242
+ min_score: float = 0.75,
243
+ staircase: bool = True,
244
+ ) -> Tuple[List[str], Dict]:
245
+ """
246
+ Split a PDF document into Markdown paragraphs.
247
+
248
+ Args:
249
+ doc (str): The text of the Markdown document.
250
+ pdf (str): The PDF document.
251
+ figure_info (Optional[List[Dict]]): A list of dictionaries, where each dictionary
252
+ specifies the information about a figure, such as its caption, page number, and bounding box.
253
+ doc_fig (Dict[str, str]): A dictionary mapping figure ids to LaTeX code.
254
+ minlen (int): The minimum length of a Markdown paragraph.
255
+ min_num_words: The minimum number of words in a Markdown paragraph.
256
+ doc_paragraph_chars: The maximum number of characters in a Markdown paragraph.
257
+ min_score: The minimum score for a Markdown paragraph to be split.
258
+ staircase: Whether to split the document into paragraphs with a staircase pattern.
259
+
260
+ Returns:
261
+ Tuple[List[str], Dict]: The list of Markdown paragraphs and the metadata.
262
+ """
263
+ pdf = pypdf.PdfReader(pdf_file)
264
+ doc_paragraphs_full: List[str] = doc.split("\n")
265
+ doc_paragraph_lengths = [len(p) for p in doc_paragraphs_full if len(p) > 1]
266
+ num_lines = 1 + int(doc_paragraph_chars / np.mean(doc_paragraph_lengths))
267
+ doc_paragraphs_full = [
268
+ unidecode("\n".join(doc_paragraphs_full[i : i + num_lines]))
269
+ for i in range(0, len(doc_paragraphs_full), num_lines)
270
+ ]
271
+ doc_paragraphs: List[str] = []
272
+ doc_paragraph_indices: List[int] = []
273
+ for i, p in enumerate(doc_paragraphs_full):
274
+ if len(p) > 1:
275
+ doc_paragraphs.append(
276
+ re.sub(r"(\[(FOOTNOTE|FIGURE|TABLE).*?END\2\])", "", p)
277
+ )
278
+ doc_paragraph_indices.append(i)
279
+ meta = {"pdffigures": figure_info}
280
+ if len(pdf.pages) > 1:
281
+ pdf_text = get_doc_text(pdf_file, True, True, minlen)
282
+ pdf_content = [
283
+ [unicode_to_latex(q).replace("\n", " ") for q in p if len(q) >= minlen]
284
+ for p in pdf_text
285
+ ]
286
+
287
+ pdf_content = clean_pdf_text(pdf_content)
288
+ if figure_info is not None:
289
+ figure_locations = sorted(
290
+ find_figures(pdf_content, figure_info), key=itemgetter(2), reverse=True
291
+ )
292
+ clean_pdf_content = deepcopy(pdf_content)
293
+ for i, page_content in enumerate(pdf_content):
294
+ len_sentences = np.cumsum([0] + [len(p) for p in page_content])
295
+ for match in figure_locations:
296
+ _, page, start, len_ = match
297
+ if i != page:
298
+ continue
299
+ a, b = (
300
+ get_glob_index(len_sentences, start),
301
+ get_glob_index(len_sentences, start + len_) + 1,
302
+ )
303
+ for j, k in enumerate(range(a, b + 1)):
304
+ if len(clean_pdf_content[i]) == k:
305
+ break
306
+ if j == 0:
307
+ clean_pdf_content[i][k] = clean_pdf_content[i][k][
308
+ : start - len_sentences[k]
309
+ ]
310
+ elif k == b:
311
+ clean_pdf_content[i][k] = clean_pdf_content[i][k][
312
+ start + len_ - len_sentences[k] :
313
+ ]
314
+ else:
315
+ clean_pdf_content[i][k] = ""
316
+ clean_pdf_content[i] = remove_short_seqs(clean_pdf_content[i], 0)
317
+ pdf_content = clean_pdf_content
318
+ paragraphs = flatten(pdf_content)
319
+ num_paragraphs = np.cumsum([0] + [len(page) for page in pdf_content])
320
+ if staircase:
321
+ # train bag of words
322
+ page_target = np.zeros(len(paragraphs))
323
+ page_target[num_paragraphs[1:-1] - 1] = 1
324
+ page_target = np.cumsum(page_target).astype(int)
325
+ model = BagOfWords(paragraphs, target=page_target)
326
+ labels = model(doc_paragraphs)
327
+
328
+ # fit stair case function
329
+ x = np.arange(len(labels))
330
+ stairs = Staircase(len(labels), labels.max() + 1)
331
+ stairs.fit(x, labels)
332
+ boundaries = (stairs.get_boundaries().astype(int)).tolist()
333
+ boundaries.insert(0, 0)
334
+ else:
335
+ boundaries = [0] * (len(pdf.pages))
336
+ splitter = Splitter(doc_paragraphs)
337
+ pages = [(0, 0, 1.0)]
338
+ meta["first_words"] = []
339
+ meta["last_words"] = []
340
+ for i in range(1, len(boundaries)):
341
+ delta = (
342
+ math.ceil(stairs.uncertainty[i - 1]) + 5
343
+ if staircase
344
+ else len(doc_paragraphs)
345
+ )
346
+ words_f = []
347
+ words_l = []
348
+ for p in pdf_content[i]:
349
+ words_f.extend(p.split(" "))
350
+ if len(words_f) >= min_num_words:
351
+ break
352
+ for p in pdf_content[i - 1][::-1]:
353
+ words_l.extend(p.split(" ")[::-1])
354
+ if len(words_l) >= min_num_words:
355
+ words_l = words_l[::-1]
356
+ break
357
+ if len(words_f) < 2:
358
+ pages.append(pages[-1])
359
+ first_words = " ".join(words_f[:min_num_words]).strip()
360
+ last_words = " ".join(words_l[-min_num_words:]).strip()
361
+ meta["first_words"].append(first_words)
362
+ meta["last_words"].append(last_words)
363
+ if len(first_words) < minlen and len(last_words) < minlen:
364
+ pages.append(pages[-1])
365
+ continue
366
+ pages.append(
367
+ splitter.split_first_last(
368
+ boundaries[i],
369
+ first_words,
370
+ last_words,
371
+ delta=delta,
372
+ )
373
+ )
374
+ elif len(pdf.pages) == 1: # single page
375
+ pages = [(0, 0, 1)]
376
+ else:
377
+ return
378
+ pages.append((len(doc_paragraphs), -1, 1.0))
379
+ out = []
380
+ page_scores = {}
381
+ for i in range(len(pages) - 1):
382
+ score = (pages[i][2] + pages[i + 1][2]) * 0.5
383
+ if score >= min_score:
384
+ end = pages[i + 1][0]
385
+ if end >= len(doc_paragraph_indices):
386
+ end = None
387
+ else:
388
+ end = doc_paragraph_indices[pages[i + 1][0]] + 1
389
+ lines = doc_paragraphs_full[doc_paragraph_indices[pages[i][0]] : end]
390
+ if len(lines) > 0:
391
+ lines[0] = lines[0][pages[i][1] :]
392
+ lines[-1] = lines[-1][: pages[i + 1][1]]
393
+ else:
394
+ lines = []
395
+ page_content = "\n".join(lines)
396
+ page_content = remove_pretty_linebreaks(page_content)
397
+ page_scores[i] = score
398
+ out.append(page_content)
399
+
400
+ meta["page_splits"] = pages
401
+ meta["page_scores"] = page_scores
402
+ meta["num_pages"] = len(pdf.pages)
403
+
404
+ # Reintroduce figures, tables and footnotes
405
+ figure_tex = list(doc_fig.keys()), list(doc_fig.values())
406
+ if len(doc_fig) > 0:
407
+ iterator = figure_info.values() if type(figure_info) == dict else [figure_info]
408
+ for figure_list in iterator:
409
+ if not figure_list:
410
+ continue
411
+ for i, f in enumerate(figure_list):
412
+ if "caption" in f:
413
+ fig_string = f["caption"]
414
+ elif "text" in f:
415
+ fig_string = f["text"]
416
+ else:
417
+ continue
418
+ ratios = []
419
+ for tex in figure_tex[1]:
420
+ if f["figType"] == "Table":
421
+ tex = tex.partition(r"\end{table}")[2]
422
+ ratios.append(Levenshtein.ratio(tex, fig_string))
423
+ k = np.argmax(ratios)
424
+ if ratios[k] < 0.8:
425
+ continue
426
+ if f["page"] < len(out) and out[f["page"]] != "":
427
+ out[f["page"]] += "\n\n" + remove_pretty_linebreaks(
428
+ figure_tex[1][k].strip()
429
+ )
430
+
431
+ for i in range(len(out)):
432
+ foot_match = re.findall(r"\[FOOTNOTE(.*?)\]\[ENDFOOTNOTE\]", out[i])
433
+ for match in foot_match:
434
+ out[i] = out[i].replace(
435
+ "[FOOTNOTE%s][ENDFOOTNOTE]" % match,
436
+ doc_fig.get("FOOTNOTE%s" % match, ""),
437
+ )
438
+
439
+ out[i] = re.sub(r"\[(FIGURE|TABLE)(.*?)\](.*?)\[END\1\]", "", out[i])
440
+ return out, meta
441
+
442
+
443
+ if __name__ == "__main__":
444
+ parser = argparse.ArgumentParser()
445
+ parser.add_argument("--md", type=str, help="Markdown file", required=True)
446
+ parser.add_argument("--pdf", type=str, help="PDF File", required=True)
447
+ parser.add_argument("--out", type=str, help="Out dir", required=True)
448
+ parser.add_argument(
449
+ "--figure",
450
+ type=str,
451
+ help="Figure info JSON",
452
+ )
453
+ parser.add_argument("--dpi", type=int, default=96)
454
+ args = parser.parse_args()
455
+ md = open(args.md, "r", encoding="utf-8").read().replace("\xa0", " ")
456
+ pdf = pypdf.PdfReader(args.pdf)
457
+ try:
458
+ fig_info = json.load(open(args.figure, "r", encoding="utf-8"))
459
+ except FileNotFoundError:
460
+ fig_info = None
461
+ pages, meta = split_markdown(md, pdf, fig_info)
462
+ if args.out:
463
+ outpath = os.path.join(args.out, os.path.basename(args.pdf).partition(".")[0])
464
+ os.makedirs(outpath, exist_ok=True)
465
+ found_pages = []
466
+ for i, content in enumerate(pages):
467
+ if content:
468
+ with open(
469
+ os.path.join(
470
+ outpath, "%02d_s=%.2f.mmd" % (i + 1, meta["page_scores"][i])
471
+ ),
472
+ "w",
473
+ encoding="utf-8",
474
+ ) as f:
475
+ f.write(content)
476
+ found_pages.append(i)
477
+ rasterize_paper(pdf, outpath, dpi=args.dpi, pages=found_pages)
nougat/dataset/splitter.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ from difflib import SequenceMatcher
8
+ from operator import itemgetter
9
+ from typing import List, Tuple, Union
10
+ import re
11
+ import numpy as np
12
+ from Levenshtein.StringMatcher import StringMatcher
13
+ import Levenshtein
14
+ from fuzzysearch import find_near_matches
15
+
16
+ math_start_regex = re.compile(r"(?<!\\)\\[\[\(]", re.M)
17
+ math_end_regex = re.compile(r"(?<!\\)\\[\]\)]", re.M)
18
+
19
+
20
+ def reverse(lst: List[str]) -> List[str]:
21
+ """Reverses a list and the strings inside
22
+
23
+ Args:
24
+ lst (List[str]): List to process
25
+
26
+ Returns:
27
+ List[str]: Reversed list
28
+ """
29
+ out = lst[::-1]
30
+ for i in range(len(out)):
31
+ out[i] = out[i][::-1]
32
+ return out
33
+
34
+
35
+ def get_first_last(
36
+ s: str,
37
+ num_words: int = 8,
38
+ delim: str = " ",
39
+ first_only: bool = False,
40
+ last_only: bool = False,
41
+ ) -> Union[Tuple[str, str], str]:
42
+ """
43
+ Get the first and last `num_words` from a string `s`.
44
+
45
+ Args:
46
+ s (str): The string.
47
+ num_words (int): The number of words.
48
+ delim (str): The delimiter between words.
49
+ first_only (bool): Whether to only get the first `num_words`.
50
+ last_only (bool): Whether to only get the last `num_words`.
51
+
52
+ Returns:
53
+ Union[Tuple[str, str], str]: The first and last `num_words` from `s`, or `s` if `num_words` is 0.
54
+ """
55
+ s = s.split(delim)
56
+ if not first_only and not last_only:
57
+ return delim.join(s[:num_words]), delim.join(s[-num_words:])
58
+ elif first_only:
59
+ return delim.join(s[:num_words])
60
+ elif last_only:
61
+ return delim.join(s[-num_words:])
62
+
63
+
64
+ def get_glob_index(
65
+ lengths: List[int], ind: int, return_breakpoints: bool = False
66
+ ) -> int:
67
+ """returns the index where ind is closest and greater than the lengths"""
68
+ breakpoints = np.cumsum(lengths)
69
+ overlap = breakpoints - ind
70
+ overlap[overlap > 0] = -int(1e5)
71
+ indices = overlap.argmax(0)
72
+ if return_breakpoints:
73
+ return indices, breakpoints
74
+ else:
75
+ return indices
76
+
77
+
78
+ # table-header-figure regex
79
+ # thf_regex = re.compile(r"(\[(FOOTNOTE|FIGURE|TABLE).*?END\2\])")
80
+
81
+
82
+ class Splitter:
83
+ _split_locs: List[Tuple[int, int]] = None
84
+
85
+ def __init__(self, paragraphs: List[str]) -> None:
86
+ self.paragraphs = paragraphs
87
+ self.paragraphs_no_space = [self.remove_special_chars(h) for h in paragraphs]
88
+ self._split_locs = [(0, 0)]
89
+ self.paragraphs_rev = reverse(self.paragraphs)
90
+ self.paragraphs_rev_no_space = reverse(self.paragraphs_no_space)
91
+
92
+ @staticmethod
93
+ def remove_special_chars(string: str) -> str:
94
+ # string = thf_regex.sub(r"", string)
95
+ return (
96
+ string.replace("\\ ", "")
97
+ .replace(" ", "")
98
+ .replace("\n", "")
99
+ .replace("*", "")
100
+ .replace("_", "")
101
+ .replace("^", "")
102
+ .replace("\\[", "")
103
+ .replace("\\]", "")
104
+ .replace("\\(", "")
105
+ .replace("\\)", "")
106
+ .replace("\\right", "")
107
+ .replace("\\left", "")
108
+ .replace("\\sum", "X") # old latex unicode encoding issue
109
+ .replace("{", "")
110
+ .replace("}", "")
111
+ .replace("#", "")
112
+ .replace("[REF]", "")
113
+ .replace("[ENDREF]", "")
114
+ .replace("\\varphi", "\\phi") # https://meta.stackexchange.com/a/349360
115
+ .replace("\\quad", "")
116
+ .replace("\\qquad", "")
117
+ .replace("\\hskip", "")
118
+ .replace("\\vskip", "")
119
+ .replace("\\frac", "")
120
+ .replace("\\rm", "")
121
+ .replace("\\,", "")
122
+ .replace("-", "")
123
+ .lower()
124
+ )
125
+
126
+ @staticmethod
127
+ def count_special_chars(string: str, char_ind: int) -> int:
128
+ if len(string) == 0:
129
+ return 0
130
+ add_space_ind = 0
131
+ while True:
132
+ string_ = string[: char_ind + add_space_ind]
133
+ # last_first = string[: char_ind + add_space_ind+]
134
+ add = (
135
+ string_.count(" ")
136
+ + string_.count("\\ ") * 2
137
+ + string_.count("\n")
138
+ + string_.count("*")
139
+ + string_.count("_")
140
+ + string_.count("^")
141
+ + string_.count("\\[") * 2
142
+ + string_.count("\\]") * 2
143
+ + string_.count("\\(") * 2
144
+ + string_.count("\\)") * 2
145
+ + string_.count("\\right") * 6
146
+ + string_.count("\\left") * 5
147
+ + string_.count("\\sum") * 3 # replaced to X that's why not 4
148
+ + string_.count("{")
149
+ + string_.count("}")
150
+ + string_.count("#")
151
+ + string_.count("[REF]") * 5
152
+ + string_.count("[ENDREF]") * 8
153
+ + string_.count("\\varphi") * 3
154
+ + string_.count("\\quad") * 5
155
+ + string_.count("\\qquad") * 6
156
+ + string_.count("\\hskip") * 6
157
+ + string_.count("\\vskip") * 6
158
+ + string_.count("\\frac") * 5
159
+ + string_.count("\\rm") * 3
160
+ + string_.count("\\,") * 2
161
+ + string_.count("-")
162
+ )
163
+ if add == add_space_ind:
164
+ break
165
+ add_space_ind = add
166
+ if len(string) <= char_ind + add_space_ind:
167
+ add_space_ind = max(0, len(string) - 1 - char_ind)
168
+
169
+ # check first chars of rest if they match closing expressions
170
+ while True:
171
+ rest = string[char_ind + add_space_ind :]
172
+ string_ = string[: char_ind + add_space_ind]
173
+ section_title = re.match(r"#+\s?\d*\s*$", string_)
174
+ if rest.startswith("\\]") or rest.startswith("\\)"):
175
+ add_space_ind += 2
176
+ elif (rest.startswith(")") or rest.startswith("]")) and string_.endswith(
177
+ "\\"
178
+ ):
179
+ add_space_ind += 1
180
+ elif (rest.startswith("(") or rest.startswith("[")) and string_.endswith(
181
+ "\\"
182
+ ):
183
+ add_space_ind -= 1
184
+ elif rest.startswith(" "):
185
+ add_space_ind += 1
186
+ elif section_title:
187
+ add_space_ind -= section_title.end() - section_title.start()
188
+ elif (
189
+ re.match(r"^[^\w\s]*_\s", rest)
190
+ or re.match(r"^[^\w\s]*\*\*?\s", rest)
191
+ or re.match(r"^.\n", rest)
192
+ ):
193
+ add_space_ind += 1
194
+ else:
195
+ break
196
+ # check if it starts in a math env and include everything before
197
+ end = math_end_regex.search(rest)
198
+ if end is not None:
199
+ start = math_start_regex.search(rest)
200
+ if start is None or start.start() > end.start():
201
+ inds = [
202
+ m.start()
203
+ for m in math_start_regex.finditer(string_)
204
+ if m.start() < end.start() + len(string_)
205
+ ]
206
+ if len(inds) > 0:
207
+ add_space_ind = inds[-1] - char_ind
208
+ # assert string_[char_ind+add_space_ind]=='\\'
209
+ return add_space_ind
210
+
211
+ def split_first_last(
212
+ self, index: int, first: str, last: str, delta: int = 5
213
+ ) -> Tuple[int, int, float]:
214
+ """Refines a split by looking at both the first words from a new page and the last words from the previous page.
215
+
216
+ Args:
217
+ index (int): paragraph index
218
+ first (str): first words
219
+ last (str): last words
220
+ delta (int, optional): paragraph search radius. Defaults to 5.
221
+
222
+ Returns:
223
+ Tuple[int, int, float]: split prediction
224
+ """
225
+ if first:
226
+ first_split = glob_f, char_f, score_f = self.split(
227
+ index, first, delta=delta
228
+ )
229
+ if last:
230
+ last_split = glob_l, char_l, score_l = self.split(
231
+ index, last, delta=delta, reverse=True
232
+ )
233
+ if first and not last:
234
+ return first_split
235
+ elif not first and last:
236
+ return last_split
237
+ elif not first and not last:
238
+ return index, 0, 0.0
239
+ if char_f == char_l and glob_f == glob_l and (score_f > 0.5 or score_l > 0.5):
240
+ return glob_l, char_l, 1.0
241
+
242
+ # score calculation
243
+ first, last = self.remove_special_chars(first), self.remove_special_chars(last)
244
+ matching = []
245
+ for split in (first_split, last_split):
246
+ first_source = []
247
+ num_chars_first = len(first)
248
+ num_chars_last = len(last)
249
+ last_source = []
250
+ for i, p in enumerate(self.paragraphs[split[0] :]):
251
+ if i == 0:
252
+ p = p[split[1] :]
253
+ first_source.append(self.remove_special_chars(p))
254
+ if sum([len(s) for s in first_source]) >= num_chars_first:
255
+ break
256
+ first_source = "".join(first_source)[:num_chars_first]
257
+ for i, p in enumerate(self.paragraphs[split[0] :: -1]):
258
+ if i == 0:
259
+ p = p[: split[1]]
260
+ last_source.insert(0, self.remove_special_chars(p))
261
+ if sum([len(s) for s in last_source]) >= num_chars_last:
262
+ last_source = last_source
263
+ break
264
+ last_source = "".join(last_source)[-num_chars_last:]
265
+ matching.append(
266
+ [
267
+ Levenshtein.ratio(first, first_source)
268
+ * Levenshtein.ratio(first[:10], first_source[:10]),
269
+ Levenshtein.ratio(last, last_source)
270
+ * Levenshtein.ratio(last[-10:], last_source[-10:]),
271
+ ]
272
+ )
273
+ scores = np.asarray(matching).max(0)
274
+ return (
275
+ (glob_l, char_l, scores[1])
276
+ if scores.argmax()
277
+ else (glob_f, char_f, scores[0])
278
+ )
279
+
280
+ def split(
281
+ self, index: int, string: str, delta: int = 5, reverse: bool = False
282
+ ) -> Tuple[int, int, float]:
283
+ """
284
+ refine split prediction. `string` are the first words from new page.
285
+ delta can be used as uncertainty measure.
286
+ returns new index and split index
287
+ """
288
+ if reverse:
289
+ index = len(self.paragraphs) - 1 - index
290
+ string = string[::-1]
291
+ paragraphs = self.paragraphs_rev
292
+ paragraphs_no_space = self.paragraphs_rev_no_space
293
+ else:
294
+ paragraphs = self.paragraphs
295
+ paragraphs_no_space = self.paragraphs_no_space
296
+
297
+ string_ = self.remove_special_chars(string)
298
+ start_ind = max(0, index - delta)
299
+ search_corpus = paragraphs_no_space[start_ind : index + delta + 1]
300
+ lengths = np.asarray([0] + [len(p) for p in search_corpus])
301
+ corp = "".join(search_corpus)
302
+ if len(corp) == 0:
303
+ self._split_locs.append((index, 0))
304
+ return index, 0, 1
305
+ ind, score = self._find_match(corp, string_)
306
+ indices, breakpoints = get_glob_index(lengths, ind, True)
307
+ global_ind, char_ind = int(start_ind + indices), int(ind - breakpoints[indices])
308
+ self._split_locs.append((global_ind, char_ind))
309
+ if reverse:
310
+ char_ind = len(paragraphs_no_space[global_ind]) - char_ind
311
+ global_ind = len(paragraphs) - global_ind - 1
312
+ add_space_ind = self.count_special_chars(self.paragraphs[global_ind], char_ind)
313
+ return global_ind, char_ind + add_space_ind, score
314
+
315
+ def _find_match(
316
+ self, corp: str, key: str, get_start: bool = True
317
+ ) -> Tuple[int, float]:
318
+ block, score = self._fuzzy(corp, key)
319
+ index = max(0, block[0])
320
+ if not get_start:
321
+ index += block[2]
322
+ return index, score
323
+
324
+ @staticmethod
325
+ def _fuzzy(
326
+ corpus: str, string: str, max_error_rate: float = 0.025
327
+ ) -> Tuple[Tuple[int, int, int], float]:
328
+ max_dist = min(len(string) - 1, int(len(string) * min(0.9, max_error_rate)) + 5)
329
+ matches = find_near_matches(string, corpus, max_l_dist=max_dist)
330
+ if len(matches) > 0 and max_dist > 0:
331
+ match = min(matches, key=lambda x: x.dist)
332
+ block = (match.start, 0, match.end - match.start)
333
+ score = 1 - match.dist / max_dist
334
+ return block, score
335
+ return (0, 0, 0), 0
336
+
337
+ @staticmethod
338
+ def fuzzysearch(
339
+ corpus: str, string: str, max_error_rate: float = 0.025
340
+ ) -> Tuple[Tuple[int, int, int], float]:
341
+ corpus_ = Splitter.remove_special_chars(corpus)
342
+ string_ = Splitter.remove_special_chars(string)
343
+ (start, _, dist), score = Splitter._fuzzy(
344
+ corpus_, string_, max_error_rate=max_error_rate
345
+ )
346
+ end = Splitter.count_special_chars(corpus, start + dist) + start + dist
347
+ start = start + Splitter.count_special_chars(corpus, start)
348
+ return (start, _, end - start), score
349
+
350
+ @staticmethod
351
+ def oldfuzz(corpus, string):
352
+ res = []
353
+ for Matcher in [StringMatcher, SequenceMatcher]:
354
+ m = Matcher(None, corpus, string)
355
+ blocks = m.get_matching_blocks()
356
+ scores = []
357
+ for i, block in enumerate(blocks):
358
+ m2 = Matcher(
359
+ None,
360
+ corpus[block[0] : block[0] + max(block[2], len(string))],
361
+ string,
362
+ )
363
+ r = m2.ratio()
364
+ if r > 0.995:
365
+ return blocks[i], r
366
+ else:
367
+ scores.append(r)
368
+ ind = np.argmax(scores)
369
+ res.append((blocks[ind], scores[ind]))
370
+ return max(res, key=itemgetter(1))
371
+
372
+ def evaluate_split(self, page_num: int, page_content: str) -> float:
373
+ if page_num > len(self._split_locs) or page_num < 1:
374
+ return 0
375
+ page_content = self.remove_special_chars(page_content)
376
+ if page_num == len(self._split_locs):
377
+ start, end = self._split_locs[-1], (-1, -1)
378
+ else:
379
+ start, end = self._split_locs[page_num - 1], self._split_locs[page_num]
380
+ if (end[0] + 1) - start[0] < 0:
381
+ return 0
382
+ doc_content = self.paragraphs_no_space[start[0] : (end[0] + 1) or None]
383
+ if (
384
+ len(doc_content) < 1
385
+ or len(doc_content[0]) < start[1]
386
+ or len(doc_content[-1]) < end[1]
387
+ ):
388
+ return 0
389
+ doc_content[0] = doc_content[0][start[1] :]
390
+ doc_content[-1] = doc_content[-1][: end[1]]
391
+ doc_content = "".join(doc_content)
392
+ match = StringMatcher(None, page_content, doc_content).ratio()
393
+ return match
nougat/dataset/staircase.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ from collections import deque
8
+ import operator
9
+ import itertools
10
+ from typing import Optional, List, Tuple
11
+ import numpy as np
12
+ import warnings
13
+
14
+ warnings.filterwarnings("ignore", message="All-NaN slice encountered")
15
+
16
+
17
+ def stair_func(x: np.ndarray, thresholds: np.ndarray) -> np.ndarray:
18
+ return np.heaviside(x[:, None] - np.floor(thresholds)[None, :], 0).sum(1)
19
+
20
+
21
+ def compute_gini(labels: np.ndarray) -> float:
22
+ N = len(labels)
23
+ if N == 0:
24
+ return 0
25
+ G = N - np.square(np.bincount(labels)).sum() / N
26
+ return G
27
+
28
+
29
+ def compute_binary_gini(labels: np.ndarray) -> float:
30
+ N = len(labels)
31
+ if N == 0:
32
+ return 0
33
+ G = N - labels.sum() ** 2 / N
34
+ return G
35
+
36
+
37
+ def gini_impurity(
38
+ thresholds: np.ndarray,
39
+ data: np.ndarray,
40
+ labels: np.ndarray,
41
+ classes: Optional[List[int]] = None,
42
+ reduction: Optional[str] = "sum",
43
+ padded: bool = True,
44
+ ) -> float:
45
+ """
46
+ Calculate the Gini impurity of a dataset split on a set of thresholds.
47
+
48
+ Args:
49
+ thresholds (np.ndarray): The thresholds to split the data on.
50
+ data (np.ndarray): The data to split.
51
+ labels (np.ndarray): The labels for the data.
52
+ classes (Optional[List[int]]): The classes to consider. If None, all classes are used.
53
+ reduction (Optional[str]): The reduction to apply to the impurity. One of "none", "sum", or "mean".
54
+ padded (bool): Whether to pad the thresholds with `[-0.5, data.max() + 0.5]`.
55
+
56
+ Returns:
57
+ float: The Gini impurity.
58
+ """
59
+ G = []
60
+ if not padded:
61
+ thresholds = np.insert(
62
+ thresholds, [0, len(thresholds)], [-0.5, data.max() + 0.5]
63
+ )
64
+ if classes is None:
65
+ classes = np.arange(len(thresholds) - 1)
66
+ else:
67
+ classes = np.asarray(classes)
68
+ if data.ndim == 1:
69
+ data = np.expand_dims(data, 0)
70
+ masks = np.logical_and(
71
+ data > thresholds[classes, None],
72
+ data <= thresholds[classes + 1, None],
73
+ )
74
+ for i, c in enumerate(classes):
75
+ G.append(compute_binary_gini(np.where(labels[masks[i]] == c, 1, 0)))
76
+
77
+ if reduction is None or reduction == "none":
78
+ return G
79
+ elif reduction == "sum":
80
+ return sum(G)
81
+ elif reduction == "mean":
82
+ return sum(G) / len(G)
83
+ else:
84
+ raise NotImplementedError
85
+
86
+
87
+ def step_impurity(
88
+ thresholds,
89
+ data: np.ndarray,
90
+ labels: np.ndarray,
91
+ classes: Optional[List[int]] = None,
92
+ ) -> float:
93
+ """
94
+ Calculate the step-wise Gini impurity of a dataset split on a set of thresholds.
95
+
96
+ Args:
97
+ thresholds (np.ndarray): The thresholds to split the data on.
98
+ data (np.ndarray): The data to split.
99
+ labels (np.ndarray): The labels for the data.
100
+ classes (Optional[List[int]]): The classes to consider. If None, all classes are used.
101
+
102
+ Returns:
103
+ float: The step-wise Gini impurity.
104
+ """
105
+ G = gini_impurity(thresholds, data, labels, reduction=None, classes=classes)
106
+ out = []
107
+ for i in range(len(G) - 1):
108
+ out.append(G[i] + G[i + 1])
109
+ return out
110
+
111
+
112
+ class PaddedArray:
113
+ """
114
+ A wrapper class for an array that allows for relative indexing.
115
+
116
+ Args:
117
+ array (np.ndarray): The array to wrap.
118
+ range (Optional[Tuple[int, int]]): The range of the array to expose. Defaults to (1, -1).
119
+ """
120
+
121
+ def __init__(
122
+ self, array: np.ndarray, range: Optional[Tuple[int, int]] = (1, -1)
123
+ ) -> None:
124
+ self.array = array
125
+ mi, ma = range
126
+ assert ma <= 0, "relative assignment only"
127
+ self.range = mi, ma
128
+
129
+ def __len__(self):
130
+ return len(self.array) + self.range[1] - self.range[0]
131
+
132
+ def _process_index(self, index):
133
+ if isinstance(index, slice):
134
+ index = slice(
135
+ (index.start or 0) + self.range[0],
136
+ self.range[0] + (len(self) if index.stop is None else index.stop),
137
+ index.step,
138
+ )
139
+ if index.stop > len(self.array):
140
+ raise IndexError
141
+ else:
142
+ index = index + self.range[0]
143
+ if index > len(self):
144
+ raise IndexError
145
+ return index
146
+
147
+ def __getitem__(self, index):
148
+ index = self._process_index(index)
149
+ return self.array[index]
150
+
151
+ def __setitem__(self, index, value):
152
+ self.array[self._process_index(index)] = value
153
+
154
+ def copy(self):
155
+ return PaddedArray(self.array.copy(), self.range)
156
+
157
+ def toarray(self):
158
+ return self.array[self.range[0] : self.range[1]]
159
+
160
+
161
+ class Staircase:
162
+ """
163
+ A class for learning a staircase decision tree.
164
+
165
+ Args:
166
+ domain: The number of points in the domain.
167
+ n_classes: The number of classes.
168
+ """
169
+
170
+ def __init__(self, domain: int, n_classes: int) -> None:
171
+ self.domain = domain
172
+ self.classes = n_classes
173
+ assert domain > 0
174
+ assert n_classes > 0
175
+ self.thresholds = self._back_thres = self._forward_thres = np.linspace(
176
+ domain / n_classes, domain, n_classes - 1, endpoint=False
177
+ )
178
+ self.uncertainty = np.zeros_like(self.thresholds)
179
+
180
+ def statistic_fit(
181
+ self,
182
+ data: np.ndarray,
183
+ labels: np.ndarray,
184
+ ):
185
+ """
186
+ Fit statistical thresholds for anomaly detection.
187
+
188
+ This method fits statistical thresholds for anomaly detection based on input data and labels.
189
+
190
+ Args:
191
+ data (np.ndarray): The input data.
192
+ labels (np.ndarray): The labels corresponding to the data.
193
+
194
+ Note:
195
+ This method modifies the internal state of the object to set statistical thresholds.
196
+ """
197
+ onehot = np.eye(self.classes)[labels.reshape(-1)]
198
+ onehot.reshape(list(labels.shape) + [self.classes])
199
+ k = onehot * data.T.repeat(self.classes, 1)
200
+ k[k == 0] = np.nan
201
+ med = np.nanmedian(k, 0)
202
+ for i in range(len(med)):
203
+ if med[i] != med[i]:
204
+ med[i] = 0 if i == 0 else med[i - 1]
205
+ mad = 5 * np.nan_to_num(
206
+ np.nanmedian(np.absolute(k - np.nanmedian(k, 0)), 0),
207
+ nan=self.domain / self.classes / 2,
208
+ )
209
+ arr = np.vstack(((med - mad)[:-1], (med + mad)[1:]))
210
+ self._forward_thres[:] = arr.max(0)
211
+ self._back_thres[:] = arr.min(0)
212
+
213
+ self._stat_forward = self._forward_thres.copy()
214
+ self._stat_back = self._back_thres.copy()
215
+
216
+ def fit(
217
+ self,
218
+ data: np.ndarray,
219
+ labels: np.ndarray,
220
+ early_stop_after: int = 10,
221
+ fixed: bool = True,
222
+ ) -> None:
223
+ """
224
+ Fit statistical thresholds for anomaly detection.
225
+
226
+ This method fits statistical thresholds for anomaly detection based on input data and labels.
227
+
228
+ Args:
229
+ data (np.ndarray): The input data.
230
+ labels (np.ndarray): The labels corresponding to the data.
231
+ early_stop_after (int, optional): The number of consecutive early stops to consider. Default is 10.
232
+ fixed (bool, optional): Whether to use fixed thresholds. Default is True.
233
+
234
+ Note:
235
+ This method modifies the internal state of the object to set statistical thresholds.
236
+ """
237
+ assert data.ndim == 1
238
+ assert labels.ndim <= 2
239
+ if self.classes == 1:
240
+ self.thresholds = np.array([0.5 + data.max()])
241
+ self.uncertainty = np.zeros_like(self.thresholds)
242
+ if data.ndim == 1:
243
+ data = np.expand_dims(data, 0)
244
+ thresholds = PaddedArray(
245
+ np.insert(
246
+ np.arange(self.domain - self.classes + 1, self.domain) - 1,
247
+ [0, self.classes - 1],
248
+ [-0.5, self.domain + 0.5],
249
+ ).astype(int)
250
+ )
251
+ self._back_thres = thresholds.copy()
252
+ self._forward_thres = thresholds.copy()
253
+ self.statistic_fit(data, labels)
254
+ last = -0.5
255
+ for n in range(self.classes):
256
+ G = np.inf
257
+ Gis = deque([], early_stop_after)
258
+ # forward pass
259
+ if n < self.classes - 1:
260
+ new_forward_n: float = self._forward_thres[n]
261
+ for i in range(
262
+ max(0, self._back_thres[n - 1]) if n - 1 >= 0 else int(last),
263
+ min(self.domain, self._forward_thres[n + 1])
264
+ if n + 2 < self.classes
265
+ else self.domain - 1,
266
+ ):
267
+ thresholds.array[n + 1] = i + 0.5
268
+ Gi = step_impurity(
269
+ thresholds.array, data, labels, classes=[n, n + 1]
270
+ )[0]
271
+ Gis.append(Gi)
272
+ if Gi <= G:
273
+ last = i + 0.5
274
+ new_forward_n = last
275
+ G = Gi
276
+ elif (
277
+ (not fixed or i - last > self.domain / self.classes)
278
+ and len(Gis) == early_stop_after
279
+ and all(
280
+ itertools.starmap(
281
+ operator.ge,
282
+ zip(Gis, itertools.islice(Gis, 1, early_stop_after)),
283
+ )
284
+ )
285
+ ):
286
+ break
287
+ thresholds.array[n + 1] = new_forward_n
288
+ self._forward_thres.array[n + 1] = new_forward_n
289
+ self._back_thres.array[n + 1] = new_forward_n
290
+ G = np.inf
291
+ self._forward_thres = self._forward_thres.toarray().clip(
292
+ min=0, max=self.domain - 1
293
+ )
294
+ self._back_thres = self._back_thres.toarray().clip(min=0, max=self.domain - 1)
295
+ self.thresholds = (self._forward_thres + self._back_thres) / 2
296
+ self.uncertainty = np.abs(self._forward_thres - self._back_thres) / 2
297
+
298
+ @property
299
+ def score(self):
300
+ try:
301
+ return gini_impurity(self.thresholds, self._data, self._labels) / len(
302
+ self._data
303
+ )
304
+ except AttributeError:
305
+ return np.inf
306
+
307
+ def predict(self, x: np.ndarray) -> np.ndarray:
308
+ return stair_func(x, self.get_boundaries())
309
+
310
+ def __call__(self, *args):
311
+ return self.predict(*args)
312
+
313
+ def get_boundaries(self) -> np.ndarray:
314
+ return self.thresholds.astype(int).clip(min=0, max=self.domain - 1) + 0.5
nougat/dataset/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
nougat/dataset/utils/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ from nougat.dataset.utils.latex_conversion import *
8
+ from nougat.dataset.utils.utils import *
nougat/dataset/utils/latex_conversion.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ import re
8
+ from pylatexenc.latexencode import UnicodeToLatexEncoder
9
+ from pylatexenc.latex2text import LatexNodes2Text
10
+ from unidecode import unidecode
11
+
12
+ syn = [
13
+ ("\\rbrack ", "] "),
14
+ ("\\lbrack ", "[ "),
15
+ ("\\lbrace ", "\\} "),
16
+ ("\\rbrace ", "\\{ "),
17
+ ("\\lnot ", "\\neg "),
18
+ ("\\land ", "\\wedge "),
19
+ ("\\vee ", "\\lor "),
20
+ ("\\doublecup ", "\\Cup "),
21
+ ("\\doublecap ", "\\Cap "),
22
+ ("\\llless ", "\\lll "),
23
+ ("\\gggtr ", "\\ggg "),
24
+ ("\\doteqdot ", "\\Doteq "),
25
+ ("\\ne ", "\\neq "),
26
+ ("\\le ", "\\leq "),
27
+ ("\\ge ", "\\geq "),
28
+ ("\\leftarrow ", "\\gets "),
29
+ ("\\rightarrow ", "\\to "),
30
+ ("\\restriction ", "\\upharpoonright "),
31
+ ("\\owns ", "\\ni "),
32
+ ("\\textlnot ", "\\neg "),
33
+ ("\\textellipsis ", "\\ldots "),
34
+ ("\\textbullet ", "\\bullet "),
35
+ ("\\plusmn ", "\\pm "),
36
+ ("\\texttimes", "\\times"),
37
+ ("\\textmu", "\\mu"),
38
+ ("\\textendash", "-"),
39
+ ("\\textemdash", "---"),
40
+ ("\\>", "\\:"),
41
+ ("\\medspace", "\\:"),
42
+ ("\\thinspace", "\\,"),
43
+ ("\\negthinspace", "\\!"),
44
+ ("\\thickspace", "\\;"),
45
+ ]
46
+ umlaut_mapping = {
47
+ "textasciicircum": "^",
48
+ "ddot": '"',
49
+ "textasciidieresis": '"',
50
+ "textasciicaron": "v ",
51
+ }
52
+ umlaut_keys = "|".join(umlaut_mapping.keys())
53
+ umlaut_regex = re.compile(rf"\s?\\({umlaut_keys})\s(\w)")
54
+ latex_comments = re.compile(r"(?<!\\)%.*\n")
55
+ toascii = UnicodeToLatexEncoder(
56
+ non_ascii_only=True, unknown_char_policy="ignore", unknown_char_warning=False
57
+ )
58
+
59
+
60
+ def remove_style(string: str) -> str:
61
+ return (
62
+ string.replace("\\displaystyle", "")
63
+ .replace("\\scriptstyle", "")
64
+ .replace("\\scriptscriptstyle", "")
65
+ .replace("\\textstyle", "")
66
+ )
67
+
68
+
69
+ def replace_duplicate_definitions(string: str) -> str:
70
+ """In Latex there are many commands that are interchangeable. Use just one of them"""
71
+ for pair in syn:
72
+ string = string.replace(pair[0], pair[1])
73
+ return string
74
+
75
+
76
+ def unicode_to_latex(s: str) -> str:
77
+ s = re.sub(
78
+ r"\s{2,}",
79
+ " ",
80
+ re.sub(
81
+ r"\\ensuremath\s?\{\s?(.+?)\s?\}\s?",
82
+ r" \1 ",
83
+ toascii.unicode_to_latex(s.strip()),
84
+ )
85
+ .replace("}", " ")
86
+ .replace("{", " "),
87
+ )
88
+ s = (
89
+ s.strip()
90
+ .replace(
91
+ "\\textperiodcentered \\textperiodcentered \\textperiodcentered", "\\cdots"
92
+ )
93
+ .replace("\\textperiodcentered", "\\cdot")
94
+ .replace("\\textquoteright", "'")
95
+ .replace("\\textquoteleft", "'")
96
+ .replace("\\textquotedblleft", '"')
97
+ .replace("\\textquotedblright", '"')
98
+ )
99
+ s = umlaut_regex.sub(lambda x: "\\" + umlaut_mapping[x.group(1)] + x.group(2), s)
100
+ s = replace_duplicate_definitions(s)
101
+ s = unidecode(s)
102
+ return s.replace("\u2009", " ")
103
+
104
+
105
+ latex_to_unicode = LatexNodes2Text()
106
+
107
+
108
+ def remove_line_breaks(string: str) -> str:
109
+ string = latex_comments.sub("\n", string)
110
+ return string.replace("\n", " ")
111
+
112
+
113
+ def normalize_tex(math: str, inline: bool) -> str:
114
+ """
115
+ Normalize TeX math expressions.
116
+
117
+ This function takes a TeX math expression and performs various normalization steps to ensure
118
+ consistency and proper formatting.
119
+
120
+ Args:
121
+ math (str): The input TeX math expression.
122
+ inline (bool): Indicates whether the expression should be inline (True) or displayed (False).
123
+
124
+ Returns:
125
+ str: The normalized TeX math expression.
126
+ """
127
+ math = math.strip()
128
+ if not math:
129
+ return ""
130
+ if math.startswith(r"\(") or math.startswith(r"\[") or math.startswith("$$"):
131
+ math = math[2:]
132
+ elif math.startswith("$"):
133
+ math = math[1:]
134
+ if math.endswith(r"\)") or math.endswith(r"\]") or math.endswith("$$"):
135
+ math = math[:-2]
136
+ elif math.endswith("$"):
137
+ math = math[:-1]
138
+ math = math.strip()
139
+ if not math:
140
+ return ""
141
+ math = remove_line_breaks(math.strip())
142
+ math = replace_duplicate_definitions(math)
143
+ math = remove_style(math)
144
+ if inline:
145
+ return rf"\({math}\)"
146
+ return rf"\[{math}\]"
nougat/dataset/utils/pdf_text_extract.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ from io import StringIO
8
+ from typing import List
9
+ import re
10
+ from pdfminer.converter import TextConverter
11
+ from pdfminer.layout import LAParams
12
+ from pdfminer.pdfdocument import PDFDocument
13
+ from pdfminer.pdfinterp import PDFResourceManager, PDFPageInterpreter
14
+ from pdfminer.pdfpage import PDFPage
15
+ from pdfminer.pdfparser import PDFParser
16
+
17
+
18
+ def replace_ligatures(text: str) -> str:
19
+ ligatures = {
20
+ "ff": "ff",
21
+ "fi": "fi",
22
+ "fl": "fl",
23
+ "ffi": "ffi",
24
+ "ffl": "ffl",
25
+ "ſt": "ft",
26
+ "st": "st",
27
+ # "Ꜳ": "AA",
28
+ # "Æ": "AE",
29
+ "ꜳ": "aa",
30
+ }
31
+ for search, replace in ligatures.items():
32
+ text = text.replace(search, replace)
33
+ return text
34
+
35
+
36
+ def remove_hyphens(text: str) -> str:
37
+ """
38
+
39
+ This fails for:
40
+ * Natural dashes: well-known, self-replication, use-cases, non-semantic,
41
+ Post-processing, Window-wise, viewpoint-dependent
42
+ * Trailing math operands: 2 - 4
43
+ * Names: Lopez-Ferreras, VGG-19, CIFAR-100
44
+ """
45
+ lines = [line.rstrip() for line in text.split("\n")]
46
+
47
+ # Find dashes
48
+ line_numbers = []
49
+ for line_no, line in enumerate(lines[:-1]):
50
+ if line.endswith("-"):
51
+ line_numbers.append(line_no)
52
+
53
+ # Replace
54
+ for line_no in line_numbers:
55
+ lines = dehyphenate(lines, line_no)
56
+ return "\n".join(lines)
57
+
58
+
59
+ def dehyphenate(lines: List[str], line_no: int) -> List[str]:
60
+ next_line = lines[line_no + 1]
61
+ word_suffix = next_line.split(" ")[0]
62
+
63
+ lines[line_no] = lines[line_no][:-1] + word_suffix
64
+ lines[line_no + 1] = lines[line_no + 1][len(word_suffix) :]
65
+ return lines
66
+
67
+
68
+ def get_pages(pdf: str) -> List[str]:
69
+ out = []
70
+ with open(pdf, "rb") as in_file:
71
+ parser = PDFParser(in_file)
72
+ doc = PDFDocument(parser)
73
+ rsrcmgr = PDFResourceManager()
74
+
75
+ for page in PDFPage.create_pages(doc):
76
+ output_string = StringIO()
77
+ device = TextConverter(rsrcmgr, output_string, laparams=LAParams())
78
+ interpreter = PDFPageInterpreter(rsrcmgr, device)
79
+ interpreter.process_page(page)
80
+ out.append(remove_hyphens(replace_ligatures(output_string.getvalue())))
81
+ return out
82
+
83
+
84
+ def get_paragraphs(pdf: str) -> List[List[str]]:
85
+ pages = get_pages(pdf)
86
+ return [re.sub(r"\n{3,}", "\n\n", txt).split("\n\n") for txt in pages]
nougat/dataset/utils/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ import re
8
+
9
+
10
+ def remove_pretty_linebreaks(string: str) -> str:
11
+ """replaces linebreaks with spaces when there would be no
12
+ difference between them in the markdown format
13
+
14
+ Args:
15
+ string (str): String to process
16
+
17
+ Returns:
18
+ str: Formatted string
19
+ """
20
+ return re.sub(r"(?<!\n)\n([^\n\d\*#\[])", r" \1", string).strip()
nougat/metrics.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ import argparse
8
+ from multiprocessing import Pool
9
+ import re
10
+ from pathlib import Path
11
+ from collections import defaultdict
12
+ from typing import List
13
+
14
+ import numpy as np
15
+
16
+ import nltk
17
+ from nltk import edit_distance
18
+ from tqdm import tqdm
19
+
20
+ import orjson
21
+
22
+ inline_reg = re.compile(r"\\\((.*?)(?<!\\)\\\)")
23
+ display_reg = re.compile(r"\\\[(.+?)(?<!\\)\\\]")
24
+ table_reg = re.compile(r"\\begin\{tabular\}(.+?)(?:\\end\{tabular\}|$)", re.S)
25
+
26
+
27
+ def compute_metrics(pred, gt, minlen=4):
28
+ metrics = {}
29
+ if len(pred) < minlen or len(gt) < minlen:
30
+ return metrics
31
+ metrics["edit_dist"] = edit_distance(pred, gt) / max(len(pred), len(gt))
32
+ reference = gt.split()
33
+ hypothesis = pred.split()
34
+ metrics["bleu"] = nltk.translate.bleu([reference], hypothesis)
35
+ try:
36
+ metrics["meteor"] = nltk.translate.meteor([reference], hypothesis)
37
+ except LookupError:
38
+ metrics["meteor"] = np.nan
39
+ reference = set(reference)
40
+ hypothesis = set(hypothesis)
41
+ metrics["precision"] = nltk.scores.precision(reference, hypothesis)
42
+ metrics["recall"] = nltk.scores.recall(reference, hypothesis)
43
+ metrics["f_measure"] = nltk.scores.f_measure(reference, hypothesis)
44
+ return metrics
45
+
46
+
47
+ def get_parser():
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument("json", type=Path, help="results file")
50
+ parser.add_argument(
51
+ "-N", dest="N", type=int, help="number of samples", default=None
52
+ )
53
+ args = parser.parse_args()
54
+ d = orjson.loads(args.json.read_text(encoding="utf-8"))
55
+ args.pred = d["predictions"]
56
+ args.gt = d["ground_truths"]
57
+ if args.N is not None:
58
+ args.pred = args.pred[: args.N]
59
+ args.gt = args.gt[: args.N]
60
+ return args
61
+
62
+
63
+ def split_text(pages: List[str]):
64
+ """
65
+ Split a list of pages into text, inline math, display math, and table blocks.
66
+
67
+ Args:
68
+ pages: The pages to split.
69
+ """
70
+ text, math, table = [], [], []
71
+ for page in pages:
72
+ for i, reg in enumerate([inline_reg, display_reg, table_reg]):
73
+ matches = "\n".join(reg.findall(page))
74
+ if i == 2:
75
+ table.append(matches)
76
+ elif i == 1:
77
+ math[-1] += matches
78
+ else:
79
+ math.append(matches)
80
+ page = reg.sub("", page)
81
+ text.append(page.strip())
82
+
83
+ return text, math, table
84
+
85
+
86
+ def get_metrics(gt: List[str], pred: List[str], pool: bool = True):
87
+ metrics = defaultdict(list)
88
+ if pool:
89
+ with Pool() as p:
90
+ _metrics = p.starmap(compute_metrics, iterable=zip(pred, gt))
91
+ else:
92
+ _metrics = [compute_metrics(p, g) for p, g in zip(pred, gt)]
93
+ for m in _metrics:
94
+ for key, value in m.items():
95
+ metrics[key].append(value)
96
+ return dict(metrics)
97
+
98
+
99
+ if __name__ == "__main__":
100
+ args = get_parser()
101
+ for name, entries in zip(["gt", "pred"], [args.gt, args.pred]):
102
+ full: Path = args.json.parent / (args.json.stem + "_" + name + "_full.mmd")
103
+ full.write_text("\n\n------------------\n\n".join(entries))
104
+ for i, (gt, pr) in enumerate(zip(split_text(args.gt), split_text(args.pred))):
105
+ sub = ["Text", "Math", "Tables"][i]
106
+ prpath: Path = args.json.parent / (
107
+ args.json.stem + "_pred_" + sub.lower() + ".mmd"
108
+ )
109
+ prpath.write_text("\n\n------------------\n\n".join(pr))
110
+ gtpath: Path = args.json.parent / (
111
+ args.json.stem + "_gt_" + sub.lower() + ".mmd"
112
+ )
113
+ gtpath.write_text("\n\n------------------\n\n".join(gt))
114
+ print("Results for", sub)
115
+
116
+ metrics = get_metrics(gt, pr)
117
+ print({key: sum(values) / len(values) for key, values in metrics.items()})
nougat/model.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Donut
3
+ Copyright (c) 2022-present NAVER Corp.
4
+ MIT License
5
+ Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ """
7
+ import logging
8
+ import math
9
+ import os
10
+ from typing import List, Optional, Union
11
+ from collections import defaultdict
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+ from PIL import Image
16
+ import cv2
17
+ import timm
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from PIL import ImageOps
22
+ from timm.models.swin_transformer import SwinTransformer
23
+ from torchvision.transforms.functional import resize, rotate
24
+ from transformers import (
25
+ PreTrainedTokenizerFast,
26
+ StoppingCriteria,
27
+ StoppingCriteriaList,
28
+ MBartConfig,
29
+ MBartForCausalLM,
30
+ )
31
+ from transformers.file_utils import ModelOutput
32
+ from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
33
+ from nougat.postprocessing import postprocess
34
+ from nougat.transforms import train_transform, test_transform
35
+
36
+
37
+ class SwinEncoder(nn.Module):
38
+ r"""
39
+ Encoder based on SwinTransformer
40
+ Set the initial weights and configuration with a pretrained SwinTransformer and then
41
+ modify the detailed configurations
42
+
43
+ Args:
44
+ input_size: Input image size (width, height)
45
+ align_long_axis: Whether to rotate image if height is greater than width
46
+ window_size: Window size(=patch size) of SwinTransformer
47
+ encoder_layer: Number of layers of SwinTransformer encoder
48
+ name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local.
49
+ otherwise, `swin_base_patch4_window12_384` will be set (using `timm`).
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ input_size: List[int],
55
+ align_long_axis: bool,
56
+ window_size: int,
57
+ encoder_layer: List[int],
58
+ patch_size: int,
59
+ embed_dim: int,
60
+ num_heads: List[int],
61
+ name_or_path: Union[str, bytes, os.PathLike] = None,
62
+ ):
63
+ super().__init__()
64
+ self.input_size = input_size
65
+ self.align_long_axis = align_long_axis
66
+ self.window_size = window_size
67
+ self.encoder_layer = encoder_layer
68
+ self.patch_size = patch_size
69
+ self.embed_dim = embed_dim
70
+ self.num_heads = num_heads
71
+
72
+ self.model = SwinTransformer(
73
+ img_size=self.input_size,
74
+ depths=self.encoder_layer,
75
+ window_size=self.window_size,
76
+ patch_size=self.patch_size,
77
+ embed_dim=self.embed_dim,
78
+ num_heads=self.num_heads,
79
+ num_classes=0,
80
+ )
81
+
82
+ # weight init with swin
83
+ if not name_or_path:
84
+ swin_state_dict = timm.create_model(
85
+ "swin_base_patch4_window12_384", pretrained=True
86
+ ).state_dict()
87
+ new_swin_state_dict = self.model.state_dict()
88
+ for x in new_swin_state_dict:
89
+ if x.endswith("relative_position_index") or x.endswith("attn_mask"):
90
+ pass
91
+ elif (
92
+ x.endswith("relative_position_bias_table")
93
+ and self.model.layers[0].blocks[0].attn.window_size[0] != 12
94
+ ):
95
+ pos_bias = swin_state_dict[x].unsqueeze(0)[0]
96
+ old_len = int(math.sqrt(len(pos_bias)))
97
+ new_len = int(2 * window_size - 1)
98
+ pos_bias = pos_bias.reshape(1, old_len, old_len, -1).permute(
99
+ 0, 3, 1, 2
100
+ )
101
+ pos_bias = F.interpolate(
102
+ pos_bias,
103
+ size=(new_len, new_len),
104
+ mode="bicubic",
105
+ align_corners=False,
106
+ )
107
+ new_swin_state_dict[x] = (
108
+ pos_bias.permute(0, 2, 3, 1)
109
+ .reshape(1, new_len**2, -1)
110
+ .squeeze(0)
111
+ )
112
+ else:
113
+ new_swin_state_dict[x] = swin_state_dict[x]
114
+ self.model.load_state_dict(new_swin_state_dict)
115
+
116
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
117
+ """
118
+ Args:
119
+ x: (batch_size, num_channels, height, width)
120
+ """
121
+ x = self.model.patch_embed(x)
122
+ x = self.model.pos_drop(x)
123
+ x = self.model.layers(x)
124
+ return x
125
+
126
+ @staticmethod
127
+ def crop_margin(img: Image.Image) -> Image.Image:
128
+ data = np.array(img.convert("L"))
129
+ data = data.astype(np.uint8)
130
+ max_val = data.max()
131
+ min_val = data.min()
132
+ if max_val == min_val:
133
+ return img
134
+ data = (data - min_val) / (max_val - min_val) * 255
135
+ gray = 255 * (data < 200).astype(np.uint8)
136
+
137
+ coords = cv2.findNonZero(gray) # Find all non-zero points (text)
138
+ a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
139
+ return img.crop((a, b, w + a, h + b))
140
+
141
+ @property
142
+ def to_tensor(self):
143
+ if self.training:
144
+ return train_transform
145
+ else:
146
+ return test_transform
147
+
148
+ def prepare_input(
149
+ self, img: Image.Image, random_padding: bool = False
150
+ ) -> torch.Tensor:
151
+ """
152
+ Convert PIL Image to tensor according to specified input_size after following steps below:
153
+ - resize
154
+ - rotate (if align_long_axis is True and image is not aligned longer axis with canvas)
155
+ - pad
156
+ """
157
+ if img is None:
158
+ return
159
+ # crop margins
160
+ try:
161
+ img = self.crop_margin(img.convert("RGB"))
162
+ except OSError:
163
+ # might throw an error for broken files
164
+ return
165
+ if img.height == 0 or img.width == 0:
166
+ return
167
+ if self.align_long_axis and (
168
+ (self.input_size[0] > self.input_size[1] and img.width > img.height)
169
+ or (self.input_size[0] < self.input_size[1] and img.width < img.height)
170
+ ):
171
+ img = rotate(img, angle=-90, expand=True)
172
+ img = resize(img, min(self.input_size))
173
+ img.thumbnail((self.input_size[1], self.input_size[0]))
174
+ delta_width = self.input_size[1] - img.width
175
+ delta_height = self.input_size[0] - img.height
176
+ if random_padding:
177
+ pad_width = np.random.randint(low=0, high=delta_width + 1)
178
+ pad_height = np.random.randint(low=0, high=delta_height + 1)
179
+ else:
180
+ pad_width = delta_width // 2
181
+ pad_height = delta_height // 2
182
+ padding = (
183
+ pad_width,
184
+ pad_height,
185
+ delta_width - pad_width,
186
+ delta_height - pad_height,
187
+ )
188
+ return self.to_tensor(ImageOps.expand(img, padding))
189
+
190
+
191
+ class BARTDecoder(nn.Module):
192
+ """
193
+ Decoder based on Multilingual BART
194
+ Set the initial weights and configuration with a pretrained multilingual BART model,
195
+ and modify the detailed configurations as a Nougat decoder
196
+
197
+ Args:
198
+ decoder_layer:
199
+ Number of layers of BARTDecoder
200
+ max_position_embeddings:
201
+ The maximum sequence length to be trained
202
+ name_or_path:
203
+ Name of a pretrained model name either registered in huggingface.co. or saved in local,
204
+ otherwise, `facebook/mbart-large-50` will be set (using `transformers`)
205
+ """
206
+
207
+ def __init__(
208
+ self,
209
+ decoder_layer: int,
210
+ max_position_embeddings: int,
211
+ hidden_dimension: int = 1024,
212
+ name_or_path: Union[str, bytes, os.PathLike] = None,
213
+ ):
214
+ super().__init__()
215
+ self.decoder_layer = decoder_layer
216
+ self.max_position_embeddings = max_position_embeddings
217
+ if not name_or_path:
218
+ tokenizer_file = Path(__file__).parent / "dataset" / "tokenizer.json"
219
+ else:
220
+ tokenizer_file = Path(name_or_path) / "tokenizer.json"
221
+ if not tokenizer_file.exists():
222
+ raise ValueError("Could not find tokenizer file")
223
+ self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_file))
224
+ self.tokenizer.pad_token = "<pad>"
225
+ self.tokenizer.bos_token = "<s>"
226
+ self.tokenizer.eos_token = "</s>"
227
+ self.tokenizer.unk_token = "<unk>"
228
+
229
+ self.model = MBartForCausalLM(
230
+ config=MBartConfig(
231
+ is_decoder=True,
232
+ is_encoder_decoder=False,
233
+ add_cross_attention=True,
234
+ decoder_layers=self.decoder_layer,
235
+ max_position_embeddings=self.max_position_embeddings,
236
+ vocab_size=len(self.tokenizer),
237
+ scale_embedding=True,
238
+ add_final_layer_norm=True,
239
+ d_model=hidden_dimension,
240
+ )
241
+ )
242
+ self.model.config.is_encoder_decoder = True # to get cross-attention
243
+ self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id
244
+ self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference
245
+
246
+ if not name_or_path:
247
+ bart_state_dict = MBartForCausalLM.from_pretrained(
248
+ "facebook/mbart-large-50"
249
+ ).state_dict()
250
+ new_bart_state_dict = self.model.state_dict()
251
+ for x in new_bart_state_dict:
252
+ if (
253
+ x.endswith("embed_positions.weight")
254
+ and self.max_position_embeddings != 1024
255
+ ):
256
+ new_bart_state_dict[x] = torch.nn.Parameter(
257
+ self.resize_bart_abs_pos_emb(
258
+ bart_state_dict[x],
259
+ self.max_position_embeddings
260
+ + 2, # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119
261
+ )
262
+ )
263
+ elif x.endswith("embed_tokens.weight") or x.endswith("lm_head.weight"):
264
+ new_bart_state_dict[x] = bart_state_dict[x][
265
+ : len(self.tokenizer), :
266
+ ]
267
+ else:
268
+ new_bart_state_dict[x] = bart_state_dict[x]
269
+ self.model.load_state_dict(new_bart_state_dict, strict=False)
270
+
271
+ def add_special_tokens(self, list_of_tokens: List[str]):
272
+ """
273
+ Add special tokens to tokenizer and resize the token embeddings
274
+ """
275
+ newly_added_num = self.tokenizer.add_special_tokens(
276
+ {"additional_special_tokens": sorted(set(list_of_tokens))}
277
+ )
278
+ if newly_added_num > 0:
279
+ self.model.resize_token_embeddings(len(self.tokenizer))
280
+
281
+ def prepare_inputs_for_inference(
282
+ self,
283
+ input_ids: torch.Tensor,
284
+ encoder_outputs: torch.Tensor,
285
+ past=None,
286
+ past_key_values=None,
287
+ use_cache: bool = None,
288
+ attention_mask: torch.Tensor = None,
289
+ ):
290
+ """
291
+ Args:
292
+ input_ids: (batch_size, sequence_length)
293
+
294
+ Returns:
295
+ input_ids: (batch_size, sequence_length)
296
+ attention_mask: (batch_size, sequence_length)
297
+ encoder_hidden_states: (batch_size, sequence_length, embedding_dim)
298
+ """
299
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
300
+ past = past or past_key_values
301
+ if past is not None:
302
+ input_ids = input_ids[:, -1:]
303
+ output = {
304
+ "input_ids": input_ids,
305
+ "attention_mask": attention_mask,
306
+ "past_key_values": past,
307
+ "use_cache": use_cache,
308
+ "encoder_hidden_states": encoder_outputs.last_hidden_state,
309
+ }
310
+ return output
311
+
312
+ def forward(
313
+ self,
314
+ input_ids,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ encoder_hidden_states: Optional[torch.Tensor] = None,
317
+ past_key_values: Optional[torch.Tensor] = None,
318
+ labels: Optional[torch.Tensor] = None,
319
+ use_cache: bool = None,
320
+ output_attentions: Optional[torch.Tensor] = None,
321
+ output_hidden_states: Optional[torch.Tensor] = None,
322
+ return_dict: bool = None,
323
+ ):
324
+ return self.model.forward(
325
+ input_ids,
326
+ attention_mask=attention_mask,
327
+ labels=labels,
328
+ encoder_hidden_states=encoder_hidden_states,
329
+ past_key_values=past_key_values,
330
+ use_cache=use_cache,
331
+ output_attentions=output_attentions,
332
+ output_hidden_states=output_hidden_states,
333
+ return_dict=return_dict,
334
+ )
335
+
336
+ @staticmethod
337
+ def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tensor:
338
+ """
339
+ Resize position embeddings
340
+ Truncate if sequence length of MBart backbone is greater than given max_length,
341
+ else interpolate to max_length
342
+ """
343
+ if weight.shape[0] > max_length:
344
+ weight = weight[:max_length, ...]
345
+ else:
346
+ weight = (
347
+ F.interpolate(
348
+ weight.permute(1, 0).unsqueeze(0),
349
+ size=max_length,
350
+ mode="linear",
351
+ align_corners=False,
352
+ )
353
+ .squeeze(0)
354
+ .permute(1, 0)
355
+ )
356
+ return weight
357
+
358
+
359
+ class NougatConfig(PretrainedConfig):
360
+ r"""
361
+ This is the configuration class to store the configuration of a [`NougatModel`]. It is used to
362
+ instantiate a Nougat model according to the specified arguments, defining the model architecture
363
+
364
+ Args:
365
+ input_size:
366
+ Input image size (canvas size) of Nougat.encoder, SwinTransformer in this codebase
367
+ align_long_axis:
368
+ Whether to rotate image if height is greater than width
369
+ window_size:
370
+ Window size of Nougat.encoder, SwinTransformer in this codebase
371
+ encoder_layer:
372
+ Depth of each Nougat.encoder Encoder layer, SwinTransformer in this codebase
373
+ decoder_layer:
374
+ Number of hidden layers in the Nougat.decoder, such as BART
375
+ max_position_embeddings
376
+ Trained max position embeddings in the Nougat decoder,
377
+ if not specified, it will have same value with max_length
378
+ max_length:
379
+ Max position embeddings(=maximum sequence length) you want to train
380
+ name_or_path:
381
+ Name of a pretrained model name either registered in huggingface.co. or saved in local
382
+ """
383
+ model_type = "nougat"
384
+
385
+ def __init__(
386
+ self,
387
+ input_size: List[int] = [896, 672],
388
+ align_long_axis: bool = False,
389
+ window_size: int = 7,
390
+ encoder_layer: List[int] = [2, 2, 14, 2],
391
+ decoder_layer: int = 10,
392
+ max_position_embeddings: int = None,
393
+ max_length: int = 4096,
394
+ name_or_path: Union[str, bytes, os.PathLike] = "",
395
+ patch_size: int = 4,
396
+ embed_dim: int = 128,
397
+ num_heads: List[int] = [4, 8, 16, 32],
398
+ hidden_dimension: int = 1024,
399
+ **kwargs,
400
+ ):
401
+ super().__init__()
402
+ self.input_size = input_size
403
+ self.align_long_axis = align_long_axis
404
+ self.window_size = window_size
405
+ self.encoder_layer = encoder_layer
406
+ self.decoder_layer = decoder_layer
407
+ self.max_position_embeddings = (
408
+ max_length if max_position_embeddings is None else max_position_embeddings
409
+ )
410
+ self.max_length = max_length
411
+ self.name_or_path = name_or_path
412
+ self.patch_size = patch_size
413
+ self.embed_dim = embed_dim
414
+ self.num_heads = num_heads
415
+ self.hidden_dimension = hidden_dimension
416
+
417
+
418
+ class RunningVarTorch:
419
+ def __init__(self, L=15, norm=False):
420
+ self.values = None
421
+ self.L = L
422
+ self.norm = norm
423
+
424
+ def push(self, x: torch.Tensor):
425
+ assert x.dim() == 1
426
+ if self.values is None:
427
+ self.values = x[:, None]
428
+ elif self.values.shape[1] < self.L:
429
+ self.values = torch.cat((self.values, x[:, None]), 1)
430
+ else:
431
+ self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
432
+
433
+ def variance(self):
434
+ if self.values is None:
435
+ return
436
+ if self.norm:
437
+ return torch.var(self.values, 1) / self.values.shape[1]
438
+ else:
439
+ return torch.var(self.values, 1)
440
+
441
+
442
+ class StoppingCriteriaScores(StoppingCriteria):
443
+ def __init__(self, threshold: float = 0.015, window_size: int = 200):
444
+ super().__init__()
445
+ self.threshold = threshold
446
+ self.vars = RunningVarTorch(norm=True)
447
+ self.varvars = RunningVarTorch(L=window_size)
448
+ self.stop_inds = defaultdict(int)
449
+ self.stopped = defaultdict(bool)
450
+ self.size = 0
451
+ self.window_size = window_size
452
+
453
+ @torch.no_grad()
454
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
455
+ last_scores = scores[-1]
456
+ self.vars.push(last_scores.max(1)[0].float().cpu())
457
+ self.varvars.push(self.vars.variance())
458
+ self.size += 1
459
+ if self.size < self.window_size:
460
+ return False
461
+
462
+ varvar = self.varvars.variance()
463
+ for b in range(len(last_scores)):
464
+ if varvar[b] < self.threshold:
465
+ if self.stop_inds[b] > 0 and not self.stopped[b]:
466
+ self.stopped[b] = self.stop_inds[b] >= self.size
467
+ else:
468
+ self.stop_inds[b] = int(
469
+ min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)
470
+ )
471
+ else:
472
+ self.stop_inds[b] = 0
473
+ self.stopped[b] = False
474
+ return all(self.stopped.values()) and len(self.stopped) > 0
475
+
476
+
477
+ def batch(l, b=15):
478
+ subs = []
479
+ for i in range(len(l) - b):
480
+ subs.append(l[i : i + b])
481
+ return subs
482
+
483
+
484
+ def subdiv(l, b=10):
485
+ subs = []
486
+ for i in range(len(l) - b):
487
+ subs.append(l[: i + b])
488
+ return subs
489
+
490
+
491
+ class NougatModel(PreTrainedModel):
492
+ r"""
493
+ Nougat: Neural Optical UnderstandinG for Academic documents.
494
+ The encoder converts an image of an academic document into a series of embeddings.
495
+ Then, the decoder generates a sequence of tokens based on encoder's output.
496
+ This sequence can be translated into a structured markup language format.
497
+ """
498
+ config_class = NougatConfig
499
+ base_model_prefix = "nougat"
500
+
501
+ def __init__(self, config: NougatConfig):
502
+ super().__init__(config)
503
+ self.config = config
504
+ self.encoder = SwinEncoder(
505
+ input_size=self.config.input_size,
506
+ align_long_axis=self.config.align_long_axis,
507
+ window_size=self.config.window_size,
508
+ encoder_layer=self.config.encoder_layer,
509
+ name_or_path=self.config.name_or_path,
510
+ patch_size=self.config.patch_size,
511
+ embed_dim=self.config.embed_dim,
512
+ num_heads=self.config.num_heads,
513
+ )
514
+ self.decoder = BARTDecoder(
515
+ max_position_embeddings=self.config.max_position_embeddings,
516
+ decoder_layer=self.config.decoder_layer,
517
+ name_or_path=self.config.name_or_path,
518
+ hidden_dimension=self.config.hidden_dimension,
519
+ )
520
+
521
+ def forward(
522
+ self,
523
+ image_tensors: torch.Tensor,
524
+ decoder_input_ids: torch.Tensor,
525
+ attention_mask: Optional[torch.Tensor] = None,
526
+ ):
527
+ """
528
+ Calculate a loss given an input image and a desired token sequence,
529
+ the model will be trained in a teacher-forcing manner
530
+
531
+ Args:
532
+ image_tensors: (batch_size, num_channels, height, width)
533
+ decoder_input_ids: (batch_size, sequence_length, embedding_dim)
534
+ """
535
+ encoder_outputs = self.encoder(image_tensors)
536
+ decoder_outputs = self.decoder(
537
+ input_ids=decoder_input_ids[:, :-1].contiguous(),
538
+ encoder_hidden_states=encoder_outputs,
539
+ attention_mask=attention_mask[:, :-1],
540
+ labels=decoder_input_ids[:, 1:].contiguous(),
541
+ )
542
+ return decoder_outputs
543
+
544
+ def _init_weights(self, *args, **kwargs):
545
+ return
546
+
547
+ def inference(
548
+ self,
549
+ image: Image.Image = None,
550
+ image_tensors: Optional[torch.Tensor] = None,
551
+ return_attentions: bool = False,
552
+ early_stopping: bool = True,
553
+ ):
554
+ """
555
+ Generate a token sequence in an auto-regressive manner.
556
+
557
+ Args:
558
+ image: input document image (PIL.Image)
559
+ image_tensors: (1, num_channels, height, width)
560
+ convert prompt to tensor if image_tensor is not fed
561
+ """
562
+ output = {
563
+ "predictions": list(),
564
+ "sequences": list(),
565
+ "repeats": list(),
566
+ "repetitions": list(),
567
+ }
568
+ if image is None and image_tensors is None:
569
+ logging.warn("Image not found")
570
+ return output
571
+
572
+ if image_tensors is None:
573
+ image_tensors = self.encoder.prepare_input(image).unsqueeze(0)
574
+
575
+ if self.device.type != "mps":
576
+ image_tensors = image_tensors.to(next(self.parameters()).dtype)
577
+
578
+ image_tensors = image_tensors.to(self.device)
579
+
580
+ last_hidden_state = self.encoder(image_tensors)
581
+
582
+ encoder_outputs = ModelOutput(
583
+ last_hidden_state=last_hidden_state, attentions=None
584
+ )
585
+
586
+ if len(encoder_outputs.last_hidden_state.size()) == 1:
587
+ encoder_outputs.last_hidden_state = (
588
+ encoder_outputs.last_hidden_state.unsqueeze(0)
589
+ )
590
+
591
+ # get decoder output
592
+ decoder_output = self.decoder.model.generate(
593
+ encoder_outputs=encoder_outputs,
594
+ min_length=1,
595
+ max_length=self.config.max_length,
596
+ pad_token_id=self.decoder.tokenizer.pad_token_id,
597
+ eos_token_id=self.decoder.tokenizer.eos_token_id,
598
+ use_cache=True,
599
+ bad_words_ids=[
600
+ [self.decoder.tokenizer.unk_token_id],
601
+ ],
602
+ return_dict_in_generate=True,
603
+ output_scores=True,
604
+ output_attentions=return_attentions,
605
+ do_sample=False,
606
+ stopping_criteria=StoppingCriteriaList(
607
+ [StoppingCriteriaScores()] if early_stopping else []
608
+ ),
609
+ )
610
+ output["repetitions"] = decoder_output.sequences.clone()
611
+ output["sequences"] = decoder_output.sequences.clone()
612
+ batch_size = len(decoder_output.sequences)
613
+
614
+ logits = torch.stack(decoder_output.scores, 1).cpu().max(-1)
615
+ values = logits.values
616
+ indices = logits.indices
617
+
618
+ for b in range(batch_size):
619
+ mask = indices[b] != self.decoder.tokenizer.pad_token_id
620
+ N = mask.sum().item()
621
+ var = np.array(
622
+ [np.var(s) / len(s) for s in batch(values[b, mask].float().numpy())]
623
+ )
624
+ if len(var) < 10:
625
+ output["repeats"].append(None)
626
+ continue
627
+ varvar = np.array([np.var(v) for v in subdiv(var[::-1])][::-1])
628
+ minlen = 120
629
+ if (
630
+ indices[b] == self.decoder.tokenizer.eos_token_id
631
+ ).any() and N + 1 < indices.shape[1]:
632
+ # there is an end to the generation, likely no repetitions
633
+ output["repeats"].append(None)
634
+ continue
635
+ small_var = np.where(varvar < 0.045)[0]
636
+ if early_stopping and len(small_var) > 1:
637
+ if np.all(np.diff(small_var) < 2):
638
+ idx = int(min(max(small_var[0], 1) * 1.08 + minlen, 4095))
639
+ if idx / N > 0.9: # at most last bit
640
+ output["repeats"].append(None)
641
+ continue
642
+ elif small_var[0] < 30:
643
+ idx = 0
644
+ logging.warn("Found repetitions in sample %i" % b)
645
+ output["repeats"].append(idx)
646
+ output["sequences"][b, idx:] = self.decoder.tokenizer.pad_token_id
647
+ output["repetitions"][b, :idx] = self.decoder.tokenizer.pad_token_id
648
+ else:
649
+ output["repeats"].append(None)
650
+ else:
651
+ output["repeats"].append(None)
652
+ output["repetitions"] = self.decoder.tokenizer.batch_decode(
653
+ output["repetitions"], skip_special_tokens=True
654
+ )
655
+ output["predictions"] = postprocess(
656
+ self.decoder.tokenizer.batch_decode(
657
+ output["sequences"], skip_special_tokens=True
658
+ ),
659
+ markdown_fix=False,
660
+ )
661
+
662
+ if return_attentions:
663
+ output["attentions"] = {
664
+ "self_attentions": decoder_output.decoder_attentions,
665
+ "cross_attentions": decoder_output.cross_attentions,
666
+ }
667
+
668
+ return output
669
+
670
+ @classmethod
671
+ def from_pretrained(
672
+ cls,
673
+ model_path: Union[str, bytes, os.PathLike],
674
+ *model_args,
675
+ **kwargs,
676
+ ):
677
+ r"""
678
+ Instantiate a pretrained nougat model from a pre-trained model configuration
679
+
680
+ Args:
681
+ model_path:
682
+ Name of a pretrained model name either registered in huggingface.co. or saved in local.
683
+ """
684
+ model = super(NougatModel, cls).from_pretrained(
685
+ model_path, *model_args, **kwargs
686
+ )
687
+
688
+ # truncate or interpolate position embeddings of decoder
689
+ max_length = kwargs.get("max_length", model.config.max_position_embeddings)
690
+ if (
691
+ max_length != model.config.max_position_embeddings
692
+ ): # if max_length of trained model differs max_length you want to train
693
+ model.decoder.model.model.decoder.embed_positions.weight = torch.nn.Parameter(
694
+ model.decoder.resize_bart_abs_pos_emb(
695
+ model.decoder.model.model.decoder.embed_positions.weight,
696
+ max_length
697
+ + 2, # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119
698
+ )
699
+ )
700
+ model.config.max_position_embeddings = max_length
701
+
702
+ return model
nougat/postprocessing.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ from typing import Union, List
8
+ import re
9
+ import os
10
+ import numpy as np
11
+ from nltk.corpus import words
12
+ from multiprocessing import Pool
13
+ from functools import partial
14
+ from Levenshtein import ratio
15
+
16
+
17
+ reference_pattern = re.compile(r"^\* \[\d+\]", flags=re.M)
18
+
19
+
20
+ def markdown_compatible(s: str) -> str:
21
+ """
22
+ Make text compatible with Markdown formatting.
23
+
24
+ This function makes various text formatting adjustments to make it compatible with Markdown.
25
+
26
+ Args:
27
+ s (str): The input text to be made Markdown-compatible.
28
+
29
+ Returns:
30
+ str: The Markdown-compatible text.
31
+ """
32
+ # equation tag
33
+ s = re.sub(
34
+ r"^\(([\d.]+[a-zA-Z]?)\) \\\[(.+?)\\\]$", r"\[\2 \\tag{\1}\]", s, flags=re.M
35
+ )
36
+ s = re.sub(
37
+ r"^\\\[(.+?)\\\] \(([\d.]+[a-zA-Z]?)\)$", r"\[\1 \\tag{\2}\]", s, flags=re.M
38
+ )
39
+ s = re.sub(
40
+ r"^\\\[(.+?)\\\] \(([\d.]+[a-zA-Z]?)\) (\\\[.+?\\\])$",
41
+ r"\[\1 \\tag{\2}\] \3",
42
+ s,
43
+ flags=re.M,
44
+ ) # multi line
45
+ s = s.replace(r"\. ", ". ")
46
+ # bold formatting
47
+ s = s.replace(r"\bm{", r"\mathbf{").replace(r"{\\bm ", r"\mathbf{")
48
+ # s = s.replace(r"\it{", r"\mathit{").replace(r"{\\it ", r"\mathit{") # not needed
49
+ s = re.sub(r"\\mbox{ ?\\boldmath\$(.*?)\$}", r"\\mathbf{\1}", s)
50
+ # s=re.sub(r'\\begin{table}(.+?)\\end{table}\nTable \d+: (.+?)\n',r'\\begin{table}\1\n\\capation{\2}\n\\end{table}\n',s,flags=re.S)
51
+ # s=re.sub(r'###### Abstract\n(.*?)\n\n',r'\\begin{abstract}\n\1\n\\end{abstract}\n\n',s,flags=re.S)
52
+ # urls
53
+ s = re.sub(
54
+ r"((?:http|ftp|https):\/\/(?:[\w_-]+(?:(?:\.[\w_-]+)+))(?:[\w.,@?^=%&:\/~+#-]*[\w@?^=%&\/~+#-]))",
55
+ r"[\1](\1)",
56
+ s,
57
+ )
58
+ # algorithms
59
+ s = re.sub(r"```\s*(.+?)\s*```", r"```\n\1\n```", s, flags=re.S)
60
+ # lists
61
+
62
+ return s
63
+
64
+
65
+ def find_next_punctuation(s: str, start_inx=0):
66
+ """
67
+ Find the index of the next punctuation mark
68
+
69
+ Args:
70
+ s: String to examine
71
+ start_inx: Index where to start
72
+ """
73
+
74
+ for i in range(start_inx, len(s)):
75
+ if s[i] in [".", "?", "!", "\n"]:
76
+ return i
77
+
78
+ return None
79
+
80
+
81
+ def find_last_punctuation(s: str, start_inx=0):
82
+ """
83
+ Find the index of the last punctuation mark before start_inx
84
+
85
+ Args:
86
+ s: String to examine
87
+ start_inx: Index where to look before
88
+ """
89
+
90
+ for i in range(start_inx - 1, 0, -1):
91
+ if s[i] in [".", "?", "!", "\n"]:
92
+ return i
93
+
94
+ return None
95
+
96
+
97
+ def truncate_repetitions(s: str, min_len=30):
98
+ """
99
+ Attempt to truncate repeating segments in the input string.
100
+
101
+ This function looks for the longest repeating substring at the end of the input string and truncates
102
+ it to appear only once. To be considered for removal, repetitions need to be continuous.
103
+
104
+ Args:
105
+ s (str): The input raw prediction to be truncated.
106
+ min_len (int): The minimum length of the repeating segment.
107
+
108
+ Returns:
109
+ str: The input string with repeated segments truncated.
110
+ """
111
+ s_lower = s.lower()
112
+ s_len = len(s_lower)
113
+
114
+ if s_len < 2 * min_len:
115
+ return s
116
+
117
+ # try to find a length at which the tail is repeating
118
+ max_rep_len = None
119
+ for rep_len in range(min_len, int(s_len / 2)):
120
+ # check if there is a repetition at the end
121
+ same = True
122
+ for i in range(0, rep_len):
123
+ if s_lower[s_len - rep_len - i - 1] != s_lower[s_len - i - 1]:
124
+ same = False
125
+ break
126
+
127
+ if same:
128
+ max_rep_len = rep_len
129
+
130
+ if max_rep_len is None:
131
+ return s
132
+
133
+ lcs = s_lower[-max_rep_len:]
134
+
135
+ # remove all but the last repetition
136
+ st = s
137
+ st_lower = s_lower
138
+ while st_lower.endswith(lcs):
139
+ st = st[:-max_rep_len]
140
+ st_lower = st_lower[:-max_rep_len]
141
+
142
+ # this is the tail with the repetitions
143
+ repeating_tail = s_lower[len(st_lower) :]
144
+
145
+ # add until next punctuation and make sure last sentence is not repeating
146
+ st_lower_out = st_lower
147
+ while True:
148
+ sentence_end = find_next_punctuation(s_lower, len(st_lower_out))
149
+ sentence_start = find_last_punctuation(s_lower, len(st_lower_out))
150
+ if sentence_end and sentence_start:
151
+ sentence = s_lower[sentence_start:sentence_end]
152
+ st_lower_out = s_lower[: sentence_end + 1]
153
+ if sentence in repeating_tail:
154
+ break
155
+ else:
156
+ break
157
+
158
+ s_out = s[: len(st_lower_out)]
159
+
160
+ return s_out
161
+
162
+
163
+ def close_envs(s: str) -> str:
164
+ """checks if table envs are opened but not closed. Appends the closing statements and returns the new string"""
165
+ envs = ("bmatrix", "pmatrix", "matrix", "tabular", "table")
166
+ for env in envs:
167
+ begins, ends = s.count(r"\begin{%s}" % env), s.count(r"\end{%s}" % env)
168
+ if begins > ends:
169
+ s += (r"\end{%s}" % env) * (begins - ends)
170
+ return s
171
+
172
+
173
+ def remove_numbers(lines):
174
+ def _clean(s):
175
+ return re.sub(r"(?:[\d_]|\*\*)", "", s).strip()
176
+
177
+ if type(lines) is str:
178
+ return _clean(lines)
179
+ out = []
180
+ for l in lines:
181
+ out.append(_clean(l))
182
+ return out
183
+
184
+
185
+ def get_slices(lines, clean_lines):
186
+ """
187
+ Get slices of text based on specific criteria within the lines.
188
+
189
+ This function identifies and returns slices of text from the input lines based on certain conditions.
190
+
191
+ Args:
192
+ lines (list of str): The list of lines containing the text.
193
+ clean_lines (list of str): A cleaned version of the text (without numbers).
194
+
195
+ Returns:
196
+ list of tuple: A list of tuples representing the start and end indices of text slices.
197
+ """
198
+ inds = np.zeros(len(lines))
199
+ for i in range(len(lines) - 1):
200
+ j = i + 1
201
+ while not clean_lines[j] and j < len(lines) - 1:
202
+ j += 1
203
+ if (
204
+ len(clean_lines[i]) < 200
205
+ and len(clean_lines[i]) > 3
206
+ and len(clean_lines[j]) < 200
207
+ and len(clean_lines[j]) > 3
208
+ and not clean_lines[i].startswith("[MISSING_PAGE")
209
+ and (
210
+ clean_lines[i] == clean_lines[j]
211
+ or ratio(clean_lines[i], clean_lines[j]) > 0.9
212
+ )
213
+ ):
214
+ inds[i:j] = 1
215
+ ids = np.where(inds)[0]
216
+ slices = []
217
+ if len(ids) == 0:
218
+ return slices
219
+ j0 = 0
220
+ for j, x in enumerate(np.diff(ids) > 3):
221
+ if x:
222
+ slices.append((ids[j0], ids[j] + 2))
223
+ j0 = j + 1
224
+ slices.append((ids[j0], ids[-1] + 2))
225
+ return [sli for sli in slices if sli[1] - sli[0] > 15]
226
+
227
+
228
+ def remove_slice_from_lines(lines, clean_text, sli) -> str:
229
+ """
230
+ Remove a slice of text from the lines based on specific criteria.
231
+
232
+ This function identifies a slice of text within the lines and removes it based on certain conditions.
233
+
234
+ Args:
235
+ lines (list of str): The list of lines containing the text.
236
+ clean_text (list of str): A cleaned version of the text (without numbers).
237
+ sli (tuple): A tuple representing the start and end indices of the slice to be removed.
238
+
239
+ Returns:
240
+ str: The removed slice of text as a single string.
241
+ """
242
+ base = clean_text[sli[0]]
243
+ section = list(sli)
244
+ check_start_flag = False
245
+ # backwards pass
246
+ for i in range(max(0, sli[0] - 1), max(0, sli[0] - 5), -1):
247
+ if not lines[i]:
248
+ continue
249
+ if lines[i] == "## References":
250
+ section[0] = i
251
+ break
252
+ elif ratio(base, remove_numbers(lines[i])) < 0.9:
253
+ section[0] = i + 1
254
+ potential_ref = remove_numbers(lines[max(0, i - 1)].partition("* [")[-1])
255
+ if (
256
+ len(potential_ref) >= 0.75 * len(base)
257
+ and ratio(base, potential_ref) < 0.9
258
+ ):
259
+ section[0] = i
260
+ check_start_flag = True
261
+ break
262
+ # forward pass
263
+ for i in range(min(len(lines), sli[1]), min(len(lines), sli[1] + 5)):
264
+ if ratio(base, remove_numbers(lines[i])) < 0.9:
265
+ section[1] = i
266
+ break
267
+ if len(lines) <= section[1]:
268
+ section[1] = len(lines) - 1
269
+ to_delete = "\n".join(lines[section[0] : section[1] + 1])
270
+ # cut off next page content
271
+ itera, iterb = enumerate(lines[section[1] - 1]), enumerate(lines[section[1]])
272
+ while True:
273
+ try:
274
+ (ia, a) = next(itera)
275
+ while a.isnumeric():
276
+ (ia, a) = next(itera)
277
+ (ib, b) = next(iterb)
278
+ while b.isnumeric():
279
+ (ib, b) = next(iterb)
280
+ if a != b:
281
+ break
282
+ except StopIteration:
283
+ break
284
+ if check_start_flag and "* [" in to_delete:
285
+ to_delete = "* [" + to_delete.partition("* [")[-1]
286
+ try:
287
+ delta = len(lines[section[1]]) - ib - 1
288
+ if delta > 0:
289
+ to_delete = to_delete[:-delta]
290
+ except UnboundLocalError:
291
+ pass
292
+
293
+ return to_delete.strip()
294
+
295
+
296
+ def remove_hallucinated_references(text: str) -> str:
297
+ """
298
+ Remove hallucinated or missing references from the text.
299
+
300
+ This function identifies and removes references that are marked as missing or hallucinated
301
+ from the input text.
302
+
303
+ Args:
304
+ text (str): The input text containing references.
305
+
306
+ Returns:
307
+ str: The text with hallucinated references removed.
308
+ """
309
+ lines = text.split("\n")
310
+ if len(lines) == 0:
311
+ return ""
312
+ clean_lines = remove_numbers(lines)
313
+ slices = get_slices(lines, clean_lines)
314
+ to_delete = []
315
+ for sli in slices:
316
+ to_delete.append(remove_slice_from_lines(lines, clean_lines, sli))
317
+ for to_delete in reversed(to_delete):
318
+ text = text.replace(to_delete, "\n\n[MISSING_PAGE_POST]\n\n")
319
+ text = re.sub(
320
+ r"## References\n+\[MISSING_PAGE_POST(:\d+)?\]",
321
+ "\n\n[MISSING_PAGE_POST\\1]",
322
+ text,
323
+ )
324
+ return text
325
+
326
+
327
+ def postprocess_single(generation: str, markdown_fix: bool = True) -> str:
328
+ """
329
+ Postprocess a single generated text.
330
+
331
+ Args:
332
+ generation (str): The generated text to be postprocessed.
333
+ markdown_fix (bool, optional): Whether to perform Markdown formatting fixes. Default is True.
334
+
335
+ Returns:
336
+ str: The postprocessed text.
337
+ """
338
+ generation = re.sub(
339
+ r"(?:\n|^)#+ \d*\W? ?(.{100,})", r"\n\1", generation
340
+ ) # too long section titles probably are none
341
+ generation = generation.strip()
342
+ generation = generation.replace("\n* [leftmargin=*]\n", "\n")
343
+ generation = re.sub(
344
+ r"^#+ (?:\.?(?:\d|[ixv])+)*\s*(?:$|\n\s*)", "", generation, flags=re.M
345
+ )
346
+ # most likely hallucinated titles
347
+ lines = generation.split("\n")
348
+ if (
349
+ lines[-1].startswith("#")
350
+ and lines[-1].lstrip("#").startswith(" ")
351
+ and len(lines) > 1
352
+ ):
353
+ print("INFO: likely hallucinated title at the end of the page: " + lines[-1])
354
+ generation = "\n".join(lines[:-1])
355
+ # obvious repetition detection
356
+ generation = truncate_repetitions(generation)
357
+ # Reference corrections
358
+ generation = remove_hallucinated_references(generation)
359
+ generation = re.sub(
360
+ r"^\* \[\d+\](\s?[A-W]\.+\s?){10,}.*$", "", generation, flags=re.M
361
+ )
362
+ generation = re.sub(r"^(\* \[\d+\])\[\](.*)$", r"\1\2", generation, flags=re.M)
363
+ generation = re.sub(r"(^\w\n\n|\n\n\w$)", "", generation)
364
+ # pmc math artifact correction
365
+ generation = re.sub(
366
+ r"([\s.,()])_([a-zA-Z0-9])__([a-zA-Z0-9]){1,3}_([\s.,:()])",
367
+ r"\1\(\2_{\3}\)\4",
368
+ generation,
369
+ )
370
+ generation = re.sub(
371
+ r"([\s.,\d])_([a-zA-Z0-9])_([\s.,\d;])", r"\1\(\2\)\3", generation
372
+ )
373
+ # footnote mistakes
374
+ generation = re.sub(
375
+ r"(\nFootnote .*?:) (?:footnotetext|thanks):\W*(.*(?:\n\n|$))",
376
+ r"\1 \2",
377
+ generation,
378
+ )
379
+ # TODO Come up with footnote formatting inside a table
380
+ generation = re.sub(r"\[FOOTNOTE:.+?\](.*?)\[ENDFOOTNOTE\]", "", generation)
381
+ # itemize post processing
382
+ for match in reversed(
383
+ list(
384
+ re.finditer(
385
+ r"(?:^)(-|\*)?(?!-|\*) ?((?:\d|[ixv])+ )?.+? (-|\*) (((?:\d|[ixv])+)\.(\d|[ixv]) )?.*(?:$)",
386
+ generation,
387
+ flags=re.I | re.M,
388
+ )
389
+ )
390
+ ):
391
+ start, stop = match.span()
392
+ delim = match.group(3) + " "
393
+ splits = match.group(0).split(delim)
394
+ replacement = ""
395
+ if match.group(1) is not None:
396
+ splits = splits[1:]
397
+ delim1 = match.group(1) + " "
398
+ else:
399
+ delim1 = ""
400
+ # too many false positives
401
+ continue
402
+ pre, post = generation[:start], generation[stop:]
403
+ for i, item in enumerate(splits):
404
+ level = 0
405
+ potential_numeral, _, rest = item.strip().partition(" ")
406
+ if not rest:
407
+ continue
408
+ if re.match(
409
+ r"^[\dixv]+((?:\.[\dixv])?)+$", potential_numeral, flags=re.I | re.M
410
+ ):
411
+ level = potential_numeral.count(".")
412
+
413
+ replacement += (
414
+ ("\n" if i > 0 else "")
415
+ + ("\t" * level)
416
+ + (delim if i > 0 or start == 0 else delim1)
417
+ + item.strip()
418
+ )
419
+ if post == "":
420
+ post = "\n"
421
+ generation = pre + replacement + post
422
+
423
+ if generation.endswith((".", "}")):
424
+ generation += "\n\n"
425
+ if re.match(r"[A-Z0-9,;:]$", generation):
426
+ # add space in case it there is a comma or word ending
427
+ generation += " "
428
+ elif generation.startswith(("#", "**", "\\begin")):
429
+ generation = "\n\n" + generation
430
+ elif generation.split("\n")[-1].startswith(("#", "Figure", "Table")):
431
+ generation = generation + "\n\n"
432
+ else:
433
+ try:
434
+ last_word = generation.split(" ")[-1]
435
+ if last_word in words.words():
436
+ generation += " "
437
+ except LookupError:
438
+ # add space just in case. Will split words but better than concatenating them
439
+ generation += " "
440
+ # download for the next time
441
+ import nltk
442
+
443
+ nltk.download("words")
444
+ # table corrections
445
+ # remove obvious wrong tables
446
+ for l in generation.split("\n"):
447
+ if (
448
+ l.count("\\begin{tabular}") > 15
449
+ or l.count("\\multicolumn") > 60
450
+ or l.count("&") > 400
451
+ ):
452
+ generation = generation.replace(l, "")
453
+ # whitespace corrections
454
+ generation = generation.replace(
455
+ "\\begin{table} \\begin{tabular}", "\\begin{table}\n\\begin{tabular}"
456
+ )
457
+ generation = generation.replace(
458
+ "\\end{tabular} \\end{table}", "\\end{tabular}\n\\end{table}"
459
+ )
460
+ generation = generation.replace("\\end{table} Tab", "\\end{table}\nTab")
461
+ generation = re.sub(r"(^.+)\\begin{tab", r"\1\n\\begin{tab", generation, flags=re.M)
462
+
463
+ generation = generation.replace(
464
+ r"\begin{tabular}{l l} & \\ \end{tabular}", ""
465
+ ).replace("\\begin{tabular}{}\n\n\\end{tabular}", "")
466
+ generation = generation.replace("\\begin{array}[]{", "\\begin{array}{")
467
+ generation = re.sub(
468
+ r"\\begin{tabular}{([clr ]){2,}}\s*[& ]*\s*(\\\\)? \\end{tabular}",
469
+ "",
470
+ generation,
471
+ )
472
+ generation = re.sub(r"(\*\*S\. A\. B\.\*\*\n+){2,}", "", generation)
473
+ generation = re.sub(r"^#+( [\[\d\w])?$", "", generation, flags=re.M)
474
+ generation = re.sub(r"^\.\s*$", "", generation, flags=re.M)
475
+ generation = re.sub(r"\n{3,}", "\n\n", generation)
476
+ if markdown_fix:
477
+ return markdown_compatible(generation)
478
+ else:
479
+ return generation
480
+
481
+
482
+ def postprocess(
483
+ generation: Union[str, List[str]], markdown_fix: bool = True
484
+ ) -> Union[str, List[str]]:
485
+ """
486
+ Postprocess generated text or a list of generated texts.
487
+
488
+ This function can be used to perform postprocessing on generated text, such as fixing Markdown formatting.
489
+
490
+ Args:
491
+ generation (Union[str, List[str]]): The generated text or a list of generated texts.
492
+ markdown_fix (bool, optional): Whether to perform Markdown formatting fixes. Default is True.
493
+
494
+ Returns:
495
+ Union[str, List[str]]: The postprocessed text or list of postprocessed texts.
496
+ """
497
+ if type(generation) == list:
498
+ if os.environ.get("NOUGAT_MULTIPROCESSING"):
499
+ with Pool(int(os.environ.get("NOUGAT_MULTIPROCESSING"))) as p:
500
+ return p.map(
501
+ partial(postprocess_single, markdown_fix=markdown_fix), generation
502
+ )
503
+ else:
504
+ return [
505
+ postprocess_single(s, markdown_fix=markdown_fix) for s in generation
506
+ ]
507
+ else:
508
+ return postprocess_single(generation, markdown_fix=markdown_fix)
nougat/transforms.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ # Implements image augmentation
8
+
9
+ import albumentations as alb
10
+ from albumentations.pytorch import ToTensorV2
11
+ import cv2
12
+ import numpy as np
13
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
14
+
15
+
16
+ def alb_wrapper(transform):
17
+ def f(im):
18
+ return transform(image=np.asarray(im))["image"]
19
+
20
+ return f
21
+
22
+
23
+ class Erosion(alb.ImageOnlyTransform):
24
+ """
25
+ Apply erosion operation to an image.
26
+
27
+ Erosion is a morphological operation that shrinks the white regions in a binary image.
28
+
29
+ Args:
30
+ scale (int or tuple/list of int): The scale or range for the size of the erosion kernel.
31
+ If an integer is provided, a square kernel of that size will be used.
32
+ If a tuple or list is provided, it should contain two integers representing the minimum
33
+ and maximum sizes for the erosion kernel.
34
+ always_apply (bool, optional): Whether to always apply this transformation. Default is False.
35
+ p (float, optional): The probability of applying this transformation. Default is 0.5.
36
+
37
+ Returns:
38
+ numpy.ndarray: The transformed image.
39
+ """
40
+
41
+ def __init__(self, scale, always_apply=False, p=0.5):
42
+ super().__init__(always_apply=always_apply, p=p)
43
+ if type(scale) is tuple or type(scale) is list:
44
+ assert len(scale) == 2
45
+ self.scale = scale
46
+ else:
47
+ self.scale = (scale, scale)
48
+
49
+ def apply(self, img, **params):
50
+ kernel = cv2.getStructuringElement(
51
+ cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2))
52
+ )
53
+ img = cv2.erode(img, kernel, iterations=1)
54
+ return img
55
+
56
+
57
+ class Dilation(alb.ImageOnlyTransform):
58
+ """
59
+ Apply dilation operation to an image.
60
+
61
+ Dilation is a morphological operation that expands the white regions in a binary image.
62
+
63
+ Args:
64
+ scale (int or tuple/list of int): The scale or range for the size of the dilation kernel.
65
+ If an integer is provided, a square kernel of that size will be used.
66
+ If a tuple or list is provided, it should contain two integers representing the minimum
67
+ and maximum sizes for the dilation kernel.
68
+ always_apply (bool, optional): Whether to always apply this transformation. Default is False.
69
+ p (float, optional): The probability of applying this transformation. Default is 0.5.
70
+
71
+ Returns:
72
+ numpy.ndarray: The transformed image.
73
+ """
74
+
75
+ def __init__(self, scale, always_apply=False, p=0.5):
76
+ super().__init__(always_apply=always_apply, p=p)
77
+ if type(scale) is tuple or type(scale) is list:
78
+ assert len(scale) == 2
79
+ self.scale = scale
80
+ else:
81
+ self.scale = (scale, scale)
82
+
83
+ def apply(self, img, **params):
84
+ kernel = cv2.getStructuringElement(
85
+ cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2))
86
+ )
87
+ img = cv2.dilate(img, kernel, iterations=1)
88
+ return img
89
+
90
+
91
+ class Bitmap(alb.ImageOnlyTransform):
92
+ """
93
+ Apply a bitmap-style transformation to an image.
94
+
95
+ This transformation replaces all pixel values below a certain threshold with a specified value.
96
+
97
+ Args:
98
+ value (int, optional): The value to replace pixels below the threshold with. Default is 0.
99
+ lower (int, optional): The threshold value below which pixels will be replaced. Default is 200.
100
+ always_apply (bool, optional): Whether to always apply this transformation. Default is False.
101
+ p (float, optional): The probability of applying this transformation. Default is 0.5.
102
+
103
+ Returns:
104
+ numpy.ndarray: The transformed image.
105
+ """
106
+
107
+ def __init__(self, value=0, lower=200, always_apply=False, p=0.5):
108
+ super().__init__(always_apply=always_apply, p=p)
109
+ self.lower = lower
110
+ self.value = value
111
+
112
+ def apply(self, img, **params):
113
+ img = img.copy()
114
+ img[img < self.lower] = self.value
115
+ return img
116
+
117
+
118
+ train_transform = alb_wrapper(
119
+ alb.Compose(
120
+ [
121
+ Bitmap(p=0.05),
122
+ alb.OneOf([Erosion((2, 3)), Dilation((2, 3))], p=0.02),
123
+ alb.Affine(shear={"x": (0, 3), "y": (-3, 0)}, cval=(255, 255, 255), p=0.03),
124
+ alb.ShiftScaleRotate(
125
+ shift_limit_x=(0, 0.04),
126
+ shift_limit_y=(0, 0.03),
127
+ scale_limit=(-0.15, 0.03),
128
+ rotate_limit=2,
129
+ border_mode=0,
130
+ interpolation=2,
131
+ value=(255, 255, 255),
132
+ p=0.03,
133
+ ),
134
+ alb.GridDistortion(
135
+ distort_limit=0.05,
136
+ border_mode=0,
137
+ interpolation=2,
138
+ value=(255, 255, 255),
139
+ p=0.04,
140
+ ),
141
+ alb.Compose(
142
+ [
143
+ alb.Affine(
144
+ translate_px=(0, 5), always_apply=True, cval=(255, 255, 255)
145
+ ),
146
+ alb.ElasticTransform(
147
+ p=1,
148
+ alpha=50,
149
+ sigma=120 * 0.1,
150
+ alpha_affine=120 * 0.01,
151
+ border_mode=0,
152
+ value=(255, 255, 255),
153
+ ),
154
+ ],
155
+ p=0.04,
156
+ ),
157
+ alb.RandomBrightnessContrast(0.1, 0.1, True, p=0.03),
158
+ alb.ImageCompression(95, p=0.07),
159
+ alb.GaussNoise(20, p=0.08),
160
+ alb.GaussianBlur((3, 3), p=0.03),
161
+ alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
162
+ ToTensorV2(),
163
+ ]
164
+ )
165
+ )
166
+ test_transform = alb_wrapper(
167
+ alb.Compose(
168
+ [
169
+ alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
170
+ ToTensorV2(),
171
+ ]
172
+ )
173
+ )
nougat/utils/__init__.py ADDED
File without changes
nougat/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (149 Bytes). View file
 
nougat/utils/__pycache__/checkpoint.cpython-310.pyc ADDED
Binary file (3.84 kB). View file
 
nougat/utils/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (8.96 kB). View file
 
nougat/utils/checkpoint.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ from typing import Optional
8
+ import requests
9
+ import os
10
+ import tqdm
11
+ import io
12
+ from pathlib import Path
13
+ import torch
14
+
15
+ BASE_URL = "https://github.com/facebookresearch/nougat/releases/download"
16
+ MODEL_TAG = "0.1.0-small"
17
+
18
+
19
+ # source: https://stackoverflow.com/a/71459251
20
+ def download_as_bytes_with_progress(url: str, name: str = None) -> bytes:
21
+ """
22
+ Download a file from a URL and return the contents as bytes, with progress bar.
23
+
24
+ Args:
25
+ url: The URL of the file to download.
26
+ name: The name of the file to save to. If None, the filename will be the same as the URL.
27
+
28
+ Returns:
29
+ bytes: The contents of the file.
30
+ """
31
+ resp = requests.get(url, stream=True, allow_redirects=True)
32
+ total = int(resp.headers.get("content-length", 0))
33
+ bio = io.BytesIO()
34
+ if name is None:
35
+ name = url
36
+ with tqdm.tqdm(
37
+ desc=name,
38
+ total=total,
39
+ unit="b",
40
+ unit_scale=True,
41
+ unit_divisor=1024,
42
+ ) as bar:
43
+ for chunk in resp.iter_content(chunk_size=65536):
44
+ bar.update(len(chunk))
45
+ bio.write(chunk)
46
+ return bio.getvalue()
47
+
48
+
49
+ def download_checkpoint(checkpoint: Path, model_tag: str = MODEL_TAG):
50
+ """
51
+ Download the Nougat model checkpoint.
52
+
53
+ This function downloads the Nougat model checkpoint from GitHub.
54
+
55
+ Args:
56
+ checkpoint (Path): The path to the checkpoint.
57
+ model_tag (str): The model tag to download. Default is "0.1.0-small".
58
+ """
59
+ print("downloading nougat checkpoint version", model_tag, "to path", checkpoint)
60
+ files = [
61
+ "config.json",
62
+ "pytorch_model.bin",
63
+ "special_tokens_map.json",
64
+ "tokenizer.json",
65
+ "tokenizer_config.json",
66
+ ]
67
+ for file in files:
68
+ download_url = f"{BASE_URL}/{model_tag}/{file}"
69
+ binary_file = download_as_bytes_with_progress(download_url, file)
70
+ if len(binary_file) > 15: # sanity check
71
+ (checkpoint / file).write_bytes(binary_file)
72
+
73
+
74
+ def torch_hub(model_tag: Optional[str] = MODEL_TAG) -> Path:
75
+ old_path = Path(torch.hub.get_dir() + "/nougat")
76
+ if model_tag is None:
77
+ model_tag = MODEL_TAG
78
+ hub_path = old_path.with_name(f"nougat-{model_tag}")
79
+ if old_path.exists():
80
+ # move to new format
81
+ old_path.rename(old_path.with_name("nougat-0.1.0-small"))
82
+ return hub_path
83
+
84
+
85
+ def get_checkpoint(
86
+ checkpoint_path: Optional[os.PathLike] = None,
87
+ model_tag: str = MODEL_TAG,
88
+ download: bool = True,
89
+ ) -> Path:
90
+ """
91
+ Get the path to the Nougat model checkpoint.
92
+
93
+ This function retrieves the path to the Nougat model checkpoint. If the checkpoint does not
94
+ exist or is empty, it can optionally download the checkpoint.
95
+
96
+ Args:
97
+ checkpoint_path (Optional[os.PathLike]): The path to the checkpoint. If not provided,
98
+ it will check the "NOUGAT_CHECKPOINT" environment variable or use the default location.
99
+ Default is None.
100
+ model_tag (str): The model tag to download. Default is "0.1.0-small".
101
+ download (bool): Whether to download the checkpoint if it doesn't exist or is empty.
102
+ Default is True.
103
+
104
+ Returns:
105
+ Path: The path to the Nougat model checkpoint.
106
+ """
107
+ checkpoint = Path(
108
+ checkpoint_path or os.environ.get("NOUGAT_CHECKPOINT", torch_hub(model_tag))
109
+ )
110
+ if checkpoint.exists() and checkpoint.is_file():
111
+ checkpoint = checkpoint.parent
112
+ if download and (not checkpoint.exists() or len(os.listdir(checkpoint)) < 5):
113
+ checkpoint.mkdir(parents=True, exist_ok=True)
114
+ download_checkpoint(checkpoint, model_tag=model_tag or MODEL_TAG)
115
+ return checkpoint
116
+
117
+
118
+ if __name__ == "__main__":
119
+ get_checkpoint()
nougat/utils/dataset.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Donut
3
+ Copyright (c) 2022-present NAVER Corp.
4
+ MIT License
5
+ Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ """
7
+ import logging
8
+ import os
9
+ from math import prod
10
+ from pathlib import Path
11
+ from functools import partial
12
+ import random
13
+ from typing import Dict, Tuple, Callable
14
+ from PIL import Image, UnidentifiedImageError
15
+ from typing import List, Optional
16
+
17
+ import torch
18
+ import pypdf
19
+ import orjson
20
+ from torch.utils.data import Dataset
21
+ from transformers.modeling_utils import PreTrainedModel
22
+ from nougat.dataset.rasterize import rasterize_paper
23
+
24
+
25
+ class ImageDataset(torch.utils.data.Dataset):
26
+ """
27
+ Dataset for processing a list of images using a preparation function.
28
+
29
+ This dataset takes a list of image paths and applies a preparation function to each image.
30
+
31
+ Args:
32
+ img_list (list): List of image paths.
33
+ prepare (Callable): A preparation function to process the images.
34
+
35
+ Attributes:
36
+ img_list (list): List of image paths.
37
+ prepare (Callable): The preparation function.
38
+ """
39
+
40
+ def __init__(self, img_list, prepare: Callable):
41
+ super().__init__()
42
+ self.img_list = img_list
43
+ self.prepare = prepare
44
+
45
+ def __len__(self):
46
+ return len(self.img_list)
47
+
48
+ @staticmethod
49
+ def ignore_none_collate(batch):
50
+ if batch is None:
51
+ return
52
+ try:
53
+ batch = [x for x in batch if x is not None and x[0] is not None]
54
+ if len(batch) == 0:
55
+ return
56
+ return torch.utils.data.dataloader.default_collate(batch)
57
+ except AttributeError:
58
+ pass
59
+
60
+ def __getitem__(self, idx):
61
+ try:
62
+ img = Image.open(self.img_list[idx])
63
+ return self.prepare(img)
64
+ except Exception as e:
65
+ logging.error(e)
66
+
67
+
68
+ class LazyDataset(Dataset):
69
+ """
70
+ Lazy loading dataset for processing PDF documents.
71
+
72
+ This dataset allows lazy loading of PDF documents and provides access to processed images
73
+ using a specified preparation function.
74
+
75
+ Args:
76
+ pdf (str): Path to the PDF document.
77
+ prepare (Callable): A preparation function to process the images.
78
+
79
+ Attributes:
80
+ name (str): Name of the PDF document.
81
+ """
82
+
83
+ def __init__(self, pdf, prepare: Callable, pages: Optional[List[int]] = None):
84
+ super().__init__()
85
+ self.prepare = prepare
86
+ self.name = str(pdf)
87
+ self.init_fn = partial(rasterize_paper, pdf, pages=pages)
88
+ self.dataset = None
89
+ self.size = len(pypdf.PdfReader(pdf).pages) if pages is None else len(pages)
90
+
91
+ def __len__(self):
92
+ return self.size
93
+
94
+ def __getitem__(self, i):
95
+ if i == 0 or self.dataset is None:
96
+ self.dataset = ImageDataset(self.init_fn(), self.prepare)
97
+ if i <= self.size and i >= 0:
98
+ return self.dataset[i], self.name if i == self.size - 1 else ""
99
+ else:
100
+ raise IndexError
101
+
102
+ @staticmethod
103
+ def ignore_none_collate(batch):
104
+ if batch is None:
105
+ return None, None
106
+ try:
107
+ _batch = []
108
+ for i, x in enumerate(batch):
109
+ image, name = x
110
+ if image is not None:
111
+ _batch.append(x)
112
+ elif name:
113
+ if i > 0:
114
+ _batch[-1] = (_batch[-1][0], name)
115
+ elif len(batch) > 1:
116
+ _batch.append((batch[1][0] * 0, name))
117
+ if len(_batch) == 0:
118
+ return None, None
119
+ return torch.utils.data.dataloader.default_collate(_batch)
120
+ except AttributeError:
121
+ pass
122
+ return None, None
123
+
124
+
125
+ class SciPDFDataset(Dataset):
126
+ """
127
+ Custom dataset for scientific PDF data.
128
+
129
+ This dataset loads data from JSONL files and provides access to images, ground truth,
130
+ and metadata.
131
+
132
+ Args:
133
+ path_to_index (str): Path to the index file.
134
+ split (str, optional): Split of the dataset (e.g., "train", "test"). Default is "train".
135
+ root_name (str, optional): Root directory name. Default is an empty string.
136
+ template (str, optional): Template for split naming. Default is "%s".
137
+
138
+ Attributes:
139
+ empty_sample: Placeholder for empty samples.
140
+ """
141
+
142
+ empty_sample = None
143
+
144
+ def __init__(
145
+ self,
146
+ path_to_index: str,
147
+ split: str = "train",
148
+ root_name="",
149
+ template="%s",
150
+ ) -> None:
151
+ super().__init__()
152
+ self.path_to_index = Path(path_to_index)
153
+ self.root_name = root_name
154
+ self.path_to_root = self.path_to_index.parent
155
+ if not split in self.path_to_index.stem:
156
+ pti = self.path_to_root / (template % split + ".jsonl")
157
+ if pti.exists():
158
+ self.path_to_index = pti
159
+ else:
160
+ raise ValueError(f'Dataset file for split "{split}" not found: {pti}')
161
+ self.dataset_file = None # mulitprocessing
162
+ # load seek map
163
+ seek_path = self.path_to_root / (self.path_to_index.stem + ".seek.map")
164
+ if seek_path.exists():
165
+ self.seek_map = orjson.loads(seek_path.open().read())
166
+ else:
167
+ raise ValueError(
168
+ 'No "%s" found in %s' % (seek_path.name, str(self.path_to_root))
169
+ )
170
+ self.dataset_length = len(self.seek_map)
171
+
172
+ def __len__(self) -> int:
173
+ return self.dataset_length
174
+
175
+ def __getitem__(self, index: int) -> Dict:
176
+ position = self.seek_map[index]
177
+ if self.dataset_file is None:
178
+ self.dataset_file = self.path_to_index.open()
179
+ self.dataset_file.seek(position)
180
+ line = self.dataset_file.readline()
181
+ try:
182
+ data: Dict = orjson.loads(line)
183
+ except Exception as e:
184
+ logging.info(
185
+ "JSONL for sample %i could not be loaded at position %i: %s\n%s",
186
+ index,
187
+ position,
188
+ str(e),
189
+ line,
190
+ )
191
+ return self.empty_sample
192
+ img_path: Path = self.path_to_root / self.root_name / data.pop("image")
193
+ if not img_path.exists():
194
+ logging.info("Sample %s could not be found.", img_path)
195
+ return self.empty_sample
196
+ try:
197
+ img = Image.open(img_path)
198
+ except UnidentifiedImageError:
199
+ logging.info("Image %s could not be opened.", img_path)
200
+ return self.empty_sample
201
+ return {"image": img, "ground_truth": data.pop("markdown"), "meta": data}
202
+
203
+ def __iter__(self):
204
+ for i in range(self.dataset_length):
205
+ yield self[i]
206
+
207
+
208
+ class NougatDataset(Dataset):
209
+ """
210
+ Args:
211
+ dataset_path: the path to the jsonl file
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ dataset_path: str,
217
+ nougat_model: PreTrainedModel,
218
+ max_length: int,
219
+ split: str = "train",
220
+ root_name: str = "arxiv",
221
+ ):
222
+ super().__init__()
223
+ self.nougat_model = nougat_model
224
+ self.max_length = max_length
225
+ self.split = split
226
+ self.perturb = "NOUGAT_PERTURB" in os.environ and os.environ["NOUGAT_PERTURB"]
227
+ # TODO improve naming conventions
228
+ template = "%s"
229
+ self.dataset = SciPDFDataset(
230
+ dataset_path, split=self.split, template=template, root_name=root_name
231
+ )
232
+ self.dataset_length = len(self.dataset)
233
+
234
+ def __len__(self) -> int:
235
+ return self.dataset_length
236
+
237
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
238
+ """
239
+ Load image from image_path of given dataset_path and convert into input_tensor and labels.
240
+ Convert gt data into input_ids (tokenized string)
241
+
242
+ Returns:
243
+ input_tensor : preprocessed image
244
+ input_ids : tokenized gt_data
245
+ """
246
+ sample = self.dataset[idx]
247
+ if sample is None:
248
+ # if sample is broken choose another randomly
249
+ return self[random.randint(0, self.dataset_length - 1)]
250
+ if sample is None or sample["image"] is None or prod(sample["image"].size) == 0:
251
+ input_tensor = None
252
+ else:
253
+ input_tensor = self.nougat_model.encoder.prepare_input(
254
+ sample["image"], random_padding=self.split == "train"
255
+ )
256
+
257
+ tokenizer_out = self.nougat_model.decoder.tokenizer(
258
+ sample["ground_truth"],
259
+ max_length=self.max_length,
260
+ padding="max_length",
261
+ return_token_type_ids=False,
262
+ truncation=True,
263
+ return_tensors="pt",
264
+ )
265
+ input_ids = tokenizer_out["input_ids"].squeeze(0)
266
+ attention_mask = tokenizer_out["attention_mask"].squeeze(0)
267
+ # randomly perturb ground truth tokens
268
+ if self.split == "train" and self.perturb:
269
+ # check if we perturb tokens
270
+ unpadded_length = attention_mask.sum()
271
+ while random.random() < 0.1:
272
+ try:
273
+ pos = random.randint(1, unpadded_length - 2)
274
+ token = random.randint(
275
+ 23, len(self.nougat_model.decoder.tokenizer) - 1
276
+ )
277
+ input_ids[pos] = token
278
+ except ValueError:
279
+ break
280
+ return input_tensor, input_ids, attention_mask
nougat/utils/device.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ import torch
8
+ import logging
9
+
10
+
11
+ def default_batch_size():
12
+ if torch.cuda.is_available():
13
+ batch_size = int(
14
+ torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1000 * 0.3
15
+ )
16
+ if batch_size == 0:
17
+ logging.warning("GPU VRAM is too small. Computing on CPU.")
18
+ elif torch.backends.mps.is_available():
19
+ # I don't know if there's an equivalent API so heuristically choosing bs=4
20
+ batch_size = 4
21
+ else:
22
+ # don't know what a good value is here. Would not recommend to run on CPU
23
+ batch_size = 1
24
+ logging.warning("No GPU found. Conversion on CPU is very slow.")
25
+ return batch_size
26
+
27
+
28
+ def move_to_device(model, bf16: bool = True, cuda: bool = True):
29
+ try:
30
+ if torch.backends.mps.is_available():
31
+ return model.to("mps")
32
+ except AttributeError:
33
+ pass
34
+ if bf16:
35
+ model = model.to(torch.bfloat16)
36
+ if cuda and torch.cuda.is_available():
37
+ model = model.to("cuda")
38
+ return model
predict.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+ import sys
8
+ from pathlib import Path
9
+ import logging
10
+ import re
11
+ import argparse
12
+ import re
13
+ from functools import partial
14
+ import torch
15
+ from torch.utils.data import ConcatDataset
16
+ from tqdm import tqdm
17
+ from nougat import NougatModel
18
+ from nougat.utils.dataset import LazyDataset
19
+ from nougat.utils.checkpoint import get_checkpoint
20
+ from nougat.postprocessing import markdown_compatible
21
+ import fitz
22
+
23
+ logging.basicConfig(level=logging.INFO)
24
+
25
+ if torch.cuda.is_available():
26
+ BATCH_SIZE = int(
27
+ torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1000 * 0.3
28
+ )
29
+ if BATCH_SIZE == 0:
30
+ logging.warning("GPU VRAM is too small. Computing on CPU.")
31
+ else:
32
+ # don't know what a good value is here. Would not recommend to run on CPU
33
+ BATCH_SIZE = 1
34
+ logging.warning("No GPU found. Conversion on CPU is very slow.")
35
+
36
+
37
+ def get_args():
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument(
40
+ "--batchsize",
41
+ "-b",
42
+ type=int,
43
+ default=BATCH_SIZE,
44
+ help="Batch size to use.",
45
+ )
46
+ parser.add_argument(
47
+ "--checkpoint",
48
+ "-c",
49
+ type=Path,
50
+ default=None,
51
+ help="Path to checkpoint directory.",
52
+ )
53
+ parser.add_argument("--out", "-o", type=Path, help="Output directory.")
54
+ parser.add_argument(
55
+ "--recompute",
56
+ action="store_true",
57
+ help="Recompute already computed PDF, discarding previous predictions.",
58
+ )
59
+ parser.add_argument(
60
+ "--markdown",
61
+ action="store_true",
62
+ help="Add postprocessing step for markdown compatibility.",
63
+ )
64
+ parser.add_argument("pdf", nargs="+", type=Path, help="PDF(s) to process.")
65
+ args = parser.parse_args()
66
+ if args.checkpoint is None or not args.checkpoint.exists():
67
+ args.checkpoint = get_checkpoint(args.checkpoint)
68
+ if args.out is None:
69
+ logging.warning("No output directory. Output will be printed to console.")
70
+ else:
71
+ if not args.out.exists():
72
+ logging.info("Output directory does not exist. Creating output directory.")
73
+ args.out.mkdir(parents=True)
74
+ if not args.out.is_dir():
75
+ logging.error("Output has to be directory.")
76
+ sys.exit(1)
77
+ if len(args.pdf) == 1 and not args.pdf[0].suffix == ".pdf":
78
+ # input is a list of pdfs
79
+ try:
80
+ args.pdf = [
81
+ Path(l) for l in open(args.pdf[0]).read().split("\n") if len(l) > 0
82
+ ]
83
+ except:
84
+ pass
85
+ return args
86
+
87
+
88
+ def main():
89
+ args = get_args()
90
+ model = NougatModel.from_pretrained(args.checkpoint).to(torch.bfloat16)
91
+ if args.batchsize > 0:
92
+ if torch.cuda.is_available():
93
+ model.to("cuda")
94
+ else:
95
+ # set batch size to 1. Need to check if there are benefits for CPU conversion for >1
96
+ args.batchsize = 1
97
+ model.eval()
98
+ datasets = []
99
+ for pdf in args.pdf:
100
+ if not pdf.exists():
101
+ continue
102
+ if args.out:
103
+ out_path = args.out / pdf.with_suffix(".mmd").name
104
+ if out_path.exists() and not args.recompute:
105
+ logging.info(
106
+ f"Skipping {pdf.name}, already computed. Run with --recompute to convert again."
107
+ )
108
+ continue
109
+ try:
110
+ dataset = LazyDataset(
111
+ pdf, partial(model.encoder.prepare_input, random_padding=False)
112
+ )
113
+ except fitz.fitz.FileDataError:
114
+ logging.info(f"Could not load file {str(pdf)}.")
115
+ continue
116
+ datasets.append(dataset)
117
+ if len(datasets) == 0:
118
+ return
119
+ dataloader = torch.utils.data.DataLoader(
120
+ ConcatDataset(datasets),
121
+ batch_size=args.batchsize,
122
+ shuffle=False,
123
+ collate_fn=LazyDataset.ignore_none_collate,
124
+ )
125
+
126
+ predictions = []
127
+ file_index = 0
128
+ page_num = 0
129
+ for i, (sample, is_last_page) in enumerate(tqdm(dataloader)):
130
+ model_output = model.inference(image_tensors=sample)
131
+ # check if model output is faulty
132
+ for j, output in enumerate(model_output["predictions"]):
133
+ if page_num == 0:
134
+ logging.info(
135
+ "Processing file %s with %i pages"
136
+ % (datasets[file_index].name, datasets[file_index].size)
137
+ )
138
+ page_num += 1
139
+ if output.strip() == "[MISSING_PAGE_POST]":
140
+ # uncaught repetitions -- most likely empty page
141
+ predictions.append(f"\n\n[MISSING_PAGE_EMPTY:{page_num}]\n\n")
142
+ elif model_output["repeats"][j] is not None:
143
+ if model_output["repeats"][j] > 0:
144
+ # If we end up here, it means the output is most likely not complete and was truncated.
145
+ logging.warning(f"Skipping page {page_num} due to repetitions.")
146
+ predictions.append(f"\n\n[MISSING_PAGE_FAIL:{page_num}]\n\n")
147
+ else:
148
+ # If we end up here, it means the document page is too different from the training domain.
149
+ # This can happen e.g. for cover pages.
150
+ predictions.append(
151
+ f"\n\n[MISSING_PAGE_EMPTY:{i*args.batchsize+j+1}]\n\n"
152
+ )
153
+ else:
154
+ if args.markdown:
155
+ output = markdown_compatible(output)
156
+ predictions.append(output)
157
+ if is_last_page[j]:
158
+ out = "".join(predictions).strip()
159
+ out = re.sub(r"\n{3,}", "\n\n", out).strip()
160
+ if args.out:
161
+ out_path = args.out / Path(is_last_page[j]).with_suffix(".mmd").name
162
+ out_path.parent.mkdir(parents=True, exist_ok=True)
163
+ out_path.write_text(out, encoding="utf-8")
164
+ else:
165
+ print(out, "\n\n")
166
+ predictions = []
167
+ page_num = 0
168
+ file_index += 1
169
+
170
+
171
+ if __name__ == "__main__":
172
+ main()