Vik Paruchuri commited on
Commit
fc65ff4
·
1 Parent(s): a5c1c2e

Load models externally

Browse files
README.md CHANGED
@@ -2,7 +2,7 @@
2
 
3
  This project converts PDF to Markdown, balancing speed with quality:
4
 
5
- - Equations will be detected and converted to Latex. This is not 100% accurate.
6
  - All headers/footers/other artifacts will be removed.
7
 
8
 
@@ -10,4 +10,9 @@ This project converts PDF to Markdown, balancing speed with quality:
10
  ## Install
11
 
12
  - `poetry install`
13
- - Set `TESSDATA_PREFIX`
 
 
 
 
 
 
2
 
3
  This project converts PDF to Markdown, balancing speed with quality:
4
 
5
+ - Equations will be detected and converted to Latex when possible.
6
  - All headers/footers/other artifacts will be removed.
7
 
8
 
 
10
  ## Install
11
 
12
  - `poetry install`
13
+ - Set `TESSDATA_PREFIX`
14
+
15
+
16
+ ## Usage
17
+
18
+ Can work with CPU, MPS, or GPU
marker/code.py CHANGED
@@ -4,7 +4,7 @@ from typing import List
4
  import fitz as pymupdf
5
 
6
 
7
- def is_code_linelen(lines, thresh=50):
8
  # Decide based on chars per newline threshold
9
  total_alnum_chars = sum(len(re.findall(r'\w', line.prelim_text)) for line in lines)
10
  total_newlines = len(lines) - 1
@@ -16,7 +16,20 @@ def is_code_linelen(lines, thresh=50):
16
  return ratio < thresh
17
 
18
 
 
 
 
 
 
19
  def identify_code_blocks(blocks: List[Page]):
 
 
 
 
 
 
 
 
20
  for page in blocks:
21
  try:
22
  common_height = page.get_line_height_stats().most_common(1)[0][0]
@@ -31,19 +44,30 @@ def identify_code_blocks(blocks: List[Page]):
31
  continue
32
 
33
  is_code = []
 
34
  for line in block.lines:
35
  fonts = [span.font for span in line.spans]
36
- monospace_font = any([font for font in fonts if "mono" in font.lower() or "prop" in font.lower()])
37
  line_height = line.bbox[3] - line.bbox[1]
38
  line_start = line.bbox[0]
39
- if line_height <= common_height and line_start > common_start and monospace_font:
40
  is_code.append(True)
41
  else:
42
  is_code.append(False)
 
43
  is_code = [
44
- sum(is_code) > len(block.lines) / 1.5,
45
- len(block.lines) > 4,
46
- is_code_linelen(block.lines)
 
 
 
 
 
 
 
 
 
47
  ]
48
 
49
  if all(is_code):
@@ -54,7 +78,8 @@ def indent_blocks(blocks: List[Page]):
54
  span_counter = 0
55
  for page in blocks:
56
  for block in page.blocks:
57
- if block.most_common_block_type() != "Code":
 
58
  continue
59
 
60
  lines = []
 
4
  import fitz as pymupdf
5
 
6
 
7
+ def is_code_linelen(lines, thresh=70):
8
  # Decide based on chars per newline threshold
9
  total_alnum_chars = sum(len(re.findall(r'\w', line.prelim_text)) for line in lines)
10
  total_newlines = len(lines) - 1
 
16
  return ratio < thresh
17
 
18
 
19
+ def comment_count(lines):
20
+ pattern = re.compile(r"^(//|#|'|--|/\*|'''|\"\"\"|--\[\[|<!--|%|%{|\(\*)")
21
+ return sum([1 for line in lines if pattern.match(line)])
22
+
23
+
24
  def identify_code_blocks(blocks: List[Page]):
25
+ font_info = None
26
+ for p in blocks:
27
+ stats = p.get_font_stats()
28
+ if font_info is None:
29
+ font_info = stats
30
+ else:
31
+ font_info += stats
32
+ most_common_font = font_info.most_common(1)[0][0]
33
  for page in blocks:
34
  try:
35
  common_height = page.get_line_height_stats().most_common(1)[0][0]
 
44
  continue
45
 
46
  is_code = []
47
+ line_fonts = []
48
  for line in block.lines:
49
  fonts = [span.font for span in line.spans]
50
+ line_fonts += fonts
51
  line_height = line.bbox[3] - line.bbox[1]
52
  line_start = line.bbox[0]
53
+ if line_start > common_start:
54
  is_code.append(True)
55
  else:
56
  is_code.append(False)
57
+ comment_lines = comment_count([line.prelim_text for line in block.lines])
58
  is_code = [
59
+ len(block.lines) > 2,
60
+ sum([f != most_common_font for f in line_fonts]) > len(line_fonts) // 1.5, # At least 1/3 of the fonts are not the most common, since code usually uses a different font from the main body text
61
+ (
62
+ sum(is_code) > len(block.lines) * .2
63
+ or
64
+ comment_lines > len(block.lines) * .1
65
+ ), # 20% of lines are indented or comments
66
+ (
67
+ is_code_linelen(block.lines)
68
+ or
69
+ comment_lines > len(block.lines) * .1
70
+ ), # 60 chars per newline or less for code, or 20% of lines are comments
71
  ]
72
 
73
  if all(is_code):
 
78
  span_counter = 0
79
  for page in blocks:
80
  for block in page.blocks:
81
+ block_types = [span.block_type for line in block.lines for span in line.spans]
82
+ if "Code" not in block_types:
83
  continue
84
 
85
  lines = []
marker/equations.py CHANGED
@@ -13,7 +13,7 @@ from marker.schema import Page, Span, Line, Block, BlockType
13
  from nougat.utils.device import move_to_device
14
 
15
 
16
- def load_model():
17
  ckpt = get_checkpoint(None, model_tag="0.1.0-small")
18
  nougat_model = NougatModel.from_pretrained(ckpt)
19
  if settings.TORCH_DEVICE != "cpu":
@@ -23,12 +23,6 @@ def load_model():
23
  return nougat_model
24
 
25
 
26
- nougat_model = load_model()
27
- MODEL_MAX = nougat_model.config.max_length
28
-
29
- NOUGAT_HALLUCINATION_WORDS = ["[MISSING_PAGE_POST]", "## References\n", "**Figure Captions**\n", "Footnote", "\par\par\par", "## Chapter", "Fig."]
30
-
31
-
32
  def contains_equation(text):
33
  # Define a regular expression pattern to look for operators and symbols commonly found in equations
34
  pattern = re.compile(r'[=\^\√∑∏∫∂∆π≈≠≤≥∞∩∪∈∉∀∃∅∇λμσαβγδεζηθφχψω]')
@@ -66,18 +60,18 @@ def mask_bbox(png_image, bbox, selected_bboxes):
66
  return result
67
 
68
 
69
- def get_nougat_text(page, old_text, bbox, selected_bboxes, save_id, max_length=MODEL_MAX):
70
  pix = page.get_pixmap(dpi=settings.DPI, clip=bbox)
71
  png = pix.pil_tobytes(format="PNG")
72
  png_image = Image.open(io.BytesIO(png))
73
  png_image = mask_bbox(png_image, bbox, selected_bboxes)
74
 
75
- nougat_model.config.max_length = min(max_length, MODEL_MAX)
76
  output = nougat_model.inference(image=png_image)
77
  return output["predictions"][0]
78
 
79
 
80
- def replace_equations(doc, blocks: List[Page], block_types: List[List[BlockType]]):
81
  span_id = 0
82
  new_blocks = []
83
  for pnum, page in enumerate(blocks):
@@ -126,10 +120,10 @@ def replace_equations(doc, blocks: List[Page], block_types: List[List[BlockType]
126
  # This prevents hallucinations from running on for a long time
127
  max_tokens = len(block_text) + 50
128
  max_char_length = 2 * len(block_text) + 100
129
- nougat_text = get_nougat_text(doc[pnum], block_text, bbox, selected_bboxes, f"{pnum}_{i}", max_length=max_tokens)
130
  conditions = [
131
  len(nougat_text) > 0,
132
- not any([word in nougat_text for word in NOUGAT_HALLUCINATION_WORDS]),
133
  len(nougat_text) < max_char_length, # Reduce hallucinations
134
  len(nougat_text) >= len(block_text) * .8
135
  ]
 
13
  from nougat.utils.device import move_to_device
14
 
15
 
16
+ def load_nougat_model():
17
  ckpt = get_checkpoint(None, model_tag="0.1.0-small")
18
  nougat_model = NougatModel.from_pretrained(ckpt)
19
  if settings.TORCH_DEVICE != "cpu":
 
23
  return nougat_model
24
 
25
 
 
 
 
 
 
 
26
  def contains_equation(text):
27
  # Define a regular expression pattern to look for operators and symbols commonly found in equations
28
  pattern = re.compile(r'[=\^\√∑∏∫∂∆π≈≠≤≥∞∩∪∈∉∀∃∅∇λμσαβγδεζηθφχψω]')
 
60
  return result
61
 
62
 
63
+ def get_nougat_text(page, bbox, selected_bboxes, nougat_model, max_length=settings.NOUGAT_MODEL_MAX):
64
  pix = page.get_pixmap(dpi=settings.DPI, clip=bbox)
65
  png = pix.pil_tobytes(format="PNG")
66
  png_image = Image.open(io.BytesIO(png))
67
  png_image = mask_bbox(png_image, bbox, selected_bboxes)
68
 
69
+ nougat_model.config.max_length = min(max_length, settings.NOUGAT_MODEL_MAX)
70
  output = nougat_model.inference(image=png_image)
71
  return output["predictions"][0]
72
 
73
 
74
+ def replace_equations(doc, blocks: List[Page], block_types: List[List[BlockType]], nougat_model):
75
  span_id = 0
76
  new_blocks = []
77
  for pnum, page in enumerate(blocks):
 
120
  # This prevents hallucinations from running on for a long time
121
  max_tokens = len(block_text) + 50
122
  max_char_length = 2 * len(block_text) + 100
123
+ nougat_text = get_nougat_text(doc[pnum], bbox, selected_bboxes, nougat_model, max_length=max_tokens)
124
  conditions = [
125
  len(nougat_text) > 0,
126
+ not any([word in nougat_text for word in settings.NOUGAT_HALLUCINATION_WORDS]),
127
  len(nougat_text) < max_char_length, # Reduce hallucinations
128
  len(nougat_text) >= len(block_text) * .8
129
  ]
marker/markdown.py CHANGED
@@ -69,7 +69,7 @@ def block_surround(text, block_type):
69
  case "List-item":
70
  pass
71
  case "Code":
72
- text = "```\n" + text + "\n```\n"
73
  case _:
74
  pass
75
  return text
 
69
  case "List-item":
70
  pass
71
  case "Code":
72
+ text = "\n" + text + "\n"
73
  case _:
74
  pass
75
  return text
marker/segmentation.py CHANGED
@@ -16,11 +16,8 @@ processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", app
16
 
17
  CHUNK_KEYS = ["input_ids", "attention_mask", "bbox", "offset_mapping"]
18
  NO_CHUNK_KEYS = ["pixel_values"]
19
- MODEL_MAX_LEN = 512
20
- CHUNK_OVERLAP = 128
21
 
22
-
23
- def load_model():
24
  model = LayoutLMv3ForTokenClassification.from_pretrained("Kwan0/layoutlmv3-base-finetune-DocLayNet-100k").to(settings.TORCH_DEVICE)
25
  model.config.id2label = {
26
  0: "Caption",
@@ -40,19 +37,16 @@ def load_model():
40
  return model
41
 
42
 
43
- layoutlm_model = load_model()
44
-
45
-
46
- def detect_all_block_types(doc, blocks: List[Page]):
47
  block_types = []
48
  for pnum, page in enumerate(doc):
49
  page_blocks = blocks[pnum]
50
- predictions = detect_page_block_types(page, page_blocks)
51
  block_types.append(predictions)
52
  return block_types
53
 
54
 
55
- def detect_page_block_types(page, page_blocks: Page):
56
  page_box = page.bound()
57
  pwidth = page_box[2] - page_box[0]
58
  pheight = page_box[3] - page_box[1]
@@ -66,7 +60,7 @@ def detect_page_block_types(page, page_blocks: Page):
66
  boxes = [s.bbox for s in lines]
67
  text = [s.prelim_text for s in lines]
68
 
69
- predictions = make_predictions(rgb_image, text, boxes, pwidth, pheight)
70
  return predictions
71
 
72
 
@@ -85,10 +79,10 @@ def get_provisional_boxes(pred, box, is_subword, start_idx=0):
85
  return prov_predictions, prov_boxes
86
 
87
 
88
- def make_predictions(rgb_image, text, boxes, pwidth, pheight) -> List[BlockType]:
89
  # Normalize boxes for model (scale to 1000x1000)
90
  boxes = [normalize_box(box, pwidth, pheight) for box in boxes]
91
- encoding = processor(rgb_image, text=text, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True, stride=CHUNK_OVERLAP, padding="max_length", max_length=MODEL_MAX_LEN, return_overflowing_tokens=True)
92
  offset_mapping = encoding.pop('offset_mapping')
93
  overflow_to_sample_mapping = encoding.pop('overflow_to_sample_mapping')
94
 
@@ -108,7 +102,7 @@ def make_predictions(rgb_image, text, boxes, pwidth, pheight) -> List[BlockType]
108
  predictions = logits.argmax(-1).squeeze().tolist()
109
  token_boxes = encoding.bbox.squeeze().tolist()
110
 
111
- if len(token_boxes) == MODEL_MAX_LEN:
112
  predictions = [predictions]
113
  token_boxes = [token_boxes]
114
 
@@ -118,7 +112,7 @@ def make_predictions(rgb_image, text, boxes, pwidth, pheight) -> List[BlockType]
118
  is_subword = np.array(mapped.squeeze().tolist())[:, 0] != 0
119
  overlap_adjust = 0
120
  if i > 0:
121
- overlap_adjust = 1 + CHUNK_OVERLAP - sum(is_subword[:1 + CHUNK_OVERLAP])
122
 
123
  prov_predictions, prov_boxes = get_provisional_boxes(pred, box, is_subword, overlap_adjust)
124
 
@@ -135,5 +129,23 @@ def make_predictions(rgb_image, text, boxes, pwidth, pheight) -> List[BlockType]
135
  if len(predicted_block_types) == 0 or unnorm_box != predicted_block_types[-1].bbox:
136
  predicted_block_types.append(block_type)
137
 
138
- return predicted_block_types
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
 
16
 
17
  CHUNK_KEYS = ["input_ids", "attention_mask", "bbox", "offset_mapping"]
18
  NO_CHUNK_KEYS = ["pixel_values"]
 
 
19
 
20
+ def load_layout_model():
 
21
  model = LayoutLMv3ForTokenClassification.from_pretrained("Kwan0/layoutlmv3-base-finetune-DocLayNet-100k").to(settings.TORCH_DEVICE)
22
  model.config.id2label = {
23
  0: "Caption",
 
37
  return model
38
 
39
 
40
+ def detect_all_block_types(doc, blocks: List[Page], layoutlm_model):
 
 
 
41
  block_types = []
42
  for pnum, page in enumerate(doc):
43
  page_blocks = blocks[pnum]
44
+ predictions = detect_page_block_types(page, page_blocks, layoutlm_model)
45
  block_types.append(predictions)
46
  return block_types
47
 
48
 
49
+ def detect_page_block_types(page, page_blocks: Page, layoutlm_model):
50
  page_box = page.bound()
51
  pwidth = page_box[2] - page_box[0]
52
  pheight = page_box[3] - page_box[1]
 
60
  boxes = [s.bbox for s in lines]
61
  text = [s.prelim_text for s in lines]
62
 
63
+ predictions = make_predictions(rgb_image, text, boxes, pwidth, pheight, layoutlm_model)
64
  return predictions
65
 
66
 
 
79
  return prov_predictions, prov_boxes
80
 
81
 
82
+ def make_predictions(rgb_image, text, boxes, pwidth, pheight, layoutlm_model) -> List[BlockType]:
83
  # Normalize boxes for model (scale to 1000x1000)
84
  boxes = [normalize_box(box, pwidth, pheight) for box in boxes]
85
+ encoding = processor(rgb_image, text=text, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True, stride=settings.LAYOUT_CHUNK_OVERLAP, padding="max_length", max_length=settings.LAYOUT_MODEL_MAX, return_overflowing_tokens=True)
86
  offset_mapping = encoding.pop('offset_mapping')
87
  overflow_to_sample_mapping = encoding.pop('overflow_to_sample_mapping')
88
 
 
102
  predictions = logits.argmax(-1).squeeze().tolist()
103
  token_boxes = encoding.bbox.squeeze().tolist()
104
 
105
+ if len(token_boxes) == settings.LAYOUT_MODEL_MAX:
106
  predictions = [predictions]
107
  token_boxes = [token_boxes]
108
 
 
112
  is_subword = np.array(mapped.squeeze().tolist())[:, 0] != 0
113
  overlap_adjust = 0
114
  if i > 0:
115
+ overlap_adjust = 1 + settings.LAYOUT_CHUNK_OVERLAP - sum(is_subword[:1 + settings.LAYOUT_CHUNK_OVERLAP])
116
 
117
  prov_predictions, prov_boxes = get_provisional_boxes(pred, box, is_subword, overlap_adjust)
118
 
 
129
  if len(predicted_block_types) == 0 or unnorm_box != predicted_block_types[-1].bbox:
130
  predicted_block_types.append(block_type)
131
 
132
+ # Align bboxes
133
+ # This will search both lists to find matching bboxes
134
+ # This will align both sets of bboxes by index
135
+ # If there are duplicate bboxes, it may result in issues
136
+ aligned_blocks = []
137
+ for i in range(len(boxes)):
138
+ unnorm_box = unnormalize_box(boxes[i], pwidth, pheight)
139
+ appended = False
140
+ for j in range(len(predicted_block_types)):
141
+ if unnorm_box == predicted_block_types[j].bbox:
142
+ aligned_blocks.append(predicted_block_types[j])
143
+ appended = True
144
+ break
145
+ if not appended:
146
+ aligned_blocks.append(BlockType(
147
+ block_type="Text",
148
+ bbox=unnorm_box
149
+ ))
150
+ return aligned_blocks
151
 
marker/settings.py CHANGED
@@ -12,6 +12,11 @@ class Settings(BaseSettings):
12
  TORCH_DEVICE: str = "cpu"
13
  TESSDATA_PREFIX: str = ""
14
  BAD_SPAN_TYPES: List[str] = ["Caption", "Footnote", "Page-footer", "Page-header", "Picture"]
 
 
 
 
 
15
 
16
  class Config:
17
  env_file = find_dotenv("local.env")
 
12
  TORCH_DEVICE: str = "cpu"
13
  TESSDATA_PREFIX: str = ""
14
  BAD_SPAN_TYPES: List[str] = ["Caption", "Footnote", "Page-footer", "Page-header", "Picture"]
15
+ NOUGAT_MODEL_MAX: int = 1024 # Max inference length for nougat
16
+ NOUGAT_HALLUCINATION_WORDS: List[str] = ["[MISSING_PAGE_POST]", "## References\n", "**Figure Captions**\n", "Footnote",
17
+ "\par\par\par", "## Chapter", "Fig."]
18
+ LAYOUT_MODEL_MAX: int = 512
19
+ LAYOUT_CHUNK_OVERLAP: int = 128
20
 
21
  class Config:
22
  env_file = find_dotenv("local.env")
parse.py CHANGED
@@ -1,8 +1,10 @@
 
 
1
  import fitz as pymupdf
2
  from marker.extract_text import get_text_blocks
3
  from marker.headers import categorize_blocks, filter_header_footer
4
- from marker.equations import replace_equations
5
- from marker.segmentation import detect_all_block_types
6
  from marker.code import identify_code_blocks, indent_blocks
7
  from marker.markdown import merge_spans, merge_lines, get_full_text
8
  from marker.schema import Page, BlockType
@@ -17,11 +19,17 @@ def annotate_spans(blocks: List[Page], block_types: List[BlockType]):
17
 
18
 
19
  if __name__ == "__main__":
20
- fname = "test_data/thinkpython.pdf"
 
 
 
 
 
21
  doc = pymupdf.open(fname)
22
  blocks, toc = get_text_blocks(doc)
23
 
24
- block_types = detect_all_block_types(doc, blocks)
 
25
 
26
  filtered = deepcopy(blocks)
27
  annotate_spans(filtered, block_types)
@@ -38,12 +46,13 @@ if __name__ == "__main__":
38
  block.filter_spans(bad_span_ids)
39
  block.filter_bad_span_types(block_types[page.pnum])
40
 
41
- filtered = replace_equations(doc, filtered, block_types)
 
42
 
43
  # Copy to avoid changing original data
44
  merged_lines = merge_spans(filtered)
45
  text_blocks = merge_lines(merged_lines, filtered)
46
  full_text = get_full_text(text_blocks)
47
 
48
- with open("test_data/thinkpython.md", "w+") as f:
49
  f.write(full_text)
 
1
+ import argparse
2
+
3
  import fitz as pymupdf
4
  from marker.extract_text import get_text_blocks
5
  from marker.headers import categorize_blocks, filter_header_footer
6
+ from marker.equations import replace_equations, load_nougat_model
7
+ from marker.segmentation import detect_all_block_types, load_layout_model
8
  from marker.code import identify_code_blocks, indent_blocks
9
  from marker.markdown import merge_spans, merge_lines, get_full_text
10
  from marker.schema import Page, BlockType
 
19
 
20
 
21
  if __name__ == "__main__":
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("filename", help="PDF file to parse")
24
+ parser.add_argument("output", help="Output file name")
25
+ args = parser.parse_args()
26
+
27
+ fname = args.filename
28
  doc = pymupdf.open(fname)
29
  blocks, toc = get_text_blocks(doc)
30
 
31
+ layoutlm_model = load_layout_model()
32
+ block_types = detect_all_block_types(doc, blocks, layoutlm_model)
33
 
34
  filtered = deepcopy(blocks)
35
  annotate_spans(filtered, block_types)
 
46
  block.filter_spans(bad_span_ids)
47
  block.filter_bad_span_types(block_types[page.pnum])
48
 
49
+ nougat_model = load_nougat_model()
50
+ filtered = replace_equations(doc, filtered, block_types, nougat_model)
51
 
52
  # Copy to avoid changing original data
53
  merged_lines = merge_spans(filtered)
54
  text_blocks = merge_lines(merged_lines, filtered)
55
  full_text = get_full_text(text_blocks)
56
 
57
+ with open(args.output, "w+") as f:
58
  f.write(full_text)