Spaces:
Configuration error
Configuration error
import json | |
from sparrow_parse.vllm.inference_factory import InferenceFactory | |
from sparrow_parse.helpers.pdf_optimizer import PDFOptimizer | |
from sparrow_parse.helpers.image_optimizer import ImageOptimizer | |
from sparrow_parse.processors.table_structure_processor import TableDetector | |
from rich import print | |
import os | |
import tempfile | |
import shutil | |
class VLLMExtractor(object): | |
def __init__(self): | |
pass | |
def run_inference(self, model_inference_instance, input_data, tables_only=False, | |
generic_query=False, crop_size=None, debug_dir=None, debug=False, mode=None): | |
""" | |
Main entry point for processing input data using a model inference instance. | |
Handles generic queries, PDFs, and table extraction. | |
""" | |
if generic_query: | |
input_data[0]["text_input"] = "retrieve document data. return response in JSON format" | |
if debug: | |
print("Input data:", input_data) | |
file_path = input_data[0]["file_path"] | |
if self.is_pdf(file_path): | |
return self._process_pdf(model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir, mode) | |
return self._process_non_pdf(model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir) | |
def _process_pdf(self, model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir, mode): | |
""" | |
Handles processing and inference for PDF files, including page splitting and optional table extraction. | |
""" | |
pdf_optimizer = PDFOptimizer() | |
num_pages, output_files, temp_dir = pdf_optimizer.split_pdf_to_pages(input_data[0]["file_path"], | |
debug_dir, convert_to_images=True) | |
results = self._process_pages(model_inference_instance, output_files, input_data, tables_only, crop_size, debug, debug_dir) | |
# Clean up temporary directory | |
shutil.rmtree(temp_dir, ignore_errors=True) | |
return results, num_pages | |
def _process_non_pdf(self, model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir): | |
""" | |
Handles processing and inference for non-PDF files, with optional table extraction. | |
""" | |
file_path = input_data[0]["file_path"] | |
if tables_only: | |
return self._extract_tables(model_inference_instance, file_path, input_data, debug, debug_dir), 1 | |
else: | |
temp_dir = tempfile.mkdtemp() | |
if crop_size: | |
if debug: | |
print(f"Cropping image borders by {crop_size} pixels.") | |
image_optimizer = ImageOptimizer() | |
cropped_file_path = image_optimizer.crop_image_borders(file_path, temp_dir, debug_dir, crop_size) | |
input_data[0]["file_path"] = cropped_file_path | |
file_path = input_data[0]["file_path"] | |
input_data[0]["file_path"] = [file_path] | |
results = model_inference_instance.inference(input_data) | |
shutil.rmtree(temp_dir, ignore_errors=True) | |
return results, 1 | |
def _process_pages(self, model_inference_instance, output_files, input_data, tables_only, crop_size, debug, debug_dir): | |
""" | |
Processes individual pages (PDF split) and handles table extraction or inference. | |
Args: | |
model_inference_instance: The model inference object. | |
output_files: List of file paths for the split PDF pages. | |
input_data: Input data for inference. | |
tables_only: Whether to only process tables. | |
crop_size: Size for cropping image borders. | |
debug: Debug flag for logging. | |
debug_dir: Directory for saving debug information. | |
Returns: | |
List of results from the processing or inference. | |
""" | |
results_array = [] | |
if tables_only: | |
if debug: | |
print(f"Processing {len(output_files)} pages for table extraction.") | |
# Process each page individually for table extraction | |
for i, file_path in enumerate(output_files): | |
tables_result = self._extract_tables( | |
model_inference_instance, file_path, input_data, debug, debug_dir, page_index=i | |
) | |
# Since _extract_tables returns a list with one JSON string, unpack it | |
results_array.extend(tables_result) # Unpack the single JSON string | |
else: | |
if debug: | |
print(f"Processing {len(output_files)} pages for inference at once.") | |
temp_dir = tempfile.mkdtemp() | |
cropped_files = [] | |
if crop_size: | |
if debug: | |
print(f"Cropping image borders by {crop_size} pixels from {len(output_files)} images.") | |
image_optimizer = ImageOptimizer() | |
# Process each file in the output_files array | |
for file_path in output_files: | |
cropped_file_path = image_optimizer.crop_image_borders( | |
file_path, | |
temp_dir, | |
debug_dir, | |
crop_size | |
) | |
cropped_files.append(cropped_file_path) | |
# Use the cropped files for inference | |
input_data[0]["file_path"] = cropped_files | |
else: | |
# If no cropping needed, use original files directly | |
input_data[0]["file_path"] = output_files | |
# Process all files at once | |
results = model_inference_instance.inference(input_data) | |
results_array.extend(results) | |
# Clean up temporary directory | |
shutil.rmtree(temp_dir, ignore_errors=True) | |
return results_array | |
def _extract_tables(self, model_inference_instance, file_path, input_data, debug, debug_dir, page_index=None): | |
""" | |
Detects and processes tables from an input file. | |
""" | |
table_detector = TableDetector() | |
cropped_tables = table_detector.detect_tables(file_path, local=False, debug_dir=debug_dir, debug=debug) | |
results_array = [] | |
temp_dir = tempfile.mkdtemp() | |
for i, table in enumerate(cropped_tables): | |
table_index = f"page_{page_index + 1}_table_{i + 1}" if page_index is not None else f"table_{i + 1}" | |
print(f"Processing {table_index} for document {file_path}") | |
output_filename = os.path.join(temp_dir, f"{table_index}.jpg") | |
table.save(output_filename, "JPEG") | |
input_data[0]["file_path"] = [output_filename] | |
result = self._run_model_inference(model_inference_instance, input_data) | |
results_array.append(result) | |
shutil.rmtree(temp_dir, ignore_errors=True) | |
# Merge results_array elements into a single JSON structure | |
merged_results = {"page_tables": results_array} | |
# Format the merged results as a JSON string with indentation | |
formatted_results = json.dumps(merged_results, indent=4) | |
# Return the formatted JSON string wrapped in a list | |
return [formatted_results] | |
def _run_model_inference(model_inference_instance, input_data): | |
""" | |
Runs model inference and handles JSON decoding. | |
""" | |
result = model_inference_instance.inference(input_data)[0] | |
try: | |
return json.loads(result) if isinstance(result, str) else result | |
except json.JSONDecodeError: | |
return {"message": "Invalid JSON format in LLM output", "valid": "false"} | |
def is_pdf(file_path): | |
"""Checks if a file is a PDF based on its extension.""" | |
return file_path.lower().endswith('.pdf') | |
if __name__ == "__main__": | |
# run locally: python -m sparrow_parse.extractors.vllm_extractor | |
extractor = VLLMExtractor() | |
# # export HF_TOKEN="hf_" | |
# config = { | |
# "method": "mlx", # Could be 'huggingface', 'mlx' or 'local_gpu' | |
# "model_name": "mlx-community/Qwen2.5-VL-7B-Instruct-8bit", | |
# # "hf_space": "katanaml/sparrow-qwen2-vl-7b", | |
# # "hf_token": os.getenv('HF_TOKEN'), | |
# # Additional fields for local GPU inference | |
# # "device": "cuda", "model_path": "model.pth" | |
# } | |
# | |
# # Use the factory to get the correct instance | |
# factory = InferenceFactory(config) | |
# model_inference_instance = factory.get_inference_instance() | |
# | |
# input_data = [ | |
# { | |
# "file_path": "/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/bonds_table.png", | |
# "text_input": "retrieve document data. return response in JSON format" | |
# } | |
# ] | |
# | |
# # Now you can run inference without knowing which implementation is used | |
# results_array, num_pages = extractor.run_inference(model_inference_instance, input_data, tables_only=False, | |
# generic_query=False, | |
# crop_size=0, | |
# debug_dir="/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/", | |
# debug=True, | |
# mode=None) | |
# | |
# for i, result in enumerate(results_array): | |
# print(f"Result for page {i + 1}:", result) | |
# print(f"Number of pages: {num_pages}") |