Spaces:
Configuration error
Configuration error
File size: 9,563 Bytes
05e6f93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
import json
from sparrow_parse.vllm.inference_factory import InferenceFactory
from sparrow_parse.helpers.pdf_optimizer import PDFOptimizer
from sparrow_parse.helpers.image_optimizer import ImageOptimizer
from sparrow_parse.processors.table_structure_processor import TableDetector
from rich import print
import os
import tempfile
import shutil
class VLLMExtractor(object):
def __init__(self):
pass
def run_inference(self, model_inference_instance, input_data, tables_only=False,
generic_query=False, crop_size=None, debug_dir=None, debug=False, mode=None):
"""
Main entry point for processing input data using a model inference instance.
Handles generic queries, PDFs, and table extraction.
"""
if generic_query:
input_data[0]["text_input"] = "retrieve document data. return response in JSON format"
if debug:
print("Input data:", input_data)
file_path = input_data[0]["file_path"]
if self.is_pdf(file_path):
return self._process_pdf(model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir, mode)
return self._process_non_pdf(model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir)
def _process_pdf(self, model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir, mode):
"""
Handles processing and inference for PDF files, including page splitting and optional table extraction.
"""
pdf_optimizer = PDFOptimizer()
num_pages, output_files, temp_dir = pdf_optimizer.split_pdf_to_pages(input_data[0]["file_path"],
debug_dir, convert_to_images=True)
results = self._process_pages(model_inference_instance, output_files, input_data, tables_only, crop_size, debug, debug_dir)
# Clean up temporary directory
shutil.rmtree(temp_dir, ignore_errors=True)
return results, num_pages
def _process_non_pdf(self, model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir):
"""
Handles processing and inference for non-PDF files, with optional table extraction.
"""
file_path = input_data[0]["file_path"]
if tables_only:
return self._extract_tables(model_inference_instance, file_path, input_data, debug, debug_dir), 1
else:
temp_dir = tempfile.mkdtemp()
if crop_size:
if debug:
print(f"Cropping image borders by {crop_size} pixels.")
image_optimizer = ImageOptimizer()
cropped_file_path = image_optimizer.crop_image_borders(file_path, temp_dir, debug_dir, crop_size)
input_data[0]["file_path"] = cropped_file_path
file_path = input_data[0]["file_path"]
input_data[0]["file_path"] = [file_path]
results = model_inference_instance.inference(input_data)
shutil.rmtree(temp_dir, ignore_errors=True)
return results, 1
def _process_pages(self, model_inference_instance, output_files, input_data, tables_only, crop_size, debug, debug_dir):
"""
Processes individual pages (PDF split) and handles table extraction or inference.
Args:
model_inference_instance: The model inference object.
output_files: List of file paths for the split PDF pages.
input_data: Input data for inference.
tables_only: Whether to only process tables.
crop_size: Size for cropping image borders.
debug: Debug flag for logging.
debug_dir: Directory for saving debug information.
Returns:
List of results from the processing or inference.
"""
results_array = []
if tables_only:
if debug:
print(f"Processing {len(output_files)} pages for table extraction.")
# Process each page individually for table extraction
for i, file_path in enumerate(output_files):
tables_result = self._extract_tables(
model_inference_instance, file_path, input_data, debug, debug_dir, page_index=i
)
# Since _extract_tables returns a list with one JSON string, unpack it
results_array.extend(tables_result) # Unpack the single JSON string
else:
if debug:
print(f"Processing {len(output_files)} pages for inference at once.")
temp_dir = tempfile.mkdtemp()
cropped_files = []
if crop_size:
if debug:
print(f"Cropping image borders by {crop_size} pixels from {len(output_files)} images.")
image_optimizer = ImageOptimizer()
# Process each file in the output_files array
for file_path in output_files:
cropped_file_path = image_optimizer.crop_image_borders(
file_path,
temp_dir,
debug_dir,
crop_size
)
cropped_files.append(cropped_file_path)
# Use the cropped files for inference
input_data[0]["file_path"] = cropped_files
else:
# If no cropping needed, use original files directly
input_data[0]["file_path"] = output_files
# Process all files at once
results = model_inference_instance.inference(input_data)
results_array.extend(results)
# Clean up temporary directory
shutil.rmtree(temp_dir, ignore_errors=True)
return results_array
def _extract_tables(self, model_inference_instance, file_path, input_data, debug, debug_dir, page_index=None):
"""
Detects and processes tables from an input file.
"""
table_detector = TableDetector()
cropped_tables = table_detector.detect_tables(file_path, local=False, debug_dir=debug_dir, debug=debug)
results_array = []
temp_dir = tempfile.mkdtemp()
for i, table in enumerate(cropped_tables):
table_index = f"page_{page_index + 1}_table_{i + 1}" if page_index is not None else f"table_{i + 1}"
print(f"Processing {table_index} for document {file_path}")
output_filename = os.path.join(temp_dir, f"{table_index}.jpg")
table.save(output_filename, "JPEG")
input_data[0]["file_path"] = [output_filename]
result = self._run_model_inference(model_inference_instance, input_data)
results_array.append(result)
shutil.rmtree(temp_dir, ignore_errors=True)
# Merge results_array elements into a single JSON structure
merged_results = {"page_tables": results_array}
# Format the merged results as a JSON string with indentation
formatted_results = json.dumps(merged_results, indent=4)
# Return the formatted JSON string wrapped in a list
return [formatted_results]
@staticmethod
def _run_model_inference(model_inference_instance, input_data):
"""
Runs model inference and handles JSON decoding.
"""
result = model_inference_instance.inference(input_data)[0]
try:
return json.loads(result) if isinstance(result, str) else result
except json.JSONDecodeError:
return {"message": "Invalid JSON format in LLM output", "valid": "false"}
@staticmethod
def is_pdf(file_path):
"""Checks if a file is a PDF based on its extension."""
return file_path.lower().endswith('.pdf')
if __name__ == "__main__":
# run locally: python -m sparrow_parse.extractors.vllm_extractor
extractor = VLLMExtractor()
# # export HF_TOKEN="hf_"
# config = {
# "method": "mlx", # Could be 'huggingface', 'mlx' or 'local_gpu'
# "model_name": "mlx-community/Qwen2.5-VL-7B-Instruct-8bit",
# # "hf_space": "katanaml/sparrow-qwen2-vl-7b",
# # "hf_token": os.getenv('HF_TOKEN'),
# # Additional fields for local GPU inference
# # "device": "cuda", "model_path": "model.pth"
# }
#
# # Use the factory to get the correct instance
# factory = InferenceFactory(config)
# model_inference_instance = factory.get_inference_instance()
#
# input_data = [
# {
# "file_path": "/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/bonds_table.png",
# "text_input": "retrieve document data. return response in JSON format"
# }
# ]
#
# # Now you can run inference without knowing which implementation is used
# results_array, num_pages = extractor.run_inference(model_inference_instance, input_data, tables_only=False,
# generic_query=False,
# crop_size=0,
# debug_dir="/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/",
# debug=True,
# mode=None)
#
# for i, result in enumerate(results_array):
# print(f"Result for page {i + 1}:", result)
# print(f"Number of pages: {num_pages}") |