File size: 13,866 Bytes
6d6af66
8b775e5
66eba92
 
 
6d6af66
db7a2e8
66eba92
8b775e5
66eba92
 
ebb9438
 
4c9f681
07b50c0
4c9f681
54e0e7a
66eba92
 
8b775e5
54e0e7a
8b775e5
 
6d6af66
 
8b775e5
 
6d6af66
 
db7a2e8
66eba92
54e0e7a
66eba92
54e0e7a
6d6af66
66eba92
db7a2e8
6d6af66
66eba92
6d6af66
 
 
 
 
 
 
 
66eba92
6d6af66
 
 
66eba92
6d6af66
 
66eba92
6d6af66
 
66eba92
54e0e7a
 
66eba92
54e0e7a
 
66eba92
 
 
54e0e7a
8b775e5
6d6af66
 
66eba92
6d6af66
 
8b775e5
6d6af66
54e0e7a
66eba92
54e0e7a
66eba92
6d6af66
66eba92
6d6af66
8b775e5
6d6af66
8b775e5
66eba92
 
 
 
 
4c9f681
 
 
 
 
 
66eba92
 
 
 
 
 
 
 
 
 
 
 
4c9f681
 
 
 
 
66eba92
 
 
 
4c9f681
 
 
 
66eba92
 
 
 
 
4c9f681
 
 
66eba92
 
 
4c9f681
 
6e4ec8a
4c9f681
6e4ec8a
 
 
 
 
 
 
 
4c9f681
6e4ec8a
 
4c9f681
6e4ec8a
4c9f681
6e4ec8a
 
 
 
 
 
 
 
4c9f681
6e4ec8a
4c9f681
 
6e4ec8a
 
 
4c9f681
6e4ec8a
 
 
 
 
 
 
4c9f681
 
 
6e4ec8a
66eba92
6e4ec8a
 
 
 
 
 
 
4c9f681
66eba92
 
 
 
 
6e4ec8a
4c9f681
6e4ec8a
 
 
 
4c9f681
 
 
6e4ec8a
4c9f681
 
 
6e4ec8a
66eba92
4c9f681
66eba92
 
 
 
 
 
4c9f681
 
 
66eba92
 
 
6d6af66
54e0e7a
66eba92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c9f681
 
 
 
66eba92
 
4c9f681
 
f15a593
 
 
 
 
 
 
66eba92
 
 
4c9f681
66eba92
4c9f681
66eba92
 
 
 
 
 
 
 
 
 
 
 
 
f15a593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07b50c0
f15a593
07b50c0
 
 
f15a593
 
07b50c0
 
 
 
f15a593
 
 
 
 
 
 
 
 
4c9f681
 
 
 
 
 
 
66eba92
 
3b392a9
66eba92
3b392a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import atexit
import functools
import os
import re
import tempfile
from queue import Queue
from threading import Event, Thread
import threading # Import threading

from flask import Flask, request, jsonify
from paddleocr import PaddleOCR
from PIL import Image

# --- NEW: Import the NLP analysis function --- 
from nlp_service import analyze_text # Corrected import

# --- Configuration ---
LANG = 'en' # Default language, can be overridden if needed
NUM_WORKERS = 2  # Number of OCR worker threads

# --- PaddleOCR Model Manager ---
class PaddleOCRModelManager(object):
    def __init__(self,
                 num_workers,
                 model_factory):
        super().__init__()
        self._model_factory = model_factory
        self._queue = Queue()
        self._workers = []
        self._model_initialized_event = Event()
        print(f"Initializing {num_workers} OCR worker(s)...")
        for i in range(num_workers):
            print(f"Starting worker {i+1}...")
            worker = Thread(target=self._worker, daemon=True)
            worker.start()
            self._model_initialized_event.wait()  # Wait for this worker's model
            self._model_initialized_event.clear()
            self._workers.append(worker)
        print("All OCR workers initialized.")

    def infer(self, *args, **kwargs):
        result_queue = Queue(maxsize=1)
        self._queue.put((args, kwargs, result_queue))
        success, payload = result_queue.get()
        if success:
            return payload
        else:
            print(f"Error during OCR inference: {payload}")
            raise payload

    def close(self):
        print("Shutting down OCR workers...")
        for _ in self._workers:
            self._queue.put(None)
        print("OCR worker shutdown signaled.")

    def _worker(self):
        print(f"Worker thread {threading.current_thread().name}: Loading PaddleOCR model ({LANG})...")
        try:
            model = self._model_factory()
            print(f"Worker thread {threading.current_thread().name}: Model loaded.")
            self._model_initialized_event.set()
        except Exception as e:
            print(f"FATAL: Worker thread {threading.current_thread().name} failed to load model: {e}")
            self._model_initialized_event.set()
            return

        while True:
            item = self._queue.get()
            if item is None:
                print(f"Worker thread {threading.current_thread().name}: Exiting.")
                break
            args, kwargs, result_queue = item
            try:
                result = model.ocr(*args, **kwargs)
                if result and result[0]:
                    result_queue.put((True, result[0]))
                else:
                     result_queue.put((True, []))
            except Exception as e:
                print(f"Worker thread {threading.current_thread().name}: Error processing request: {e}")
                result_queue.put((False, e))
            finally:
                self._queue.task_done()

# --- Amount Extraction Logic ---
def find_main_amount(ocr_results):
    if not ocr_results:
        return None

    amount_regex = re.compile(r'(?<!%)\b\d{1,3}(?:,?\d{3})*(?:\.\d{2})\b|\b\d+\.\d{2}\b|\b\d+\b(?!\.\d{1})')
    
    # Prioritized keywords
    priority_keywords = ['grand total', 'total amount', 'amount due', 'to pay', 'bill total', 'total payable']
    secondary_keywords = ['total', 'balance', 'net amount', 'paid', 'charge', 'net total'] # Added 'net total'
    lower_priority_keywords = ['subtotal', 'sub total'] # Added 'sub total'

    parsed_lines = []
    for i, line_info in enumerate(ocr_results):
        if not line_info or len(line_info) < 2 or len(line_info[1]) < 1:
            continue
        text = line_info[1][0].lower().strip()
        confidence = line_info[1][1]

        numbers_in_line = amount_regex.findall(text)
        float_numbers = []
        for num_str in numbers_in_line:
            try:
                # Avoid converting year-like numbers if they stand alone on short lines
                if len(text) < 7 and '.' not in num_str and 1900 < int(num_str.replace(',', '')) < 2100:
                     # More robust check: avoid if it's the only thing and looks like a year
                     if len(numbers_in_line) == 1 and len(num_str) == 4:
                         continue
                float_numbers.append(float(num_str.replace(',', '')))
            except ValueError:
                continue

        # Check for keywords
        has_priority_keyword = any(re.search(r'\b' + re.escape(kw) + r'\b', text) for kw in priority_keywords)
        has_secondary_keyword = any(re.search(r'\b' + re.escape(kw) + r'\b', text) for kw in secondary_keywords)
        has_lower_priority_keyword = any(re.search(r'\b' + re.escape(kw) + r'\b', text) for kw in lower_priority_keywords)

        parsed_lines.append({
            "index": i,
            "text": text,
            "numbers": float_numbers,
            "has_priority_keyword": has_priority_keyword,
            "has_secondary_keyword": has_secondary_keyword,
            "has_lower_priority_keyword": has_lower_priority_keyword,
            "confidence": confidence
        })

    # --- Strategy to find the best candidate ---
    
    # 1. Look for numbers on the SAME line as PRIORITY keywords OR the line IMMEDIATELY AFTER
    priority_candidates = []
    for i, line in enumerate(parsed_lines):
        if line["has_priority_keyword"]:
            if line["numbers"]:
                priority_candidates.extend(line["numbers"])
            # Check next line if current line has no numbers and next line exists
            elif i + 1 < len(parsed_lines) and parsed_lines[i+1]["numbers"]:
                 priority_candidates.extend(parsed_lines[i+1]["numbers"])

    if priority_candidates:
        # Often the largest number on/near these lines is the final total
        return max(priority_candidates)

    # 2. Look for numbers on the SAME line as SECONDARY keywords OR the line IMMEDIATELY AFTER
    secondary_candidates = []
    for i, line in enumerate(parsed_lines):
         if line["has_secondary_keyword"]:
            if line["numbers"]:
                secondary_candidates.extend(line["numbers"])
            # Check next line if current line has no numbers and next line exists
            elif i + 1 < len(parsed_lines) and parsed_lines[i+1]["numbers"]:
                 secondary_candidates.extend(parsed_lines[i+1]["numbers"])

    if secondary_candidates:
         # If we only found secondary keywords, return the largest number found on/near those lines
        return max(secondary_candidates)

    # 3. Look near priority/secondary keywords (REMOVED - less reliable, covered by step 1 & 2)

    # 4. Look for numbers on the SAME line as LOWER PRIORITY keywords (Subtotal) OR the line IMMEDIATELY AFTER
    lower_priority_candidates = []
    for i, line in enumerate(parsed_lines):
        if line["has_lower_priority_keyword"]:
            if line["numbers"]:
                lower_priority_candidates.extend(line["numbers"])
            # Check next line if current line has no numbers and next line exists
            elif i + 1 < len(parsed_lines) and parsed_lines[i+1]["numbers"]:
                 lower_priority_candidates.extend(parsed_lines[i+1]["numbers"])
    # Don't return subtotal directly unless it's the only thing found later

    # 5. Fallback: Largest plausible number overall (excluding subtotals if other numbers exist)
    print("Warning: No numbers found on/near priority/secondary keyword lines. Using fallback.")
    all_numbers = []
    # Use set comprehension for efficiency
    subtotal_numbers = {num for line in parsed_lines if line["has_lower_priority_keyword"] for num in line["numbers"]}
    # Also add numbers from the line after lower priority keywords to subtotals
    for i, line in enumerate(parsed_lines):
        if line["has_lower_priority_keyword"] and not line["numbers"] and i + 1 < len(parsed_lines):
             subtotal_numbers.update(parsed_lines[i+1]["numbers"])


    for line in parsed_lines:
        all_numbers.extend(line["numbers"])

    if all_numbers:
        unique_numbers = list(set(all_numbers))

        # Filter out potential quantities/years/small irrelevant numbers
        plausible_numbers = [n for n in unique_numbers if n >= 0.01] # Keep small decimals too
        # Stricter filter for large numbers: exclude large integers (likely IDs, phone numbers)
        # Keep numbers < 50000 OR numbers that have a non-zero decimal part
        plausible_numbers = [n for n in plausible_numbers if n < 50000 or (n != int(n))]

        # If we have plausible numbers other than subtotals, prefer them
        non_subtotal_plausible = [n for n in plausible_numbers if n not in subtotal_numbers]

        if non_subtotal_plausible:
            return max(non_subtotal_plausible)
        elif plausible_numbers: # Only subtotals (or nothing else plausible) were found
             return max(plausible_numbers) # Return the largest subtotal/plausible as last resort

    # 6. If still nothing, return None
    print("Warning: Could not determine main amount.")
    return None

# --- Flask App Setup ---
app = Flask(__name__)

# --- REMOVED: Register the NLP Blueprint ---
# app.register_blueprint(nlp_bp) # No longer needed as we call the function directly

# --- Initialize OCR Manager ---
ocr_model_factory = functools.partial(PaddleOCR, lang=LANG, use_angle_cls=True, use_gpu=False, show_log=False)
ocr_manager = PaddleOCRModelManager(num_workers=NUM_WORKERS, model_factory=ocr_model_factory)

# Register cleanup function
atexit.register(ocr_manager.close)

# --- API Endpoint ---
@app.route('/extract_expense', methods=['POST'])
def extract_expense():
    if 'file' not in request.files:
        return jsonify({"error": "No file part in the request"}), 400

    file = request.files['file']
    if file.filename == '':
        return jsonify({"error": "No selected file"}), 400

    if file:
        temp_file_path = None # Initialize variable
        try:
            # Save to a temporary file
            _, file_extension = os.path.splitext(file.filename)
            with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
                file.save(temp_file.name)
                temp_file_path = temp_file.name

            # Perform OCR
            ocr_result = ocr_manager.infer(temp_file_path, cls=True)

            # Process OCR results
            extracted_text = ""
            main_amount_ocr = None
            if ocr_result:
                extracted_lines = [line[1][0] for line in ocr_result if line and len(line) > 1 and len(line[1]) > 0]
                extracted_text = "\n".join(extracted_lines)
                main_amount_ocr = find_main_amount(ocr_result) # Keep OCR amount extraction

            # --- REMOVED: NLP Call ---
            # nlp_analysis_result = None
            # nlp_error = None
            # ... (removed NLP call logic) ...
            # --- End Removed NLP Call ---

            # Construct the response (only OCR results)
            response_data = {
                "type": "photo",
                "extracted_text": extracted_text,
                "main_amount_ocr": main_amount_ocr, # Amount found by OCR regex logic
            }

            return jsonify(response_data)

        except Exception as e:
            print(f"Error processing file: {e}")
            import traceback
            traceback.print_exc()
            return jsonify({"error": f"An internal error occurred: {str(e)}"}), 500
        finally:
            if temp_file_path and os.path.exists(temp_file_path):
                os.remove(temp_file_path)

    return jsonify({"error": "File processing failed"}), 500

# --- NEW: NLP Message Endpoint ---
@app.route('/message', methods=['POST'])
def process_message():
    data = request.get_json()
    if not data or 'text' not in data:
        return jsonify({"error": "Missing 'text' field in JSON payload"}), 400

    text_message = data['text']
    if not text_message:
         return jsonify({"error": "'text' field cannot be empty"}), 400

    nlp_analysis_result = None
    nlp_error = None
    try:
        # Call the imported analysis function
        nlp_analysis_result = analyze_text(text_message) # Corrected function call
        print(f"NLP Service Analysis Result: {nlp_analysis_result}")
        # Check if the NLP analysis itself reported an error/failure or requires fallback
        status = nlp_analysis_result.get("status")
        if status == "failed":
            nlp_error = nlp_analysis_result.get("message", "NLP processing failed")
            # Return the failure result from NLP service
            return jsonify(nlp_analysis_result), 400 # Use 400 for client-side errors like empty text
        elif status == "fallback_required":
             # Return the fallback result (e.g., for queries)
             return jsonify(nlp_analysis_result), 200 # Return 200, but indicate fallback needed
        
        # Return the successful analysis result
        return jsonify(nlp_analysis_result)

    except Exception as nlp_e:
        nlp_error = f"Error calling NLP analysis function: {nlp_e}"
        print(f"Error calling NLP function: {nlp_error}")
        return jsonify({"error": "An internal error occurred during NLP processing", "details": nlp_error}), 500

# --- NEW: Health Check Endpoint ---
@app.route('/health', methods=['GET'])
def health_check():
    # You could add more checks here (e.g., if OCR workers are alive)
    return jsonify({"status": "ok"}), 200


# --- Run the App ---
if __name__ == '__main__':
    # Use port 7860 as expected by Hugging Face Spaces
    # Use host='0.0.0.0' for accessibility within Docker/Spaces
    app.run(host='0.0.0.0', port=7860, debug=False)