Vik Paruchuri commited on
Commit
e332feb
·
1 Parent(s): 0f59b51

Merge plus form processor

Browse files
marker/config/parser.py CHANGED
@@ -34,45 +34,39 @@ class ConfigParser:
34
  fn = click.option("--disable_multiprocessing", is_flag=True, default=False, help="Disable multiprocessing.")(fn)
35
  fn = click.option("--paginate_output", is_flag=True, default=False, help="Paginate output.")(fn)
36
  fn = click.option("--disable_image_extraction", is_flag=True, default=False, help="Disable image extraction.")(fn)
37
- fn = click.option("--high_quality", is_flag=True, default=False, help="Enable high quality processing with Gemini.")(fn)
38
  return fn
39
 
40
  def generate_config_dict(self) -> Dict[str, any]:
41
  config = {}
42
  output_dir = self.cli_options.get("output_dir", settings.OUTPUT_DIR)
43
  for k, v in self.cli_options.items():
 
 
 
44
  match k:
45
  case "debug":
46
- if v:
47
- config["debug_pdf_images"] = True
48
- config["debug_layout_images"] = True
49
- config["debug_json"] = True
50
- config["debug_data_folder"] = output_dir
51
  case "page_range":
52
- if v:
53
- config["page_range"] = parse_range_str(v)
54
  case "force_ocr":
55
- if v:
56
- config["force_ocr"] = True
57
  case "languages":
58
- if v:
59
- config["languages"] = v.split(",")
60
  case "config_json":
61
- if v:
62
- with open(v, "r") as f:
63
- config.update(json.load(f))
64
  case "disable_multiprocessing":
65
- if v:
66
- config["pdftext_workers"] = 1
67
  case "paginate_output":
68
- if v:
69
- config["paginate_output"] = True
70
  case "disable_image_extraction":
71
- if v:
72
- config["extract_images"] = False
73
  case "high_quality":
74
- if v:
75
- config["high_quality"] = True
76
  return config
77
 
78
  def get_renderer(self):
 
34
  fn = click.option("--disable_multiprocessing", is_flag=True, default=False, help="Disable multiprocessing.")(fn)
35
  fn = click.option("--paginate_output", is_flag=True, default=False, help="Paginate output.")(fn)
36
  fn = click.option("--disable_image_extraction", is_flag=True, default=False, help="Disable image extraction.")(fn)
37
+ fn = click.option("--high_quality", is_flag=True, default=False, help="Enable high quality processing with LLMs.")(fn)
38
  return fn
39
 
40
  def generate_config_dict(self) -> Dict[str, any]:
41
  config = {}
42
  output_dir = self.cli_options.get("output_dir", settings.OUTPUT_DIR)
43
  for k, v in self.cli_options.items():
44
+ if not v:
45
+ continue
46
+
47
  match k:
48
  case "debug":
49
+ config["debug_pdf_images"] = True
50
+ config["debug_layout_images"] = True
51
+ config["debug_json"] = True
52
+ config["debug_data_folder"] = output_dir
 
53
  case "page_range":
54
+ config["page_range"] = parse_range_str(v)
 
55
  case "force_ocr":
56
+ config["force_ocr"] = True
 
57
  case "languages":
58
+ config["languages"] = v.split(",")
 
59
  case "config_json":
60
+ with open(v, "r") as f:
61
+ config.update(json.load(f))
 
62
  case "disable_multiprocessing":
63
+ config["pdftext_workers"] = 1
 
64
  case "paginate_output":
65
+ config["paginate_output"] = True
 
66
  case "disable_image_extraction":
67
+ config["extract_images"] = False
 
68
  case "high_quality":
69
+ config["high_quality"] = True
 
70
  return config
71
 
72
  def get_renderer(self):
marker/converters/pdf.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning
3
 
4
  import inspect
5
  from collections import defaultdict
@@ -17,6 +17,8 @@ from marker.processors.debug import DebugProcessor
17
  from marker.processors.document_toc import DocumentTOCProcessor
18
  from marker.processors.equation import EquationProcessor
19
  from marker.processors.footnote import FootnoteProcessor
 
 
20
  from marker.processors.high_quality_text import HighQualityTextProcessor
21
  from marker.processors.ignoretext import IgnoreTextProcessor
22
  from marker.processors.line_numbers import LineNumbersProcessor
@@ -68,6 +70,8 @@ class PdfConverter(BaseConverter):
68
  PageHeaderProcessor,
69
  SectionHeaderProcessor,
70
  TableProcessor,
 
 
71
  TextProcessor,
72
  HighQualityTextProcessor,
73
  DebugProcessor,
 
1
  import os
2
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning
3
 
4
  import inspect
5
  from collections import defaultdict
 
17
  from marker.processors.document_toc import DocumentTOCProcessor
18
  from marker.processors.equation import EquationProcessor
19
  from marker.processors.footnote import FootnoteProcessor
20
+ from marker.processors.llm.highqualityformprocessor import HighQualityFormProcessor
21
+ from marker.processors.llm.highqualitytableprocessor import HighQualityTableProcessor
22
  from marker.processors.high_quality_text import HighQualityTextProcessor
23
  from marker.processors.ignoretext import IgnoreTextProcessor
24
  from marker.processors.line_numbers import LineNumbersProcessor
 
70
  PageHeaderProcessor,
71
  SectionHeaderProcessor,
72
  TableProcessor,
73
+ HighQualityTableProcessor,
74
+ HighQualityFormProcessor,
75
  TextProcessor,
76
  HighQualityTextProcessor,
77
  DebugProcessor,
marker/llm.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+
4
+ import PIL
5
+ import google.generativeai as genai
6
+ from google.ai.generativelanguage_v1beta.types import content
7
+ from google.api_core.exceptions import ResourceExhausted
8
+
9
+
10
+ class GoogleModel:
11
+ def __init__(self, api_key: str, model_name: str):
12
+ if api_key is None:
13
+ raise ValueError("Google API key is not set")
14
+
15
+ self.api_key = api_key
16
+ self.model_name = model_name
17
+ self.model = self.configure_google_model()
18
+
19
+ def configure_google_model(self):
20
+ genai.configure(api_key=self.api_key)
21
+ return genai.GenerativeModel(self.model_name)
22
+
23
+ def generate_response(
24
+ self,
25
+ prompt: str,
26
+ image: PIL.Image.Image,
27
+ response_schema: content.Schema,
28
+ max_retries: int = 3,
29
+ timeout: int = 60
30
+ ):
31
+ tries = 0
32
+ while tries < max_retries:
33
+ try:
34
+ responses = self.model.generate_content(
35
+ [prompt, image],
36
+ stream=False,
37
+ generation_config={
38
+ "temperature": 0,
39
+ "response_schema": response_schema,
40
+ "response_mime_type": "application/json",
41
+ },
42
+ request_options={'timeout': timeout}
43
+ )
44
+ output = responses.candidates[0].content.parts[0].text
45
+ return json.loads(output)
46
+ except ResourceExhausted as e:
47
+ tries += 1
48
+ wait_time = tries * 3
49
+ print(f"ResourceExhausted: {e}. Retrying in {wait_time} seconds... (Attempt {tries}/{max_retries})")
50
+ time.sleep(wait_time)
51
+ except Exception as e:
52
+ print(e)
53
+ break
54
+
55
+ return {}
marker/processors/llm/__init__.py ADDED
File without changes
marker/processors/llm/highqualityformprocessor.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import markdown2
2
+
3
+ from marker.llm import GoogleModel
4
+ from marker.processors import BaseProcessor
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+ from typing import Optional
7
+
8
+ from google.ai.generativelanguage_v1beta.types import content
9
+ from tqdm import tqdm
10
+ from tabled.formats import markdown_format
11
+
12
+ from marker.schema import BlockTypes
13
+ from marker.schema.blocks import Block
14
+ from marker.schema.document import Document
15
+ from marker.schema.groups.page import PageGroup
16
+ from marker.settings import settings
17
+
18
+
19
+ class HighQualityFormProcessor(BaseProcessor):
20
+ """
21
+ A processor for converting form blocks in a document to markdown.
22
+ Attributes:
23
+ google_api_key (str):
24
+ The Google API key to use for the Gemini model.
25
+ Default is None.
26
+ model_name (str):
27
+ The name of the Gemini model to use.
28
+ Default is "gemini-1.5-flash".
29
+ max_retries (int):
30
+ The maximum number of retries to use for the Gemini model.
31
+ Default is 3.
32
+ max_concurrency (int):
33
+ The maximum number of concurrent requests to make to the Gemini model.
34
+ Default is 3.
35
+ timeout (int):
36
+ The timeout for requests to the Gemini model.
37
+ gemini_rewriting_prompt (str):
38
+ The prompt to use for rewriting text.
39
+ Default is a string containing the Gemini rewriting prompt.
40
+ """
41
+
42
+ block_types = (BlockTypes.Form,)
43
+ google_api_key: Optional[str] = settings.GOOGLE_API_KEY
44
+ model_name: str = "gemini-1.5-flash"
45
+ high_quality: bool = False
46
+ max_retries: int = 3
47
+ max_concurrency: int = 3
48
+ timeout: int = 60
49
+
50
+ gemini_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
51
+ You will receive an image of a text block and a markdown representation of the form in the image.
52
+ Your task is to correct any errors in the markdown representation, and format it properly.
53
+ Values and labels should appear in markdown tables, with the labels on the left side, and values on the right. The headers should be "Labels" and "Values". Other text in the form can appear between the tables.
54
+ **Instructions:**
55
+ 1. Carefully examine the provided form block image.
56
+ 2. Analyze the markdown representation of the form.
57
+ 3. If the markdown representation is largely correct, then write "No corrections needed."
58
+ 4. If the markdown representation contains errors, generate the corrected markdown representation.
59
+ 5. Output only either the corrected markdown representation or "No corrections needed."
60
+ **Example:**
61
+ Input:
62
+ ```markdown
63
+ | Label 1 | Label 2 | Label 3 |
64
+ |----------|----------|----------|
65
+ | Value 1 | Value 2 | Value 3 |
66
+ ```
67
+ Output:
68
+ ```markdown
69
+ | Labels | Values |
70
+ |--------|--------|
71
+ | Label 1 | Value 1 |
72
+ | Label 2 | Value 2 |
73
+ | Label 3 | Value 3 |
74
+ ```
75
+ **Input:**
76
+ """
77
+
78
+ def __init__(self, config=None):
79
+ super().__init__(config)
80
+
81
+ self.model = None
82
+ if not self.high_quality:
83
+ return
84
+
85
+ self.model = GoogleModel(self.google_api_key, self.model_name)
86
+
87
+ def __call__(self, document: Document):
88
+ if not self.high_quality or self.model is None:
89
+ return
90
+
91
+ self.rewrite_blocks(document)
92
+
93
+ def rewrite_blocks(self, document: Document):
94
+ pbar = tqdm(desc="High quality form processor")
95
+ with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor:
96
+ for future in as_completed([
97
+ executor.submit(self.process_rewriting, page, block)
98
+ for page in document.pages
99
+ for block in page.contained_blocks(document, self.block_types)
100
+ ]):
101
+ future.result() # Raise exceptions if any occurred
102
+ pbar.update(1)
103
+
104
+ pbar.close()
105
+
106
+ def process_rewriting(self, page: PageGroup, block: Block):
107
+ cells = block.cells
108
+ if cells is None:
109
+ # Happens if table/form processors didn't run
110
+ return
111
+
112
+ prompt = self.gemini_rewriting_prompt + '```markdown\n`' + markdown_format(cells) + '`\n```\n'
113
+ image = self.extract_image(page, block)
114
+ response_schema = content.Schema(
115
+ type=content.Type.OBJECT,
116
+ enum=[],
117
+ required=["corrected_markdown"],
118
+ properties={
119
+ "corrected_markdown": content.Schema(
120
+ type=content.Type.STRING
121
+ )
122
+ },
123
+ )
124
+
125
+ response = self.model.generate_response(prompt, image, response_schema)
126
+
127
+ if not response or "corrected_markdown" not in response:
128
+ return
129
+
130
+ corrected_markdown = response["corrected_markdown"]
131
+
132
+ # The original table is okay
133
+ if "no corrections" in corrected_markdown.lower():
134
+ return
135
+
136
+ orig_cell_text = "".join([cell.text for cell in cells])
137
+
138
+ # Potentially a partial response
139
+ if len(corrected_markdown) < len(orig_cell_text) * .5:
140
+ return
141
+
142
+ # Convert LLM markdown to html
143
+ block.html = markdown2.markdown(corrected_markdown)
144
+
145
+ def extract_image(self, page: PageGroup, image_block: Block, expand: float = 0.01):
146
+ page_img = page.lowres_image
147
+ image_box = image_block.polygon\
148
+ .rescale(page.polygon.size, page_img.size)\
149
+ .expand(expand, expand)
150
+ cropped = page_img.crop(image_box.bbox)
151
+ return cropped
marker/processors/llm/highqualitytableprocessor.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tabled.schema import SpanTableCell
2
+
3
+ from marker.llm import GoogleModel
4
+ from marker.processors import BaseProcessor
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+ from typing import Optional, List
7
+
8
+ from google.ai.generativelanguage_v1beta.types import content
9
+ from tqdm import tqdm
10
+ from tabled.formats import markdown_format
11
+
12
+ from marker.schema import BlockTypes
13
+ from marker.schema.blocks import Block
14
+ from marker.schema.document import Document
15
+ from marker.schema.groups.page import PageGroup
16
+ from marker.schema.polygon import PolygonBox
17
+ from marker.settings import settings
18
+
19
+
20
+ class HighQualityTableProcessor(BaseProcessor):
21
+ """
22
+ A processor for converting table blocks in a document to markdown.
23
+ Attributes:
24
+ google_api_key (str):
25
+ The Google API key to use for the Gemini model.
26
+ Default is None.
27
+ model_name (str):
28
+ The name of the Gemini model to use.
29
+ Default is "gemini-1.5-flash".
30
+ max_retries (int):
31
+ The maximum number of retries to use for the Gemini model.
32
+ Default is 3.
33
+ max_concurrency (int):
34
+ The maximum number of concurrent requests to make to the Gemini model.
35
+ Default is 3.
36
+ timeout (int):
37
+ The timeout for requests to the Gemini model.
38
+ gemini_rewriting_prompt (str):
39
+ The prompt to use for rewriting text.
40
+ Default is a string containing the Gemini rewriting prompt.
41
+ """
42
+
43
+ block_types = (BlockTypes.Table,)
44
+ google_api_key: Optional[str] = settings.GOOGLE_API_KEY
45
+ model_name: str = "gemini-1.5-flash"
46
+ high_quality: bool = False
47
+ max_retries: int = 3
48
+ max_concurrency: int = 3
49
+ timeout: int = 60
50
+
51
+ gemini_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
52
+ You will receive an image of a text block and a markdown representation of the table in the image.
53
+ Your task is to correct any errors in the markdown representation. The markdown representation should be as faithful to the original table as possible.
54
+ **Instructions:**
55
+ 1. Carefully examine the provided text block image.
56
+ 2. Analyze the markdown representation of the table.
57
+ 3. If the markdown representation is largely correct, then write "No corrections needed."
58
+ 4. If the markdown representation contains errors, generate the corrected markdown representation.
59
+ 5. Output only either the corrected markdown representation or "No corrections needed."
60
+ **Example:**
61
+ Input:
62
+ ```markdown
63
+ | Column 1 | Column 2 | Column 3 |
64
+ |----------|----------|----------|
65
+ | Value 1 | Value 2 | Value 3 |
66
+ ```
67
+ Output:
68
+ ```markdown
69
+ No corrections needed.
70
+ ```
71
+ **Input:**
72
+ """
73
+
74
+ def __init__(self, config=None):
75
+ super().__init__(config)
76
+
77
+ self.model = None
78
+ if not self.high_quality:
79
+ return
80
+
81
+ self.model = GoogleModel(self.google_api_key, self.model_name)
82
+
83
+ def __call__(self, document: Document):
84
+ if not self.high_quality or self.model is None:
85
+ return
86
+
87
+ self.rewrite_blocks(document)
88
+
89
+ def rewrite_blocks(self, document: Document):
90
+ pbar = tqdm(desc="High quality table processor")
91
+ with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor:
92
+ for future in as_completed([
93
+ executor.submit(self.process_rewriting, page, block)
94
+ for page in document.pages
95
+ for block in page.contained_blocks(document, self.block_types)
96
+ ]):
97
+ future.result() # Raise exceptions if any occurred
98
+ pbar.update(1)
99
+
100
+ pbar.close()
101
+
102
+ def process_rewriting(self, page: PageGroup, block: Block):
103
+ cells = block.cells
104
+ if cells is None:
105
+ # Happens if table/form processors didn't run
106
+ return
107
+
108
+ prompt = self.gemini_rewriting_prompt + '```markdown\n`' + markdown_format(cells) + '`\n```\n'
109
+ image = self.extract_image(page, block)
110
+ response_schema = content.Schema(
111
+ type=content.Type.OBJECT,
112
+ enum=[],
113
+ required=["corrected_markdown"],
114
+ properties={
115
+ "corrected_markdown": content.Schema(
116
+ type=content.Type.STRING
117
+ )
118
+ },
119
+ )
120
+
121
+ response = self.model.generate_response(prompt, image, response_schema)
122
+
123
+ if not response or "corrected_markdown" not in response:
124
+ return
125
+
126
+ corrected_markdown = response["corrected_markdown"]
127
+
128
+ # The original table is okay
129
+ if "no corrections" in corrected_markdown.lower():
130
+ return
131
+
132
+ parsed_cells = self.parse_markdown_table(corrected_markdown, block)
133
+ if len(parsed_cells) <= 1:
134
+ return
135
+
136
+ parsed_cell_text = "".join([cell.text for cell in parsed_cells])
137
+ orig_cell_text = "".join([cell.text for cell in cells])
138
+
139
+ # Potentially a partial response
140
+ if len(parsed_cell_text) < len(orig_cell_text) * .5:
141
+ return
142
+
143
+
144
+ block.cells = parsed_cells
145
+
146
+ def extract_image(self, page: PageGroup, image_block: Block, expand: float = 0.01):
147
+ page_img = page.lowres_image
148
+ image_box = image_block.polygon\
149
+ .rescale(page.polygon.size, page_img.size)\
150
+ .expand(expand, expand)
151
+ cropped = page_img.crop(image_box.bbox)
152
+ return cropped
153
+
154
+ def parse_markdown_table(self, markdown_text: str, block: Block) -> List[SpanTableCell]:
155
+ lines = [line.strip() for line in markdown_text.splitlines() if line.strip()]
156
+
157
+ # Remove separator row for headers
158
+ lines = [line for line in lines if not line.replace('|', ' ').replace('-', ' ').isspace()]
159
+
160
+ rows = []
161
+ for line in lines:
162
+ # Remove leading/trailing pipes and split by remaining pipes
163
+ cells = line.strip('|').split('|')
164
+ # Clean whitespace from each cell
165
+ cells = [cell.strip() for cell in cells]
166
+ rows.append(cells)
167
+
168
+ cells = []
169
+ for i, row in enumerate(rows):
170
+ for j, cell in enumerate(row):
171
+ cell_bbox = [
172
+ block.polygon.bbox[0] + j,
173
+ block.polygon.bbox[1] + i,
174
+ block.polygon.bbox[0] + j + 1,
175
+ block.polygon.bbox[1] + i + 1
176
+ ]
177
+ cell_polygon = PolygonBox.from_bbox(cell_bbox)
178
+ cells.append(
179
+ SpanTableCell(
180
+ text=cell,
181
+ row_ids=[i],
182
+ col_ids=[j],
183
+ bbox=cell_polygon.bbox
184
+ )
185
+ )
186
+
187
+
188
+ return cells
marker/schema/blocks/form.py CHANGED
@@ -10,6 +10,11 @@ from marker.schema.blocks import Block
10
  class Form(Block):
11
  block_type: str = BlockTypes.Form
12
  cells: List[SpanTableCell] | None = None
 
13
 
14
  def assemble_html(self, child_blocks, parent_structure=None):
 
 
 
 
15
  return str(html_format(self.cells))
 
10
  class Form(Block):
11
  block_type: str = BlockTypes.Form
12
  cells: List[SpanTableCell] | None = None
13
+ html: str | None = None
14
 
15
  def assemble_html(self, child_blocks, parent_structure=None):
16
+ # Some processors convert the form to html
17
+ if self.html is not None:
18
+ return self.html
19
+
20
  return str(html_format(self.cells))
marker/settings.py CHANGED
@@ -18,6 +18,9 @@ class Settings(BaseSettings):
18
  OUTPUT_ENCODING: str = "utf-8"
19
  OUTPUT_IMAGE_FORMAT: str = "JPEG"
20
 
 
 
 
21
  # General models
22
  TORCH_DEVICE: Optional[str] = None # Note: MPS device does not work for text detection, and will default to CPU
23
  GOOGLE_API_KEY: Optional[str] = None
 
18
  OUTPUT_ENCODING: str = "utf-8"
19
  OUTPUT_IMAGE_FORMAT: str = "JPEG"
20
 
21
+ # LLM
22
+ GOOGLE_API_KEY: Optional[str] = None
23
+
24
  # General models
25
  TORCH_DEVICE: Optional[str] = None # Note: MPS device does not work for text detection, and will default to CPU
26
  GOOGLE_API_KEY: Optional[str] = None
pyproject.toml CHANGED
@@ -40,6 +40,7 @@ tabled-pdf = "~0.2.0"
40
  markdownify = "^0.13.1"
41
  click = "^8.1.7"
42
  google-generativeai = "^0.8.3"
 
43
 
44
  [tool.poetry.group.dev.dependencies]
45
  jupyter = "^1.0.0"
 
40
  markdownify = "^0.13.1"
41
  click = "^8.1.7"
42
  google-generativeai = "^0.8.3"
43
+ markdown2 = "^2.5.2"
44
 
45
  [tool.poetry.group.dev.dependencies]
46
  jupyter = "^1.0.0"