File size: 9,563 Bytes
05e6f93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import json
from sparrow_parse.vllm.inference_factory import InferenceFactory
from sparrow_parse.helpers.pdf_optimizer import PDFOptimizer
from sparrow_parse.helpers.image_optimizer import ImageOptimizer
from sparrow_parse.processors.table_structure_processor import TableDetector
from rich import print
import os
import tempfile
import shutil


class VLLMExtractor(object):
    def __init__(self):
        pass

    def run_inference(self, model_inference_instance, input_data, tables_only=False,
                      generic_query=False, crop_size=None, debug_dir=None, debug=False, mode=None):
        """
        Main entry point for processing input data using a model inference instance.
        Handles generic queries, PDFs, and table extraction.
        """
        if generic_query:
            input_data[0]["text_input"] = "retrieve document data. return response in JSON format"

        if debug:
            print("Input data:", input_data)

        file_path = input_data[0]["file_path"]
        if self.is_pdf(file_path):
            return self._process_pdf(model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir, mode)

        return self._process_non_pdf(model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir)


    def _process_pdf(self, model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir, mode):
        """
        Handles processing and inference for PDF files, including page splitting and optional table extraction.
        """
        pdf_optimizer = PDFOptimizer()
        num_pages, output_files, temp_dir = pdf_optimizer.split_pdf_to_pages(input_data[0]["file_path"],
                                                                             debug_dir, convert_to_images=True)

        results = self._process_pages(model_inference_instance, output_files, input_data, tables_only, crop_size, debug, debug_dir)

        # Clean up temporary directory
        shutil.rmtree(temp_dir, ignore_errors=True)
        return results, num_pages


    def _process_non_pdf(self, model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir):
        """
        Handles processing and inference for non-PDF files, with optional table extraction.
        """
        file_path = input_data[0]["file_path"]

        if tables_only:
            return self._extract_tables(model_inference_instance, file_path, input_data, debug, debug_dir), 1
        else:
            temp_dir = tempfile.mkdtemp()

            if crop_size:
                if debug:
                    print(f"Cropping image borders by {crop_size} pixels.")
                image_optimizer = ImageOptimizer()
                cropped_file_path = image_optimizer.crop_image_borders(file_path, temp_dir, debug_dir, crop_size)
                input_data[0]["file_path"] = cropped_file_path

            file_path = input_data[0]["file_path"]
            input_data[0]["file_path"] = [file_path]
            results = model_inference_instance.inference(input_data)

            shutil.rmtree(temp_dir, ignore_errors=True)

            return results, 1

    def _process_pages(self, model_inference_instance, output_files, input_data, tables_only, crop_size, debug, debug_dir):
        """
        Processes individual pages (PDF split) and handles table extraction or inference.

        Args:
            model_inference_instance: The model inference object.
            output_files: List of file paths for the split PDF pages.
            input_data: Input data for inference.
            tables_only: Whether to only process tables.
            crop_size: Size for cropping image borders.
            debug: Debug flag for logging.
            debug_dir: Directory for saving debug information.

        Returns:
            List of results from the processing or inference.
        """
        results_array = []

        if tables_only:
            if debug:
                print(f"Processing {len(output_files)} pages for table extraction.")
            # Process each page individually for table extraction
            for i, file_path in enumerate(output_files):
                tables_result = self._extract_tables(
                    model_inference_instance, file_path, input_data, debug, debug_dir, page_index=i
                )
                # Since _extract_tables returns a list with one JSON string, unpack it
                results_array.extend(tables_result)  # Unpack the single JSON string
        else:
            if debug:
                print(f"Processing {len(output_files)} pages for inference at once.")

            temp_dir = tempfile.mkdtemp()
            cropped_files = []

            if crop_size:
                if debug:
                    print(f"Cropping image borders by {crop_size} pixels from {len(output_files)} images.")

                image_optimizer = ImageOptimizer()

                # Process each file in the output_files array
                for file_path in output_files:
                    cropped_file_path = image_optimizer.crop_image_borders(
                        file_path,
                        temp_dir,
                        debug_dir,
                        crop_size
                    )
                    cropped_files.append(cropped_file_path)

                # Use the cropped files for inference
                input_data[0]["file_path"] = cropped_files
            else:
                # If no cropping needed, use original files directly
                input_data[0]["file_path"] = output_files

            # Process all files at once
            results = model_inference_instance.inference(input_data)
            results_array.extend(results)

            # Clean up temporary directory
            shutil.rmtree(temp_dir, ignore_errors=True)

        return results_array


    def _extract_tables(self, model_inference_instance, file_path, input_data, debug, debug_dir, page_index=None):
        """
        Detects and processes tables from an input file.
        """
        table_detector = TableDetector()
        cropped_tables = table_detector.detect_tables(file_path, local=False, debug_dir=debug_dir, debug=debug)
        results_array = []
        temp_dir = tempfile.mkdtemp()

        for i, table in enumerate(cropped_tables):
            table_index = f"page_{page_index + 1}_table_{i + 1}" if page_index is not None else f"table_{i + 1}"
            print(f"Processing {table_index} for document {file_path}")

            output_filename = os.path.join(temp_dir, f"{table_index}.jpg")
            table.save(output_filename, "JPEG")

            input_data[0]["file_path"] = [output_filename]
            result = self._run_model_inference(model_inference_instance, input_data)
            results_array.append(result)

        shutil.rmtree(temp_dir, ignore_errors=True)

        # Merge results_array elements into a single JSON structure
        merged_results = {"page_tables": results_array}

        # Format the merged results as a JSON string with indentation
        formatted_results = json.dumps(merged_results, indent=4)

        # Return the formatted JSON string wrapped in a list
        return [formatted_results]


    @staticmethod
    def _run_model_inference(model_inference_instance, input_data):
        """
        Runs model inference and handles JSON decoding.
        """
        result = model_inference_instance.inference(input_data)[0]
        try:
            return json.loads(result) if isinstance(result, str) else result
        except json.JSONDecodeError:
            return {"message": "Invalid JSON format in LLM output", "valid": "false"}


    @staticmethod
    def is_pdf(file_path):
        """Checks if a file is a PDF based on its extension."""
        return file_path.lower().endswith('.pdf')


if __name__ == "__main__":
    # run locally: python -m sparrow_parse.extractors.vllm_extractor

    extractor = VLLMExtractor()

    # # export HF_TOKEN="hf_"
    # config = {
    #     "method": "mlx",  # Could be 'huggingface', 'mlx' or 'local_gpu'
    #     "model_name": "mlx-community/Qwen2.5-VL-7B-Instruct-8bit",
    #     # "hf_space": "katanaml/sparrow-qwen2-vl-7b",
    #     # "hf_token": os.getenv('HF_TOKEN'),
    #     # Additional fields for local GPU inference
    #     # "device": "cuda", "model_path": "model.pth"
    # }
    #
    # # Use the factory to get the correct instance
    # factory = InferenceFactory(config)
    # model_inference_instance = factory.get_inference_instance()
    #
    # input_data = [
    #     {
    #         "file_path": "/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/bonds_table.png",
    #         "text_input": "retrieve document data. return response in JSON format"
    #     }
    # ]
    #
    # # Now you can run inference without knowing which implementation is used
    # results_array, num_pages = extractor.run_inference(model_inference_instance, input_data, tables_only=False,
    #                                                    generic_query=False,
    #                                                    crop_size=0,
    #                                                    debug_dir="/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/",
    #                                                    debug=True,
    #                                                    mode=None)
    #
    # for i, result in enumerate(results_array):
    #     print(f"Result for page {i + 1}:", result)
    # print(f"Number of pages: {num_pages}")