Spaces:
Sleeping
Sleeping
| # src/ml_extraction.py | |
| import os | |
| import torch | |
| from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| from typing import List, Dict, Any, Tuple | |
| import re | |
| import numpy as np | |
| from src.extraction import extract_invoice_number, extract_total, extract_address | |
| from src.table_extraction import extract_table_items | |
| from doctr.io import DocumentFile | |
| from doctr.models import ocr_predictor | |
| # --- CONFIGURATION --- | |
| LOCAL_MODEL_PATH = "./models/layoutlmv3-doctr-trained" | |
| HUB_MODEL_ID = "GSoumyajit2005/layoutlmv3-doctr-invoice-processor" | |
| # --- Load LayoutLMv3 Model --- | |
| def load_model_and_processor(model_path, hub_id): | |
| print("Loading processor from microsoft/layoutlmv3-base...") | |
| processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) | |
| if not os.path.exists(model_path) or not os.listdir(model_path): | |
| print(f"Downloading model from Hub: {hub_id}...") | |
| snapshot_download(repo_id=hub_id, local_dir=model_path, local_dir_use_symlinks=False) | |
| try: | |
| model = LayoutLMv3ForTokenClassification.from_pretrained(model_path) | |
| except Exception: | |
| print(f"Fallback: Loading directly from Hub {hub_id}...") | |
| model = LayoutLMv3ForTokenClassification.from_pretrained(hub_id) | |
| return model, processor | |
| # --- Load DocTR OCR Predictor --- | |
| def load_doctr_predictor(): | |
| """Initialize DocTR predictor and move to GPU for speed.""" | |
| print("Loading DocTR OCR predictor...") | |
| predictor = ocr_predictor( | |
| det_arch='db_resnet50', | |
| reco_arch='crnn_vgg16_bn', | |
| pretrained=True | |
| ) | |
| if torch.cuda.is_available(): | |
| print("🚀 Moving DocTR to GPU (CUDA)...") | |
| predictor.cuda() | |
| else: | |
| print("⚠️ GPU not found. Running on CPU (slow).") | |
| print("DocTR OCR predictor is ready.") | |
| return predictor | |
| MODEL, PROCESSOR = load_model_and_processor(LOCAL_MODEL_PATH, HUB_MODEL_ID) | |
| DOCTR_PREDICTOR = load_doctr_predictor() | |
| if MODEL and PROCESSOR: | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL.to(DEVICE) | |
| MODEL.eval() | |
| print(f"ML Model is ready on device: {DEVICE}") | |
| else: | |
| DEVICE = None | |
| print("❌ Could not load ML model.") | |
| def parse_doctr_output(doctr_result, img_width: int, img_height: int) -> Tuple[List[str], List[List[int]], List[List[int]]]: | |
| """ | |
| Parse DocTR's hierarchical output (Page -> Block -> Line -> Word) | |
| into flat lists of words and bounding boxes for LayoutLMv3. | |
| DocTR returns coordinates in 0-1.0 scale (relative to image). | |
| We convert to: | |
| - unnormalized_boxes: pixel coordinates [x, y, width, height] for visualization | |
| - normalized_boxes: 0-1000 scale [x0, y0, x1, y1] for LayoutLMv3 | |
| Args: | |
| doctr_result: Output from DocTR predictor | |
| img_width: Original image width in pixels | |
| img_height: Original image height in pixels | |
| Returns: | |
| words: List of word strings | |
| unnormalized_boxes: List of [x, y, width, height] in pixel coordinates | |
| normalized_boxes: List of [x0, y0, x1, y1] in 0-1000 scale | |
| """ | |
| words = [] | |
| unnormalized_boxes = [] | |
| normalized_boxes = [] | |
| # DocTR hierarchy: Document -> Page -> Block -> Line -> Word | |
| for page in doctr_result.pages: | |
| for block in page.blocks: | |
| for line in block.lines: | |
| for word in line.words: | |
| # Skip empty words | |
| if not word.value.strip(): | |
| continue | |
| words.append(word.value) | |
| # DocTR bbox format: ((x_min, y_min), (x_max, y_max)) in 0-1 scale | |
| (x_min, y_min), (x_max, y_max) = word.geometry | |
| # Convert to pixel coordinates for visualization | |
| px_x0 = int(x_min * img_width) | |
| px_y0 = int(y_min * img_height) | |
| px_x1 = int(x_max * img_width) | |
| px_y1 = int(y_max * img_height) | |
| # Unnormalized box: [x, y, width, height] for visualization overlay | |
| unnormalized_boxes.append([ | |
| px_x0, | |
| px_y0, | |
| px_x1 - px_x0, # width | |
| px_y1 - px_y0 # height | |
| ]) | |
| # Normalized box: [x0, y0, x1, y1] in 0-1000 scale for LayoutLMv3 | |
| # Clamp values to ensure they stay within [0, 1000] | |
| normalized_boxes.append([ | |
| max(0, min(1000, int(x_min * 1000))), | |
| max(0, min(1000, int(y_min * 1000))), | |
| max(0, min(1000, int(x_max * 1000))), | |
| max(0, min(1000, int(y_max * 1000))), | |
| ]) | |
| return words, unnormalized_boxes, normalized_boxes | |
| def _process_predictions(words, unnormalized_boxes, encoding, predictions, id2label): | |
| word_ids = encoding.word_ids(batch_index=0) | |
| word_level_preds = {} | |
| for idx, word_id in enumerate(word_ids): | |
| if word_id is not None: | |
| label_id = predictions[idx] | |
| if label_id != -100: | |
| if word_id not in word_level_preds: | |
| word_level_preds[word_id] = id2label[label_id] | |
| entities = {} | |
| for word_idx, label in word_level_preds.items(): | |
| if label == 'O': continue | |
| entity_type = label[2:] | |
| word = words[word_idx] | |
| if label.startswith('B-'): | |
| entities[entity_type] = {"text": word, "bbox": [unnormalized_boxes[word_idx]]} | |
| elif label.startswith('I-') and entity_type in entities: | |
| entities[entity_type]['text'] += " " + word | |
| entities[entity_type]['bbox'].append(unnormalized_boxes[word_idx]) | |
| for entity in entities.values(): | |
| entity['text'] = entity['text'].strip() | |
| return entities | |
| def extract_ml_based(image_path: str) -> Dict[str, Any]: | |
| if not MODEL or not PROCESSOR: | |
| raise RuntimeError("ML model is not loaded.") | |
| # 1. Load Image | |
| image = Image.open(image_path).convert("RGB") | |
| width, height = image.size | |
| # 2. Run DocTR OCR | |
| doc = DocumentFile.from_images(image_path) | |
| doctr_result = DOCTR_PREDICTOR(doc) | |
| # 3. Parse DocTR output to get words and boxes | |
| words, unnormalized_boxes, normalized_boxes = parse_doctr_output( | |
| doctr_result, width, height | |
| ) | |
| # Reconstructs lines so regex can work line-by-line | |
| lines = [] | |
| current_line = [] | |
| if len(unnormalized_boxes) > 0: | |
| # Initialize with first word's Y and Height | |
| current_y = unnormalized_boxes[0][1] | |
| current_h = unnormalized_boxes[0][3] | |
| for i, word in enumerate(words): | |
| y = unnormalized_boxes[i][1] | |
| h = unnormalized_boxes[i][3] | |
| # If vertical gap > 50% of line height, it's a new line | |
| if abs(y - current_y) > max(current_h, h) / 2: | |
| lines.append(" ".join(current_line)) | |
| current_line = [] | |
| current_y = y | |
| current_h = h | |
| current_line.append(word) | |
| # Append the last line | |
| if current_line: | |
| lines.append(" ".join(current_line)) | |
| raw_text = "\n".join(lines) | |
| # Handle empty OCR result | |
| if not words: | |
| return { | |
| "vendor": None, | |
| "date": None, | |
| "address": None, | |
| "receipt_number": None, | |
| "bill_to": None, | |
| "total_amount": None, | |
| "items": [], | |
| "raw_text": "", | |
| "raw_predictions": {} | |
| } | |
| # 4. Inference with LayoutLMv3 | |
| encoding = PROCESSOR( | |
| image, text=words, boxes=normalized_boxes, | |
| truncation=True, max_length=512, return_tensors="pt" | |
| ) | |
| # Move tensors to device for inference, but keep original encoding for word_ids() | |
| model_inputs = {k: v.to(DEVICE) for k, v in encoding.items()} | |
| with torch.no_grad(): | |
| outputs = MODEL(**model_inputs) | |
| predictions = outputs.logits.argmax(-1).squeeze().tolist() | |
| extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label) | |
| # 5. Construct Output | |
| final_output = { | |
| "vendor": extracted_entities.get("COMPANY", {}).get("text"), | |
| "date": extracted_entities.get("DATE", {}).get("text"), | |
| "address": extracted_entities.get("ADDRESS", {}).get("text"), | |
| "receipt_number": extracted_entities.get("INVOICE_NO", {}).get("text"), | |
| "bill_to": extracted_entities.get("BILL_TO", {}).get("text"), | |
| "total_amount": None, | |
| "items": [], | |
| "raw_text": raw_text, | |
| "raw_predictions": extracted_entities # Contains text and bbox data for each entity | |
| } | |
| # 6. Vendor Fallback (Spatial Heuristic) | |
| # If ML failed to find a vendor, assume the largest text at the top is the vendor | |
| if not final_output["vendor"] and unnormalized_boxes: | |
| # Filter for words in the top 20% of the image | |
| top_words_indices = [ | |
| i for i, box in enumerate(unnormalized_boxes) | |
| if box[1] < height * 0.2 | |
| ] | |
| if top_words_indices: | |
| # Find the word with the largest height (font size) | |
| largest_idx = max(top_words_indices, key=lambda i: unnormalized_boxes[i][3]) | |
| final_output["vendor"] = words[largest_idx] | |
| # --- ADDRESS FALLBACK --- | |
| if not final_output["address"]: | |
| # We pass the extracted (or fallback) Vendor Name to help anchor the search | |
| # Use the raw text and the known vendor to find the address spatially | |
| fallback_address = extract_address(raw_text, vendor_name=final_output["vendor"]) | |
| if fallback_address: | |
| final_output["address"] = fallback_address | |
| # Backfill Bounding Boxes for Address Fallback | |
| # If Regex found the address but ML didn't, find its boxes in the OCR data | |
| if final_output["address"] and "ADDRESS" not in final_output["raw_predictions"]: | |
| address_text = final_output["address"] | |
| address_boxes = [] | |
| # The address may span multiple words, so we search for each word | |
| # Split by comma first (since extract_address joins lines with ", ") | |
| address_parts = [part.strip() for part in address_text.split(",")] | |
| for part in address_parts: | |
| part_words = part.split() | |
| for target_word in part_words: | |
| for i, word in enumerate(words): | |
| # Case-insensitive match | |
| if target_word.lower() == word.lower() or target_word.lower() in word.lower(): | |
| address_boxes.append(unnormalized_boxes[i]) | |
| break # Only match once per target word | |
| # If we found any boxes, inject into raw_predictions | |
| if address_boxes: | |
| final_output["raw_predictions"]["ADDRESS"] = { | |
| "text": address_text, | |
| "bbox": address_boxes | |
| } | |
| # Fallbacks | |
| ml_total = extracted_entities.get("TOTAL", {}).get("text") | |
| if ml_total: | |
| try: | |
| cleaned = re.sub(r'[^\d.,]', '', ml_total).replace(',', '.') | |
| final_output["total_amount"] = float(cleaned) | |
| except (ValueError, TypeError): | |
| pass | |
| if final_output["total_amount"] is None: | |
| final_output["total_amount"] = extract_total(raw_text) | |
| if not final_output["receipt_number"]: | |
| final_output["receipt_number"] = extract_invoice_number(raw_text) | |
| # Backfill Bounding Boxes for Regex Results | |
| # If Regex found the number but ML didn't, we must find its box | |
| # in the OCR data so the UI can draw it. | |
| if final_output["receipt_number"] and "INVOICE_NO" not in final_output["raw_predictions"]: | |
| target_val = final_output["receipt_number"].strip() | |
| found_box = None | |
| # 1. Try finding the exact word in the OCR list | |
| # 'words' and 'unnormalized_boxes' are available from step 3 | |
| for i, word in enumerate(words): | |
| # Check for exact match or if the word contains the target (e.g. "Inv#123") | |
| if target_val == word or (len(target_val) > 3 and target_val in word): | |
| found_box = unnormalized_boxes[i] | |
| break | |
| # 2. If found, inject it into raw_predictions | |
| if found_box: | |
| # The UI expects a list of boxes | |
| final_output["raw_predictions"]["INVOICE_NO"] = { | |
| "text": target_val, | |
| "bbox": [found_box] | |
| } | |
| # --- TABLE EXTRACTION (Geometric Heuristic) --- | |
| # Use the geometric fallback to extract line items from table region | |
| if words and unnormalized_boxes: | |
| extracted_items = extract_table_items(words, unnormalized_boxes) | |
| if extracted_items: | |
| final_output["items"] = extracted_items | |
| return final_output |