Sparrow / sparrow_parse /extractors /vllm_extractor.py
Zana897465's picture
Upload 24 files
05e6f93 verified
raw
history blame contribute delete
9.56 kB
import json
from sparrow_parse.vllm.inference_factory import InferenceFactory
from sparrow_parse.helpers.pdf_optimizer import PDFOptimizer
from sparrow_parse.helpers.image_optimizer import ImageOptimizer
from sparrow_parse.processors.table_structure_processor import TableDetector
from rich import print
import os
import tempfile
import shutil
class VLLMExtractor(object):
def __init__(self):
pass
def run_inference(self, model_inference_instance, input_data, tables_only=False,
generic_query=False, crop_size=None, debug_dir=None, debug=False, mode=None):
"""
Main entry point for processing input data using a model inference instance.
Handles generic queries, PDFs, and table extraction.
"""
if generic_query:
input_data[0]["text_input"] = "retrieve document data. return response in JSON format"
if debug:
print("Input data:", input_data)
file_path = input_data[0]["file_path"]
if self.is_pdf(file_path):
return self._process_pdf(model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir, mode)
return self._process_non_pdf(model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir)
def _process_pdf(self, model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir, mode):
"""
Handles processing and inference for PDF files, including page splitting and optional table extraction.
"""
pdf_optimizer = PDFOptimizer()
num_pages, output_files, temp_dir = pdf_optimizer.split_pdf_to_pages(input_data[0]["file_path"],
debug_dir, convert_to_images=True)
results = self._process_pages(model_inference_instance, output_files, input_data, tables_only, crop_size, debug, debug_dir)
# Clean up temporary directory
shutil.rmtree(temp_dir, ignore_errors=True)
return results, num_pages
def _process_non_pdf(self, model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir):
"""
Handles processing and inference for non-PDF files, with optional table extraction.
"""
file_path = input_data[0]["file_path"]
if tables_only:
return self._extract_tables(model_inference_instance, file_path, input_data, debug, debug_dir), 1
else:
temp_dir = tempfile.mkdtemp()
if crop_size:
if debug:
print(f"Cropping image borders by {crop_size} pixels.")
image_optimizer = ImageOptimizer()
cropped_file_path = image_optimizer.crop_image_borders(file_path, temp_dir, debug_dir, crop_size)
input_data[0]["file_path"] = cropped_file_path
file_path = input_data[0]["file_path"]
input_data[0]["file_path"] = [file_path]
results = model_inference_instance.inference(input_data)
shutil.rmtree(temp_dir, ignore_errors=True)
return results, 1
def _process_pages(self, model_inference_instance, output_files, input_data, tables_only, crop_size, debug, debug_dir):
"""
Processes individual pages (PDF split) and handles table extraction or inference.
Args:
model_inference_instance: The model inference object.
output_files: List of file paths for the split PDF pages.
input_data: Input data for inference.
tables_only: Whether to only process tables.
crop_size: Size for cropping image borders.
debug: Debug flag for logging.
debug_dir: Directory for saving debug information.
Returns:
List of results from the processing or inference.
"""
results_array = []
if tables_only:
if debug:
print(f"Processing {len(output_files)} pages for table extraction.")
# Process each page individually for table extraction
for i, file_path in enumerate(output_files):
tables_result = self._extract_tables(
model_inference_instance, file_path, input_data, debug, debug_dir, page_index=i
)
# Since _extract_tables returns a list with one JSON string, unpack it
results_array.extend(tables_result) # Unpack the single JSON string
else:
if debug:
print(f"Processing {len(output_files)} pages for inference at once.")
temp_dir = tempfile.mkdtemp()
cropped_files = []
if crop_size:
if debug:
print(f"Cropping image borders by {crop_size} pixels from {len(output_files)} images.")
image_optimizer = ImageOptimizer()
# Process each file in the output_files array
for file_path in output_files:
cropped_file_path = image_optimizer.crop_image_borders(
file_path,
temp_dir,
debug_dir,
crop_size
)
cropped_files.append(cropped_file_path)
# Use the cropped files for inference
input_data[0]["file_path"] = cropped_files
else:
# If no cropping needed, use original files directly
input_data[0]["file_path"] = output_files
# Process all files at once
results = model_inference_instance.inference(input_data)
results_array.extend(results)
# Clean up temporary directory
shutil.rmtree(temp_dir, ignore_errors=True)
return results_array
def _extract_tables(self, model_inference_instance, file_path, input_data, debug, debug_dir, page_index=None):
"""
Detects and processes tables from an input file.
"""
table_detector = TableDetector()
cropped_tables = table_detector.detect_tables(file_path, local=False, debug_dir=debug_dir, debug=debug)
results_array = []
temp_dir = tempfile.mkdtemp()
for i, table in enumerate(cropped_tables):
table_index = f"page_{page_index + 1}_table_{i + 1}" if page_index is not None else f"table_{i + 1}"
print(f"Processing {table_index} for document {file_path}")
output_filename = os.path.join(temp_dir, f"{table_index}.jpg")
table.save(output_filename, "JPEG")
input_data[0]["file_path"] = [output_filename]
result = self._run_model_inference(model_inference_instance, input_data)
results_array.append(result)
shutil.rmtree(temp_dir, ignore_errors=True)
# Merge results_array elements into a single JSON structure
merged_results = {"page_tables": results_array}
# Format the merged results as a JSON string with indentation
formatted_results = json.dumps(merged_results, indent=4)
# Return the formatted JSON string wrapped in a list
return [formatted_results]
@staticmethod
def _run_model_inference(model_inference_instance, input_data):
"""
Runs model inference and handles JSON decoding.
"""
result = model_inference_instance.inference(input_data)[0]
try:
return json.loads(result) if isinstance(result, str) else result
except json.JSONDecodeError:
return {"message": "Invalid JSON format in LLM output", "valid": "false"}
@staticmethod
def is_pdf(file_path):
"""Checks if a file is a PDF based on its extension."""
return file_path.lower().endswith('.pdf')
if __name__ == "__main__":
# run locally: python -m sparrow_parse.extractors.vllm_extractor
extractor = VLLMExtractor()
# # export HF_TOKEN="hf_"
# config = {
# "method": "mlx", # Could be 'huggingface', 'mlx' or 'local_gpu'
# "model_name": "mlx-community/Qwen2.5-VL-7B-Instruct-8bit",
# # "hf_space": "katanaml/sparrow-qwen2-vl-7b",
# # "hf_token": os.getenv('HF_TOKEN'),
# # Additional fields for local GPU inference
# # "device": "cuda", "model_path": "model.pth"
# }
#
# # Use the factory to get the correct instance
# factory = InferenceFactory(config)
# model_inference_instance = factory.get_inference_instance()
#
# input_data = [
# {
# "file_path": "/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/bonds_table.png",
# "text_input": "retrieve document data. return response in JSON format"
# }
# ]
#
# # Now you can run inference without knowing which implementation is used
# results_array, num_pages = extractor.run_inference(model_inference_instance, input_data, tables_only=False,
# generic_query=False,
# crop_size=0,
# debug_dir="/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/",
# debug=True,
# mode=None)
#
# for i, result in enumerate(results_array):
# print(f"Result for page {i + 1}:", result)
# print(f"Number of pages: {num_pages}")