Spaces:
Configuration error
Configuration error
| import json | |
| from sparrow_parse.vllm.inference_factory import InferenceFactory | |
| from sparrow_parse.helpers.pdf_optimizer import PDFOptimizer | |
| from sparrow_parse.helpers.image_optimizer import ImageOptimizer | |
| from sparrow_parse.processors.table_structure_processor import TableDetector | |
| from rich import print | |
| import os | |
| import tempfile | |
| import shutil | |
| class VLLMExtractor(object): | |
| def __init__(self): | |
| pass | |
| def run_inference(self, model_inference_instance, input_data, tables_only=False, | |
| generic_query=False, crop_size=None, debug_dir=None, debug=False, mode=None): | |
| """ | |
| Main entry point for processing input data using a model inference instance. | |
| Handles generic queries, PDFs, and table extraction. | |
| """ | |
| if generic_query: | |
| input_data[0]["text_input"] = "retrieve document data. return response in JSON format" | |
| if debug: | |
| print("Input data:", input_data) | |
| file_path = input_data[0]["file_path"] | |
| if self.is_pdf(file_path): | |
| return self._process_pdf(model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir, mode) | |
| return self._process_non_pdf(model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir) | |
| def _process_pdf(self, model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir, mode): | |
| """ | |
| Handles processing and inference for PDF files, including page splitting and optional table extraction. | |
| """ | |
| pdf_optimizer = PDFOptimizer() | |
| num_pages, output_files, temp_dir = pdf_optimizer.split_pdf_to_pages(input_data[0]["file_path"], | |
| debug_dir, convert_to_images=True) | |
| results = self._process_pages(model_inference_instance, output_files, input_data, tables_only, crop_size, debug, debug_dir) | |
| # Clean up temporary directory | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| return results, num_pages | |
| def _process_non_pdf(self, model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir): | |
| """ | |
| Handles processing and inference for non-PDF files, with optional table extraction. | |
| """ | |
| file_path = input_data[0]["file_path"] | |
| if tables_only: | |
| return self._extract_tables(model_inference_instance, file_path, input_data, debug, debug_dir), 1 | |
| else: | |
| temp_dir = tempfile.mkdtemp() | |
| if crop_size: | |
| if debug: | |
| print(f"Cropping image borders by {crop_size} pixels.") | |
| image_optimizer = ImageOptimizer() | |
| cropped_file_path = image_optimizer.crop_image_borders(file_path, temp_dir, debug_dir, crop_size) | |
| input_data[0]["file_path"] = cropped_file_path | |
| file_path = input_data[0]["file_path"] | |
| input_data[0]["file_path"] = [file_path] | |
| results = model_inference_instance.inference(input_data) | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| return results, 1 | |
| def _process_pages(self, model_inference_instance, output_files, input_data, tables_only, crop_size, debug, debug_dir): | |
| """ | |
| Processes individual pages (PDF split) and handles table extraction or inference. | |
| Args: | |
| model_inference_instance: The model inference object. | |
| output_files: List of file paths for the split PDF pages. | |
| input_data: Input data for inference. | |
| tables_only: Whether to only process tables. | |
| crop_size: Size for cropping image borders. | |
| debug: Debug flag for logging. | |
| debug_dir: Directory for saving debug information. | |
| Returns: | |
| List of results from the processing or inference. | |
| """ | |
| results_array = [] | |
| if tables_only: | |
| if debug: | |
| print(f"Processing {len(output_files)} pages for table extraction.") | |
| # Process each page individually for table extraction | |
| for i, file_path in enumerate(output_files): | |
| tables_result = self._extract_tables( | |
| model_inference_instance, file_path, input_data, debug, debug_dir, page_index=i | |
| ) | |
| # Since _extract_tables returns a list with one JSON string, unpack it | |
| results_array.extend(tables_result) # Unpack the single JSON string | |
| else: | |
| if debug: | |
| print(f"Processing {len(output_files)} pages for inference at once.") | |
| temp_dir = tempfile.mkdtemp() | |
| cropped_files = [] | |
| if crop_size: | |
| if debug: | |
| print(f"Cropping image borders by {crop_size} pixels from {len(output_files)} images.") | |
| image_optimizer = ImageOptimizer() | |
| # Process each file in the output_files array | |
| for file_path in output_files: | |
| cropped_file_path = image_optimizer.crop_image_borders( | |
| file_path, | |
| temp_dir, | |
| debug_dir, | |
| crop_size | |
| ) | |
| cropped_files.append(cropped_file_path) | |
| # Use the cropped files for inference | |
| input_data[0]["file_path"] = cropped_files | |
| else: | |
| # If no cropping needed, use original files directly | |
| input_data[0]["file_path"] = output_files | |
| # Process all files at once | |
| results = model_inference_instance.inference(input_data) | |
| results_array.extend(results) | |
| # Clean up temporary directory | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| return results_array | |
| def _extract_tables(self, model_inference_instance, file_path, input_data, debug, debug_dir, page_index=None): | |
| """ | |
| Detects and processes tables from an input file. | |
| """ | |
| table_detector = TableDetector() | |
| cropped_tables = table_detector.detect_tables(file_path, local=False, debug_dir=debug_dir, debug=debug) | |
| results_array = [] | |
| temp_dir = tempfile.mkdtemp() | |
| for i, table in enumerate(cropped_tables): | |
| table_index = f"page_{page_index + 1}_table_{i + 1}" if page_index is not None else f"table_{i + 1}" | |
| print(f"Processing {table_index} for document {file_path}") | |
| output_filename = os.path.join(temp_dir, f"{table_index}.jpg") | |
| table.save(output_filename, "JPEG") | |
| input_data[0]["file_path"] = [output_filename] | |
| result = self._run_model_inference(model_inference_instance, input_data) | |
| results_array.append(result) | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| # Merge results_array elements into a single JSON structure | |
| merged_results = {"page_tables": results_array} | |
| # Format the merged results as a JSON string with indentation | |
| formatted_results = json.dumps(merged_results, indent=4) | |
| # Return the formatted JSON string wrapped in a list | |
| return [formatted_results] | |
| def _run_model_inference(model_inference_instance, input_data): | |
| """ | |
| Runs model inference and handles JSON decoding. | |
| """ | |
| result = model_inference_instance.inference(input_data)[0] | |
| try: | |
| return json.loads(result) if isinstance(result, str) else result | |
| except json.JSONDecodeError: | |
| return {"message": "Invalid JSON format in LLM output", "valid": "false"} | |
| def is_pdf(file_path): | |
| """Checks if a file is a PDF based on its extension.""" | |
| return file_path.lower().endswith('.pdf') | |
| if __name__ == "__main__": | |
| # run locally: python -m sparrow_parse.extractors.vllm_extractor | |
| extractor = VLLMExtractor() | |
| # # export HF_TOKEN="hf_" | |
| # config = { | |
| # "method": "mlx", # Could be 'huggingface', 'mlx' or 'local_gpu' | |
| # "model_name": "mlx-community/Qwen2.5-VL-7B-Instruct-8bit", | |
| # # "hf_space": "katanaml/sparrow-qwen2-vl-7b", | |
| # # "hf_token": os.getenv('HF_TOKEN'), | |
| # # Additional fields for local GPU inference | |
| # # "device": "cuda", "model_path": "model.pth" | |
| # } | |
| # | |
| # # Use the factory to get the correct instance | |
| # factory = InferenceFactory(config) | |
| # model_inference_instance = factory.get_inference_instance() | |
| # | |
| # input_data = [ | |
| # { | |
| # "file_path": "/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/bonds_table.png", | |
| # "text_input": "retrieve document data. return response in JSON format" | |
| # } | |
| # ] | |
| # | |
| # # Now you can run inference without knowing which implementation is used | |
| # results_array, num_pages = extractor.run_inference(model_inference_instance, input_data, tables_only=False, | |
| # generic_query=False, | |
| # crop_size=0, | |
| # debug_dir="/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/", | |
| # debug=True, | |
| # mode=None) | |
| # | |
| # for i, result in enumerate(results_array): | |
| # print(f"Result for page {i + 1}:", result) | |
| # print(f"Number of pages: {num_pages}") |