| from dotenv import load_dotenv |
| load_dotenv() |
|
|
| from fastapi import FastAPI, UploadFile, File, Form |
| from fastapi.responses import JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from transformers import pipeline as hf_pipeline, AutoTokenizer, AutoModelForTokenClassification |
| from doctr.io import DocumentFile |
| from doctr.models import ocr_predictor |
| from img2table.document import Image as Img2TableImage |
| from img2table.ocr import DocTR |
| import cv2 |
| import numpy as np |
| from PIL import Image |
| import io |
| import json |
| import os |
| import tempfile |
| import base64 |
| from typing import Dict, Any, Optional, List |
| import difflib |
| import re |
| import httpx |
| from bs4 import BeautifulSoup |
|
|
| |
| from docling.document_converter import DocumentConverter, InputFormat, ImageFormatOption |
| from docling.datamodel.pipeline_options import PdfPipelineOptions |
| from docling_ocr_onnxtr import OnnxtrOcrOptions |
|
|
| |
| from router_chat import router as chat_router |
| from faq_store import initialize_faq_store |
|
|
| app = FastAPI(title="ScanAssured OCR & NER API") |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| initialize_faq_store() |
|
|
| app.include_router(chat_router) |
|
|
| |
| DRUG_INTERACTIONS = {} |
| interactions_path = os.path.join(os.path.dirname(__file__), 'interactions_data.json') |
| if os.path.exists(interactions_path): |
| with open(interactions_path, 'r') as f: |
| DRUG_INTERACTIONS = json.load(f) |
| print(f"Loaded {len(DRUG_INTERACTIONS)} drug interaction entries") |
|
|
| |
| MEDLINEPLUS_MAP = {} |
| medlineplus_map_path = os.path.join(os.path.dirname(__file__), 'medlineplus_map.json') |
| if os.path.exists(medlineplus_map_path): |
| with open(medlineplus_map_path, 'r') as f: |
| MEDLINEPLUS_MAP = json.load(f) |
| print(f"Loaded {len(MEDLINEPLUS_MAP)} MedlinePlus test mappings") |
|
|
| MEDLINEPLUS_CACHE = {} |
| medlineplus_cache_path = os.path.join(os.path.dirname(__file__), 'medlineplus_cache.json') |
| if os.path.exists(medlineplus_cache_path): |
| with open(medlineplus_cache_path, 'r') as f: |
| MEDLINEPLUS_CACHE = json.load(f) |
| print(f"Loaded {len(MEDLINEPLUS_CACHE)} MedlinePlus cached entries") |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| OCR_PRESETS = { |
| "high_accuracy": { |
| "det": "db_resnet50", |
| "reco": "crnn_vgg16_bn", |
| "name": "High Accuracy", |
| "description": "Best quality, slower processing" |
| }, |
| "balanced": { |
| "det": "db_resnet50", |
| "reco": "crnn_mobilenet_v3_small", |
| "name": "Balanced (Recommended)", |
| "description": "Good quality and speed" |
| }, |
| "fast": { |
| "det": "db_mobilenet_v3_large", |
| "reco": "crnn_mobilenet_v3_small", |
| "name": "Fast", |
| "description": "Fastest processing, slightly lower quality" |
| }, |
| } |
|
|
| OCR_DETECTION_MODELS = ["db_resnet50", "db_mobilenet_v3_large", "linknet_resnet18"] |
| OCR_RECOGNITION_MODELS = ["crnn_vgg16_bn", "crnn_mobilenet_v3_small", "parseq"] |
|
|
| |
| NER_MODELS = { |
| "Clinical-AI-Apollo/Medical-NER": { |
| "name": "Medical NER (Recommended)", |
| "description": "Medications, diseases, lab values, procedures, dosages", |
| "entities": ["MEDICATION", "DOSAGE", "FREQUENCY", "DURATION", |
| "DISEASE_DISORDER", "SIGN_SYMPTOM", "DIAGNOSTIC_PROCEDURE", |
| "THERAPEUTIC_PROCEDURE", "LAB_VALUE", "SEVERITY"] |
| }, |
| "samrawal/bert-base-uncased_clinical-ner": { |
| "name": "Clinical Notes", |
| "description": "Optimized for clinical/medical notes", |
| "entities": ["PROBLEM", "TREATMENT", "TEST"] |
| }, |
| } |
|
|
| |
| ner_model_cache: Dict[str, Any] = {} |
| ocr_model_cache: Dict[str, Any] = {} |
| |
| docling_converter_cache: Dict[str, Any] = {} |
|
|
| def get_docling_converter(det_arch: str = "db_mobilenet_v3_large", reco_arch: str = "crnn_vgg16_bn"): |
| """Get or create a cached Docling DocumentConverter with OnnxTR OCR.""" |
| cache_key = f"docling_{det_arch}_{reco_arch}" |
|
|
| if cache_key in docling_converter_cache: |
| print(f"Using cached Docling converter: {cache_key}") |
| return docling_converter_cache[cache_key] |
|
|
| try: |
| print(f"Initializing Docling converter: det={det_arch}, reco={reco_arch}...") |
|
|
| ocr_options = OnnxtrOcrOptions( |
| det_arch=det_arch, |
| reco_arch=reco_arch, |
| ) |
|
|
| pipeline_options = PdfPipelineOptions(ocr_options=ocr_options) |
| pipeline_options.do_table_structure = True |
| pipeline_options.do_ocr = True |
| pipeline_options.allow_external_plugins = True |
|
|
| converter = DocumentConverter( |
| format_options={ |
| InputFormat.IMAGE: ImageFormatOption(pipeline_options=pipeline_options) |
| } |
| ) |
|
|
| docling_converter_cache[cache_key] = converter |
| print(f"Docling converter {cache_key} initialized successfully!") |
| return converter |
| except Exception as e: |
| print(f"ERROR: Failed to initialize Docling converter: {e}") |
| import traceback |
| traceback.print_exc() |
| return None |
|
|
|
|
| def run_docling_pipeline(file_content: bytes) -> Dict[str, Any]: |
| """ |
| Run the Docling pipeline on raw image bytes. |
| Returns structured results for comparison with docTR. |
| """ |
| try: |
| converter = get_docling_converter() |
| if converter is None: |
| return {"error": "Docling converter not available", "success": False} |
|
|
| |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: |
| tmp_file.write(file_content) |
| tmp_path = tmp_file.name |
|
|
| try: |
| print("Running Docling pipeline...") |
| result = converter.convert(source=tmp_path) |
|
|
| |
| markdown_text = result.document.export_to_markdown() |
|
|
| |
| if hasattr(result.document, 'export_to_text'): |
| plain_text = result.document.export_to_text() |
| else: |
| plain_text = markdown_text |
|
|
| |
| docling_tables = [] |
| if hasattr(result.document, 'tables') and result.document.tables: |
| for table in result.document.tables: |
| table_data = _parse_docling_table(table) |
| if table_data: |
| docling_tables.append(table_data) |
|
|
| print(f"Docling: {len(markdown_text)} chars markdown, {len(docling_tables)} tables") |
|
|
| return { |
| "success": True, |
| "markdown_text": markdown_text, |
| "plain_text": plain_text, |
| "tables": docling_tables, |
| "primary_table": docling_tables[0] if docling_tables else None, |
| } |
| finally: |
| try: |
| os.unlink(tmp_path) |
| except: |
| pass |
|
|
| except Exception as e: |
| print(f"Docling pipeline error: {e}") |
| import traceback |
| traceback.print_exc() |
| return {"error": str(e), "success": False} |
|
|
|
|
| def _parse_docling_table(table) -> Optional[Dict]: |
| """Parse a Docling table into {cells, num_rows, num_columns} format.""" |
| try: |
| if hasattr(table, 'export_to_dataframe'): |
| df = table.export_to_dataframe() |
| if df is not None and not df.empty: |
| cells = [] |
| header = [str(col) if col is not None else '' for col in df.columns.tolist()] |
| cells.append(header) |
| for _, row in df.iterrows(): |
| row_cells = [str(val).strip() if val is not None else '' for val in row.tolist()] |
| cells.append(row_cells) |
|
|
| return { |
| "cells": cells, |
| "num_rows": len(cells), |
| "num_columns": len(header), |
| "method": "docling_tableformer" |
| } |
|
|
| if hasattr(table, 'export_to_markdown'): |
| md = table.export_to_markdown() |
| if md: |
| return { |
| "cells": [], |
| "num_rows": 0, |
| "num_columns": 0, |
| "method": "docling_tableformer", |
| "markdown": md |
| } |
|
|
| return None |
| except Exception as e: |
| print(f"Docling table parse error: {e}") |
| return None |
|
|
|
|
| |
| def get_ocr_predictor(det_arch: str, reco_arch: str): |
| """Retrieves a loaded OCR predictor from cache or loads it if necessary.""" |
| cache_key = f"{det_arch}_{reco_arch}" |
|
|
| if cache_key in ocr_model_cache: |
| print(f"Using cached OCR model: {cache_key}") |
| return ocr_model_cache[cache_key] |
|
|
| try: |
| print(f"Loading OCR model: det={det_arch}, reco={reco_arch}...") |
| predictor = ocr_predictor( |
| det_arch=det_arch, |
| reco_arch=reco_arch, |
| pretrained=True, |
| assume_straight_pages=True, |
| straighten_pages=False, |
| detect_orientation=False, |
| preserve_aspect_ratio=True |
| ) |
| ocr_model_cache[cache_key] = predictor |
| print(f"OCR model {cache_key} loaded successfully!") |
| return predictor |
| except Exception as e: |
| print(f"ERROR: Failed to load OCR model {cache_key}: {e}") |
| return None |
|
|
| |
| def get_ner_pipeline(model_id: str): |
| """Retrieves a loaded NER pipeline from cache or loads it if necessary.""" |
| if model_id not in NER_MODELS: |
| raise ValueError(f"Unknown NER model ID: {model_id}") |
|
|
| if model_id in ner_model_cache: |
| print(f"Using cached NER model: {model_id}") |
| return ner_model_cache[model_id] |
|
|
| try: |
| print(f"Loading NER model: {model_id}...") |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| model = AutoModelForTokenClassification.from_pretrained(model_id) |
|
|
| ner_pipeline = hf_pipeline( |
| "ner", |
| model=model, |
| tokenizer=tokenizer, |
| aggregation_strategy="simple" |
| ) |
| ner_model_cache[model_id] = ner_pipeline |
| print(f"NER model {model_id} loaded successfully!") |
| return ner_pipeline |
| except Exception as e: |
| print(f"ERROR: Failed to load NER model {model_id}: {e}") |
| return None |
|
|
|
|
| def _edit_distance(s1: str, s2: str) -> int: |
| """Compute Levenshtein edit distance between two strings.""" |
| if len(s1) < len(s2): |
| return _edit_distance(s2, s1) |
| if len(s2) == 0: |
| return len(s1) |
|
|
| prev_row = range(len(s2) + 1) |
| for i, c1 in enumerate(s1): |
| curr_row = [i + 1] |
| for j, c2 in enumerate(s2): |
| insertions = prev_row[j + 1] + 1 |
| deletions = curr_row[j] + 1 |
| substitutions = prev_row[j] + (c1 != c2) |
| curr_row.append(min(insertions, deletions, substitutions)) |
| prev_row = curr_row |
| return prev_row[-1] |
|
|
|
|
| |
|
|
| _entity_dicts: dict[str, set] = {} |
|
|
|
|
| def _build_entity_dicts(): |
| """Build per-entity-type dictionaries from already-loaded DRUG_INTERACTIONS and MEDLINEPLUS_MAP.""" |
| global _entity_dicts |
|
|
| med_dict: set[str] = set() |
| for drug_name in DRUG_INTERACTIONS.keys(): |
| for part in str(drug_name).split(','): |
| part = part.strip().lower() |
| if len(part) >= 4: |
| med_dict.add(part) |
|
|
| lab_dict: set[str] = set() |
| for test_name, data in MEDLINEPLUS_MAP.items(): |
| if len(test_name) >= 4: |
| lab_dict.add(test_name.lower()) |
| for alias in data.get('aliases', []): |
| if len(alias) >= 4: |
| lab_dict.add(alias.lower()) |
|
|
| _entity_dicts = { |
| 'MEDICATION': med_dict, |
| 'LAB_VALUE': lab_dict, |
| 'DIAGNOSTIC_PROCEDURE': lab_dict, |
| 'TREATMENT': med_dict, |
| 'CHEM': med_dict, |
| 'CHEMICAL': med_dict, |
| } |
| print(f"Entity dicts built: {len(med_dict)} medication terms, {len(lab_dict)} lab terms") |
|
|
|
|
| def _find_closest(word: str, dictionary: set) -> tuple: |
| best_match, best_dist = None, 999 |
| word_lower = word.lower() |
| for term in dictionary: |
| if abs(len(term) - len(word_lower)) > 3: |
| continue |
| dist = _edit_distance(word_lower, term) |
| if dist < best_dist: |
| best_dist = dist |
| best_match = term |
| return best_match, best_dist |
|
|
|
|
| def _match_case(original: str, replacement: str) -> str: |
| if original.isupper(): |
| return replacement.upper() |
| if original[0].isupper(): |
| return replacement.capitalize() |
| return replacement.lower() |
|
|
|
|
| def correct_with_ner_entities( |
| words_with_boxes: list, |
| ner_entities: list, |
| text: str, |
| confidence_threshold: float = 0.75, |
| ) -> dict: |
| """Second-pass correction using NER entity labels as context.""" |
| if not _entity_dicts: |
| _build_entity_dicts() |
|
|
| word_conf: dict[str, float] = {} |
| for w in words_with_boxes: |
| key = w['word'].lower() |
| word_conf[key] = min(word_conf.get(key, 1.0), w.get('confidence', 1.0)) |
|
|
| corrections = [] |
| corrected_text = text |
|
|
| for entity in ner_entities: |
| entity_type = entity.get('entity_group', '') |
| entity_word = entity.get('word', '').strip() |
| lookup_dict = _entity_dicts.get(entity_type) |
| if not lookup_dict or not entity_word: |
| continue |
|
|
| for token in entity_word.split(): |
| clean_token = re.sub(r'[^a-zA-Z]', '', token) |
| if not clean_token.isalpha() or len(clean_token) < 4: |
| continue |
|
|
| ocr_conf = word_conf.get(clean_token.lower(), 1.0) |
| if ocr_conf >= confidence_threshold: |
| continue |
|
|
| best_match, best_dist = _find_closest(clean_token, lookup_dict) |
| if best_match is None or best_dist > 2: |
| continue |
| if best_match.lower() == clean_token.lower(): |
| continue |
|
|
| replacement = _match_case(clean_token, best_match) |
| match = re.search(r'\b' + re.escape(clean_token) + r'\b', |
| corrected_text, re.IGNORECASE) |
| if not match: |
| continue |
|
|
| start, end = match.start(), match.end() |
| corrected_text = corrected_text[:start] + replacement + corrected_text[end:] |
| corrections.append({ |
| 'original': clean_token, |
| 'corrected': replacement, |
| 'confidence': round(1.0 - best_dist / max(len(clean_token), len(best_match)), 4), |
| 'ocr_confidence': round(ocr_conf, 4), |
| 'edit_distance': best_dist, |
| 'source': 'ner', |
| 'entity_type': entity_type, |
| }) |
| word_conf[replacement.lower()] = 1.0 |
|
|
| return {'corrected_text': corrected_text, 'corrections': corrections} |
|
|
|
|
| |
| def deskew_image(image: np.ndarray) -> np.ndarray: |
| """Deskew image using projection profile method.""" |
| try: |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image |
| edges = cv2.Canny(gray, 50, 150, apertureSize=3) |
| lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=100, minLineLength=100, maxLineGap=10) |
|
|
| if lines is not None and len(lines) > 0: |
| angles = [] |
| for line in lines: |
| x1, y1, x2, y2 = line[0] |
| angle = np.arctan2(y2 - y1, x2 - x1) * 180 / np.pi |
| if abs(angle) < 45: |
| angles.append(angle) |
|
|
| if angles: |
| median_angle = np.median(angles) |
| if abs(median_angle) > 0.5: |
| (h, w) = image.shape[:2] |
| center = (w // 2, h // 2) |
| M = cv2.getRotationMatrix2D(center, median_angle, 1.0) |
| rotated = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE) |
| return rotated |
|
|
| return image |
| except Exception as e: |
| print(f"Deskew warning: {e}") |
| return image |
|
|
| def preprocess_for_doctr(file_content: bytes) -> np.ndarray: |
| """Automatic preprocessing pipeline optimized for docTR.""" |
| nparr = np.frombuffer(file_content, np.uint8) |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
| if img is None: |
| raise ValueError("Failed to decode image") |
|
|
| img = deskew_image(img) |
|
|
| lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) |
| lab[:, :, 0] = clahe.apply(lab[:, :, 0]) |
| img = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) |
|
|
| img = cv2.fastNlMeansDenoisingColored(img, None, 6, 6, 7, 21) |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
| return img |
|
|
| def basic_cleanup(text: str) -> str: |
| """Clean up OCR text for NER processing.""" |
| text = " ".join(text.split()) |
| return text |
|
|
|
|
| |
|
|
| |
| img2table_ocr_cache = {} |
|
|
| def get_img2table_ocr(): |
| """Get or create img2table DocTR OCR instance.""" |
| if 'doctr' not in img2table_ocr_cache: |
| img2table_ocr_cache['doctr'] = DocTR() |
| return img2table_ocr_cache['doctr'] |
|
|
|
|
| def extract_tables_with_img2table(image_bytes: bytes, img_width: int, img_height: int) -> dict: |
| """ |
| Use img2table to detect and extract table structure from image. |
| Returns table data with properly structured cells. |
| """ |
| try: |
| |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: |
| tmp_file.write(image_bytes) |
| tmp_path = tmp_file.name |
|
|
| |
| img2table_img = Img2TableImage(src=tmp_path) |
|
|
| |
| ocr = get_img2table_ocr() |
|
|
| |
| tables = img2table_img.extract_tables( |
| ocr=ocr, |
| implicit_rows=True, |
| implicit_columns=True, |
| borderless_tables=True, |
| min_confidence=50 |
| ) |
|
|
| |
| try: |
| os.unlink(tmp_path) |
| except: |
| pass |
|
|
| if not tables: |
| return {'is_table': False, 'tables': []} |
|
|
| |
| all_tables = [] |
| for table in tables: |
| cells = [] |
|
|
| |
| if hasattr(table, 'df') and table.df is not None: |
| df = table.df |
| |
| |
| header = [str(col) if col is not None else '' for col in df.columns.tolist()] |
| cells.append(header) |
| |
| for _, row in df.iterrows(): |
| row_cells = [str(val).strip() if val is not None else '' for val in row.tolist()] |
| cells.append(row_cells) |
|
|
| |
| elif hasattr(table, 'content') and table.content is not None: |
| content = table.content |
| if isinstance(content, list): |
| for row in content: |
| if isinstance(row, (list, tuple)): |
| row_cells = [] |
| for cell in row: |
| if cell is None: |
| row_cells.append('') |
| elif isinstance(cell, str): |
| row_cells.append(cell.strip()) |
| elif hasattr(cell, 'value'): |
| row_cells.append(str(cell.value).strip() if cell.value else '') |
| elif hasattr(cell, 'text'): |
| row_cells.append(str(cell.text).strip() if cell.text else '') |
| else: |
| row_cells.append(str(cell).strip()) |
| cells.append(row_cells) |
| elif isinstance(row, dict): |
| |
| row_cells = [str(v).strip() if v else '' for v in row.values()] |
| cells.append(row_cells) |
|
|
| |
| elif hasattr(table, '_content'): |
| print(f"Table has _content: {type(table._content)}") |
|
|
| |
| if cells and len(cells) > 1: |
| |
| has_content = any(any(c.strip() for c in row) for row in cells) |
| if has_content: |
| num_cols = max(len(row) for row in cells) if cells else 0 |
| all_tables.append({ |
| 'cells': cells, |
| 'num_rows': len(cells), |
| 'num_columns': num_cols |
| }) |
| print(f"Extracted table with {len(cells)} rows and {num_cols} columns") |
|
|
| if not all_tables: |
| print("No valid tables extracted") |
| return {'is_table': False, 'tables': []} |
|
|
| |
| primary_table = max(all_tables, key=lambda t: t['num_rows'] * t['num_columns']) |
| print(f"Primary table: {primary_table['num_rows']}x{primary_table['num_columns']}") |
|
|
| return { |
| 'is_table': True, |
| 'cells': primary_table['cells'], |
| 'num_rows': primary_table['num_rows'], |
| 'num_columns': primary_table['num_columns'], |
| 'tables': all_tables, |
| 'total_tables': len(all_tables) |
| } |
|
|
| except Exception as e: |
| print(f"img2table extraction error: {e}") |
| import traceback |
| traceback.print_exc() |
| return {'is_table': False, 'error': str(e)} |
|
|
|
|
| def format_table_as_markdown(table_data: dict) -> str: |
| """Format extracted table data as a markdown table.""" |
| if not table_data.get('is_table') or not table_data.get('cells'): |
| return '' |
|
|
| cells = table_data['cells'] |
| if not cells: |
| return '' |
|
|
| num_cols = max(len(row) for row in cells) if cells else 0 |
| if num_cols == 0: |
| return '' |
|
|
| lines = [] |
| col_widths = [3] * num_cols |
|
|
| |
| normalized_cells = [] |
| for row in cells: |
| normalized_row = list(row) + [''] * (num_cols - len(row)) |
| normalized_cells.append(normalized_row) |
| for i, cell in enumerate(normalized_row): |
| if i < num_cols: |
| col_widths[i] = max(col_widths[i], len(str(cell))) |
|
|
| for row_idx, row in enumerate(normalized_cells): |
| formatted_cells = [] |
| for i, cell in enumerate(row): |
| if i < num_cols: |
| formatted_cells.append(str(cell).ljust(col_widths[i])) |
|
|
| line = '| ' + ' | '.join(formatted_cells) + ' |' |
| lines.append(line) |
|
|
| if row_idx == 0: |
| separator = '|' + '|'.join(['-' * (w + 2) for w in col_widths]) + '|' |
| lines.append(separator) |
|
|
| return '\n'.join(lines) |
|
|
|
|
| def extract_text_with_table_detection(image_bytes: bytes, img_width: int, img_height: int) -> tuple: |
| """ |
| Extract tables from image using img2table. |
| Returns (markdown_text, table_data). |
| """ |
| table_data = extract_tables_with_img2table(image_bytes, img_width, img_height) |
|
|
| if table_data.get('is_table'): |
| markdown_table = format_table_as_markdown(table_data) |
| return markdown_table, table_data |
| else: |
| return '', {'is_table': False} |
|
|
|
|
| |
|
|
| def extract_tables_two_stage(image_bytes: bytes, img_width: int, img_height: int, ocr_predictor) -> dict: |
| """ |
| Two-stage table extraction: |
| 1. Detect table structure (cells/grid) WITHOUT OCR |
| 2. Crop each cell and run docTR OCR individually |
| |
| This keeps multi-line text together within cells. |
| """ |
| try: |
| |
| nparr = np.frombuffer(image_bytes, np.uint8) |
| img_array = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
| if img_array is None: |
| return {'is_table': False, 'error': 'Failed to decode image'} |
|
|
| actual_height, actual_width = img_array.shape[:2] |
|
|
| |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: |
| tmp_file.write(image_bytes) |
| tmp_path = tmp_file.name |
|
|
| |
| img2table_img = Img2TableImage(src=tmp_path) |
|
|
| |
| |
| tables = img2table_img.extract_tables( |
| ocr=None, |
| implicit_rows=True, |
| implicit_columns=True, |
| borderless_tables=True |
| ) |
|
|
| |
| try: |
| os.unlink(tmp_path) |
| except: |
| pass |
|
|
| if not tables: |
| print("Two-stage: No tables detected") |
| return {'is_table': False, 'tables': []} |
|
|
| |
| all_tables = [] |
|
|
| for table_idx, table in enumerate(tables): |
| print(f"Two-stage: Processing table {table_idx + 1}") |
|
|
| |
| if not hasattr(table, 'bbox') or not hasattr(table, 'content'): |
| continue |
|
|
| table_bbox = table.bbox |
|
|
| |
| cells_data = [] |
|
|
| |
| if hasattr(table, '_items') and table._items: |
| |
| rows_dict = {} |
|
|
| for cell in table._items: |
| if hasattr(cell, 'bbox'): |
| cell_bbox = cell.bbox |
| row_key = cell_bbox[1] |
|
|
| |
| matched_row = None |
| for existing_row in rows_dict.keys(): |
| if abs(existing_row - row_key) < 10: |
| matched_row = existing_row |
| break |
|
|
| if matched_row is None: |
| matched_row = row_key |
| rows_dict[matched_row] = [] |
|
|
| rows_dict[matched_row].append({ |
| 'bbox': cell_bbox, |
| 'x': cell_bbox[0] |
| }) |
|
|
| |
| sorted_rows = sorted(rows_dict.items(), key=lambda x: x[0]) |
|
|
| |
| table_cells = [] |
|
|
| for row_y, row_cells in sorted_rows: |
| |
| row_cells.sort(key=lambda c: c['x']) |
| row_texts = [] |
|
|
| for cell_info in row_cells: |
| bbox = cell_info['bbox'] |
| x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) |
|
|
| |
| padding = 2 |
| x1 = max(0, x1 - padding) |
| y1 = max(0, y1 - padding) |
| x2 = min(actual_width, x2 + padding) |
| y2 = min(actual_height, y2 + padding) |
|
|
| |
| cell_img = img_array[y1:y2, x1:x2] |
|
|
| if cell_img.size == 0: |
| row_texts.append('') |
| continue |
|
|
| |
| cell_img_rgb = cv2.cvtColor(cell_img, cv2.COLOR_BGR2RGB) |
|
|
| |
| cell_text = ocr_single_cell(cell_img_rgb, ocr_predictor) |
| row_texts.append(cell_text) |
|
|
| if row_texts: |
| table_cells.append(row_texts) |
|
|
| if table_cells: |
| num_cols = max(len(row) for row in table_cells) |
| |
| normalized_cells = [] |
| for row in table_cells: |
| normalized_row = row + [''] * (num_cols - len(row)) |
| normalized_cells.append(normalized_row) |
|
|
| all_tables.append({ |
| 'cells': normalized_cells, |
| 'num_rows': len(normalized_cells), |
| 'num_columns': num_cols, |
| 'method': 'two_stage' |
| }) |
| print(f"Two-stage: Extracted {len(normalized_cells)}x{num_cols} table") |
|
|
| if not all_tables: |
| return {'is_table': False, 'tables': []} |
|
|
| |
| primary_table = max(all_tables, key=lambda t: t['num_rows'] * t['num_columns']) |
|
|
| return { |
| 'is_table': True, |
| 'cells': primary_table['cells'], |
| 'num_rows': primary_table['num_rows'], |
| 'num_columns': primary_table['num_columns'], |
| 'tables': all_tables, |
| 'total_tables': len(all_tables), |
| 'method': 'two_stage' |
| } |
|
|
| except Exception as e: |
| print(f"Two-stage extraction error: {e}") |
| import traceback |
| traceback.print_exc() |
| return {'is_table': False, 'error': str(e), 'method': 'two_stage'} |
|
|
|
|
| def ocr_single_cell(cell_image: np.ndarray, ocr_predictor) -> str: |
| """ |
| Run OCR on a single cell image using docTR. |
| Returns the extracted text with lines joined. |
| """ |
| try: |
| if cell_image.size == 0: |
| return '' |
|
|
| |
| pil_img = Image.fromarray(cell_image) |
| img_byte_arr = io.BytesIO() |
| pil_img.save(img_byte_arr, format='PNG') |
| img_bytes = img_byte_arr.getvalue() |
|
|
| |
| doc = DocumentFile.from_images([img_bytes]) |
| result = ocr_predictor(doc) |
|
|
| |
| lines = [] |
| for page in result.pages: |
| for block in page.blocks: |
| for line in block.lines: |
| line_text = ' '.join(word.value for word in line.words) |
| if line_text.strip(): |
| lines.append(line_text.strip()) |
|
|
| |
| return ' '.join(lines) |
|
|
| except Exception as e: |
| print(f"Cell OCR error: {e}") |
| return '' |
|
|
|
|
| def extract_text_two_stage(image_bytes: bytes, img_width: int, img_height: int, ocr_predictor) -> tuple: |
| """ |
| Two-stage table extraction wrapper. |
| Returns (markdown_text, table_data). |
| """ |
| table_data = extract_tables_two_stage(image_bytes, img_width, img_height, ocr_predictor) |
|
|
| if table_data.get('is_table'): |
| markdown_table = format_table_as_markdown(table_data) |
| return markdown_table, table_data |
| else: |
| return '', {'is_table': False, 'method': 'two_stage'} |
|
|
|
|
| |
|
|
| def extract_tables_borderless(doctr_result, min_columns: int = 2, min_rows: int = 2) -> dict: |
| """ |
| Detect borderless tables by analyzing text positions from docTR. |
| Works when there are no visible grid lines - uses whitespace gaps to infer structure. |
| |
| Algorithm: |
| 1. Collect all words with positions |
| 2. Find column boundaries by detecting consistent vertical gaps |
| 3. Group words into rows by y-position clustering |
| 4. Handle multi-line cells by merging text within same cell bounds |
| """ |
| try: |
| |
| all_words = [] |
| for page in doctr_result.pages: |
| for block in page.blocks: |
| for line in block.lines: |
| for word in line.words: |
| x_min, y_min = word.geometry[0] |
| x_max, y_max = word.geometry[1] |
| all_words.append({ |
| 'text': word.value, |
| 'x_min': x_min, |
| 'x_max': x_max, |
| 'y_min': y_min, |
| 'y_max': y_max, |
| 'x_center': (x_min + x_max) / 2, |
| 'y_center': (y_min + y_max) / 2, |
| 'height': y_max - y_min |
| }) |
|
|
| if len(all_words) < 4: |
| return {'is_table': False, 'reason': 'Too few words'} |
|
|
| print(f"Borderless: Analyzing {len(all_words)} words") |
|
|
| |
| columns = detect_column_boundaries(all_words) |
|
|
| if len(columns) < min_columns: |
| print(f"Borderless: Only {len(columns)} columns detected, need {min_columns}") |
| return {'is_table': False, 'reason': f'Only {len(columns)} columns found'} |
|
|
| print(f"Borderless: Detected {len(columns)} columns") |
|
|
| |
| rows = detect_row_boundaries(all_words) |
|
|
| if len(rows) < min_rows: |
| print(f"Borderless: Only {len(rows)} rows detected, need {min_rows}") |
| return {'is_table': False, 'reason': f'Only {len(rows)} rows found'} |
|
|
| print(f"Borderless: Detected {len(rows)} rows") |
|
|
| |
| cells = build_table_cells(all_words, columns, rows) |
|
|
| |
| non_empty_cells = sum(1 for row in cells for cell in row if cell.strip()) |
| total_cells = len(cells) * len(columns) |
| fill_ratio = non_empty_cells / total_cells if total_cells > 0 else 0 |
|
|
| if fill_ratio < 0.3: |
| print(f"Borderless: Low fill ratio {fill_ratio:.2f}, probably not a table") |
| return {'is_table': False, 'reason': f'Low fill ratio: {fill_ratio:.2f}'} |
|
|
| print(f"Borderless: Built {len(cells)}x{len(columns)} table with {fill_ratio:.2f} fill ratio") |
|
|
| return { |
| 'is_table': True, |
| 'cells': cells, |
| 'num_rows': len(cells), |
| 'num_columns': len(columns), |
| 'method': 'borderless', |
| 'fill_ratio': fill_ratio |
| } |
|
|
| except Exception as e: |
| print(f"Borderless extraction error: {e}") |
| import traceback |
| traceback.print_exc() |
| return {'is_table': False, 'error': str(e), 'method': 'borderless'} |
|
|
|
|
| def detect_column_boundaries(words: list, min_gap: float = 0.03) -> list: |
| """ |
| Detect column boundaries by finding consistent vertical gaps in text. |
| Returns list of (x_start, x_end) tuples for each column. |
| """ |
| if not words: |
| return [] |
|
|
| |
| x_positions = sorted(set(w['x_min'] for w in words)) |
|
|
| if len(x_positions) < 2: |
| return [(0, 1)] |
|
|
| |
| gaps = [] |
| for i in range(1, len(x_positions)): |
| gap = x_positions[i] - x_positions[i-1] |
| gaps.append((x_positions[i-1], x_positions[i], gap)) |
|
|
| |
| |
| significant_gaps = [] |
| for x1, x2, gap in gaps: |
| if gap >= min_gap: |
| |
| gap_mid = (x1 + x2) / 2 |
| rows_with_gap = count_rows_with_gap(words, gap_mid, gap * 0.5) |
| if rows_with_gap >= 2: |
| significant_gaps.append(gap_mid) |
|
|
| |
| if not significant_gaps: |
| |
| return cluster_columns_by_alignment(words, min_gap) |
|
|
| |
| significant_gaps = sorted(set(significant_gaps)) |
|
|
| columns = [] |
| prev_x = 0 |
| for gap_x in significant_gaps: |
| columns.append((prev_x, gap_x)) |
| prev_x = gap_x |
| columns.append((prev_x, 1.0)) |
|
|
| return columns |
|
|
|
|
| def count_rows_with_gap(words: list, gap_x: float, tolerance: float) -> int: |
| """Count how many rows have a gap at the given x position.""" |
| |
| y_groups = {} |
| for word in words: |
| y_key = round(word['y_center'] * 20) / 20 |
| if y_key not in y_groups: |
| y_groups[y_key] = [] |
| y_groups[y_key].append(word) |
|
|
| rows_with_gap = 0 |
| for y_key, row_words in y_groups.items(): |
| |
| words_before = [w for w in row_words if w['x_max'] < gap_x - tolerance] |
| words_after = [w for w in row_words if w['x_min'] > gap_x + tolerance] |
|
|
| if words_before and words_after: |
| rows_with_gap += 1 |
|
|
| return rows_with_gap |
|
|
|
|
| def cluster_columns_by_alignment(words: list, min_gap: float) -> list: |
| """ |
| Cluster columns by finding words that align vertically. |
| Used when gap detection doesn't find clear separators. |
| """ |
| |
| x_mins = sorted(w['x_min'] for w in words) |
|
|
| clusters = [] |
| current_cluster = [x_mins[0]] |
|
|
| for i in range(1, len(x_mins)): |
| if x_mins[i] - x_mins[i-1] <= min_gap: |
| current_cluster.append(x_mins[i]) |
| else: |
| clusters.append(current_cluster) |
| current_cluster = [x_mins[i]] |
| clusters.append(current_cluster) |
|
|
| |
| if len(clusters) < 2: |
| return [(0, 1)] |
|
|
| columns = [] |
| for i, cluster in enumerate(clusters): |
| x_start = min(cluster) - 0.01 |
| if i < len(clusters) - 1: |
| x_end = (max(cluster) + min(clusters[i+1])) / 2 |
| else: |
| x_end = 1.0 |
| columns.append((max(0, x_start), min(1, x_end))) |
|
|
| return columns |
|
|
|
|
| def detect_row_boundaries(words: list, y_tolerance: float = 0.02) -> list: |
| """ |
| Detect row boundaries by clustering y-positions. |
| Returns list of (y_start, y_end) tuples for each row. |
| """ |
| if not words: |
| return [] |
|
|
| |
| sorted_by_y = sorted(words, key=lambda w: w['y_min']) |
|
|
| |
| rows = [] |
| current_row = [sorted_by_y[0]] |
|
|
| for i in range(1, len(sorted_by_y)): |
| word = sorted_by_y[i] |
| prev_word = current_row[-1] |
|
|
| |
| |
| y_overlap = min(word['y_max'], prev_word['y_max']) - max(word['y_min'], prev_word['y_min']) |
| min_height = min(word['height'], prev_word['height']) |
|
|
| if y_overlap > min_height * 0.3 or abs(word['y_center'] - prev_word['y_center']) < y_tolerance: |
| current_row.append(word) |
| else: |
| |
| row_y_min = min(w['y_min'] for w in current_row) |
| row_y_max = max(w['y_max'] for w in current_row) |
| rows.append((row_y_min, row_y_max, current_row)) |
| current_row = [word] |
|
|
| |
| if current_row: |
| row_y_min = min(w['y_min'] for w in current_row) |
| row_y_max = max(w['y_max'] for w in current_row) |
| rows.append((row_y_min, row_y_max, current_row)) |
|
|
| return rows |
|
|
|
|
| def build_table_cells(words: list, columns: list, rows: list) -> list: |
| """ |
| Build table cells by assigning words to their respective cells. |
| Handles multi-line text within cells. |
| """ |
| num_cols = len(columns) |
| table = [] |
|
|
| for row_y_min, row_y_max, row_words in rows: |
| row_cells = [''] * num_cols |
|
|
| |
| row_words_sorted = sorted(row_words, key=lambda w: w['x_min']) |
|
|
| for word in row_words_sorted: |
| |
| word_x = word['x_min'] |
|
|
| for col_idx, (col_start, col_end) in enumerate(columns): |
| if col_start <= word_x < col_end: |
| |
| if row_cells[col_idx]: |
| row_cells[col_idx] += ' ' + word['text'] |
| else: |
| row_cells[col_idx] = word['text'] |
| break |
|
|
| table.append(row_cells) |
|
|
| return table |
|
|
|
|
| def extract_text_borderless(doctr_result) -> tuple: |
| """ |
| Borderless table extraction wrapper. |
| Returns (markdown_text, table_data). |
| """ |
| table_data = extract_tables_borderless(doctr_result) |
|
|
| if table_data.get('is_table'): |
| markdown_table = format_table_as_markdown(table_data) |
| return markdown_table, table_data |
| else: |
| return '', {'is_table': False, 'method': 'borderless'} |
|
|
|
|
| |
|
|
| def extract_tables_block_geometry(doctr_result, min_columns: int = 2, min_rows: int = 2) -> dict: |
| """ |
| Detect tables using docTR's block-level grouping from .export(). |
| If multiple blocks exist at similar y-positions but different x-positions, |
| they likely represent table columns. |
| """ |
| try: |
| exported = doctr_result.export() |
|
|
| if not exported or 'pages' not in exported or not exported['pages']: |
| return {'is_table': False, 'reason': 'No pages in export', 'method': 'block_geometry'} |
|
|
| page = exported['pages'][0] |
| blocks = page.get('blocks', []) |
|
|
| if len(blocks) < 2: |
| return {'is_table': False, 'reason': f'Only {len(blocks)} blocks found', 'method': 'block_geometry'} |
|
|
| print(f"Block-geometry: Analyzing {len(blocks)} blocks") |
|
|
| |
| block_data = [] |
| for block in blocks: |
| geometry = block.get('geometry', []) |
| if len(geometry) < 2: |
| continue |
|
|
| x_min, y_min = geometry[0] |
| x_max, y_max = geometry[1] |
|
|
| block_text_parts = [] |
| for line in block.get('lines', []): |
| line_words = [] |
| for word in line.get('words', []): |
| line_words.append(word.get('value', '')) |
| if line_words: |
| block_text_parts.append(' '.join(line_words)) |
|
|
| block_text = ' '.join(block_text_parts).strip() |
|
|
| if block_text: |
| block_data.append({ |
| 'text': block_text, |
| 'x_min': x_min, |
| 'x_max': x_max, |
| 'y_min': y_min, |
| 'y_max': y_max, |
| 'y_center': (y_min + y_max) / 2, |
| 'x_center': (x_min + x_max) / 2, |
| 'height': y_max - y_min, |
| }) |
|
|
| if len(block_data) < min_columns: |
| return {'is_table': False, 'reason': f'Only {len(block_data)} text blocks', 'method': 'block_geometry'} |
|
|
| |
| block_data.sort(key=lambda b: b['y_min']) |
|
|
| rows = [] |
| current_row = [block_data[0]] |
|
|
| for i in range(1, len(block_data)): |
| block = block_data[i] |
| prev_block = current_row[-1] |
|
|
| y_overlap = min(block['y_max'], prev_block['y_max']) - max(block['y_min'], prev_block['y_min']) |
| min_height = min(block['height'], prev_block['height']) |
|
|
| if min_height > 0 and y_overlap / min_height > 0.3: |
| current_row.append(block) |
| else: |
| rows.append(current_row) |
| current_row = [block] |
|
|
| if current_row: |
| rows.append(current_row) |
|
|
| print(f"Block-geometry: Found {len(rows)} potential rows") |
|
|
| |
| multi_block_rows = [row for row in rows if len(row) >= min_columns] |
|
|
| if len(multi_block_rows) < min_rows: |
| print(f"Block-geometry: Only {len(multi_block_rows)} multi-block rows, need {min_rows}") |
| return {'is_table': False, 'reason': f'Only {len(multi_block_rows)} multi-block rows', 'method': 'block_geometry'} |
|
|
| |
| col_counts = [len(row) for row in multi_block_rows] |
| most_common_count = max(set(col_counts), key=col_counts.count) |
| consistent_rows = [row for row in multi_block_rows if len(row) == most_common_count] |
|
|
| if len(consistent_rows) < min_rows: |
| print(f"Block-geometry: Only {len(consistent_rows)} rows with {most_common_count} columns") |
| return {'is_table': False, 'reason': 'Inconsistent column counts', 'method': 'block_geometry'} |
|
|
| print(f"Block-geometry: {len(consistent_rows)} rows with {most_common_count} columns") |
|
|
| |
| table_cells = [] |
| for row in consistent_rows: |
| row_sorted = sorted(row, key=lambda b: b['x_min']) |
| row_texts = [b['text'] for b in row_sorted] |
| table_cells.append(row_texts) |
|
|
| |
| max_cols = max(len(row) for row in table_cells) if table_cells else 0 |
| normalized_cells = [] |
| for row in table_cells: |
| normalized_row = row + [''] * (max_cols - len(row)) |
| normalized_cells.append(normalized_row) |
|
|
| |
| non_empty = sum(1 for row in normalized_cells for cell in row if cell.strip()) |
| total = len(normalized_cells) * max_cols |
| fill_ratio = non_empty / total if total > 0 else 0 |
|
|
| if fill_ratio < 0.3: |
| print(f"Block-geometry: Low fill ratio {fill_ratio:.2f}") |
| return {'is_table': False, 'reason': f'Low fill ratio: {fill_ratio:.2f}', 'method': 'block_geometry'} |
|
|
| print(f"Block-geometry: Built {len(normalized_cells)}x{max_cols} table with {fill_ratio:.2f} fill ratio") |
|
|
| return { |
| 'is_table': True, |
| 'cells': normalized_cells, |
| 'num_rows': len(normalized_cells), |
| 'num_columns': max_cols, |
| 'method': 'block_geometry', |
| 'fill_ratio': fill_ratio, |
| } |
|
|
| except Exception as e: |
| print(f"Block-geometry extraction error: {e}") |
| import traceback |
| traceback.print_exc() |
| return {'is_table': False, 'error': str(e), 'method': 'block_geometry'} |
|
|
|
|
| def extract_text_block_geometry(doctr_result) -> tuple: |
| """Block-geometry table extraction wrapper.""" |
| table_data = extract_tables_block_geometry(doctr_result) |
|
|
| if table_data.get('is_table'): |
| markdown_table = format_table_as_markdown(table_data) |
| return markdown_table, table_data |
| else: |
| return '', {'is_table': False, 'method': 'block_geometry'} |
|
|
|
|
| def extract_text_structured(result) -> str: |
| """ |
| Extract text from docTR result preserving logical structure. |
| Explicitly sorts words by x-coordinate and lines by y-coordinate. |
| """ |
| all_lines = [] |
|
|
| for page in result.pages: |
| for block in page.blocks: |
| for line in block.lines: |
| |
| words_data = [] |
| for word in line.words: |
| |
| x_pos = word.geometry[0][0] |
| y_pos = word.geometry[0][1] |
| words_data.append({ |
| 'text': word.value, |
| 'x': x_pos, |
| 'y': y_pos |
| }) |
|
|
| if not words_data: |
| continue |
|
|
| |
| words_data.sort(key=lambda w: w['x']) |
|
|
| line_text = " ".join([w['text'] for w in words_data]) |
| avg_y = sum(w['y'] for w in words_data) / len(words_data) |
| min_x = min(w['x'] for w in words_data) |
|
|
| if line_text.strip(): |
| all_lines.append({ |
| 'text': line_text.strip(), |
| 'y': avg_y, |
| 'x': min_x |
| }) |
|
|
| |
| all_lines.sort(key=lambda l: (round(l['y'] * 20) / 20, l['x'])) |
|
|
| |
| result_lines = [] |
| prev_y_group = -1 |
| current_line_parts = [] |
|
|
| for line_info in all_lines: |
| current_y_group = round(line_info['y'] * 20) / 20 |
|
|
| if prev_y_group != -1 and current_y_group != prev_y_group: |
| if current_line_parts: |
| result_lines.append(" ".join(current_line_parts)) |
| current_line_parts = [line_info['text']] |
| else: |
| current_line_parts.append(line_info['text']) |
|
|
| prev_y_group = current_y_group |
|
|
| if current_line_parts: |
| result_lines.append(" ".join(current_line_parts)) |
|
|
| return "\n".join(result_lines) |
|
|
| def generate_synthesized_image(doctr_result) -> Optional[str]: |
| """ |
| Generate a reconstructed document image using docTR's synthesize() method. |
| Returns a base64-encoded PNG string, or None if synthesis fails. |
| """ |
| try: |
| synthetic_pages = doctr_result.synthesize() |
|
|
| if not synthetic_pages or len(synthetic_pages) == 0: |
| print("Synthesize: No pages returned") |
| return None |
|
|
| |
| synth_img = synthetic_pages[0] |
|
|
| |
| pil_img = Image.fromarray(synth_img) |
| img_byte_arr = io.BytesIO() |
| pil_img.save(img_byte_arr, format='PNG') |
| img_bytes = img_byte_arr.getvalue() |
|
|
| b64_string = base64.b64encode(img_bytes).decode('utf-8') |
| print(f"Synthesize: Generated image ({len(b64_string)} chars base64)") |
| return b64_string |
|
|
| except Exception as e: |
| print(f"Synthesize error: {e}") |
| return None |
|
|
|
|
| def extract_words_with_boxes(result) -> list: |
| """ |
| Extract all words with their bounding boxes and confidence from docTR result. |
| Returns list of {word, confidence, bbox} where bbox is [[x0,y0], [x1,y1]] normalized 0-1. |
| """ |
| words_with_boxes = [] |
|
|
| for page in result.pages: |
| for block in page.blocks: |
| for line in block.lines: |
| for word in line.words: |
| |
| bbox = [ |
| [word.geometry[0][0], word.geometry[0][1]], |
| [word.geometry[1][0], word.geometry[1][1]] |
| ] |
| words_with_boxes.append({ |
| 'word': word.value, |
| 'confidence': word.confidence, |
| 'bbox': bbox |
| }) |
|
|
| return words_with_boxes |
|
|
| def check_drug_interactions(detected_drugs: List[str]) -> List[Dict]: |
| """ |
| Check for known interactions between detected drugs. |
| Returns list of interaction warnings. |
| """ |
| interactions = [] |
| drugs_lower = [d.lower().strip() for d in detected_drugs] |
|
|
| |
| for i, drug1 in enumerate(drugs_lower): |
| for drug2 in drugs_lower[i+1:]: |
| |
| if drug1 in DRUG_INTERACTIONS: |
| if drug2 in DRUG_INTERACTIONS[drug1]: |
| interaction = DRUG_INTERACTIONS[drug1][drug2] |
| interactions.append({ |
| 'drug1': detected_drugs[i], |
| 'drug2': detected_drugs[drugs_lower.index(drug2)], |
| 'severity': interaction.get('severity', 'info'), |
| 'description': interaction.get('description', ''), |
| 'recommendation': interaction.get('recommendation'), |
| }) |
| |
| elif drug2 in DRUG_INTERACTIONS: |
| if drug1 in DRUG_INTERACTIONS[drug2]: |
| interaction = DRUG_INTERACTIONS[drug2][drug1] |
| interactions.append({ |
| 'drug1': detected_drugs[drugs_lower.index(drug2)], |
| 'drug2': detected_drugs[i], |
| 'severity': interaction.get('severity', 'info'), |
| 'description': interaction.get('description', ''), |
| 'recommendation': interaction.get('recommendation'), |
| }) |
|
|
| return interactions |
|
|
| |
|
|
| def parse_reference_range(range_str: str): |
| """ |
| Parse reference range strings from lab documents. |
| Formats: "(13.5 - 18.0)", "(< 200)", "(> 60)", "(< 0.61)" |
| Returns: (low, high) where either can be None |
| """ |
| if not range_str: |
| return None, None |
|
|
| |
| s = range_str.strip().strip('()').strip() |
|
|
| |
| m = re.match(r'^[<\u2264]\s*(\d+\.?\d*)$', s) |
| if m: |
| return None, float(m.group(1)) |
|
|
| |
| m = re.match(r'^[>\u2265]\s*(\d+\.?\d*)$', s) |
| if m: |
| return float(m.group(1)), None |
|
|
| |
| m = re.match(r'(\d+\.?\d*)\s*[-\u2013]\s*(\d+\.?\d*)', s) |
| if m: |
| return float(m.group(1)), float(m.group(2)) |
|
|
| return None, None |
|
|
|
|
| def extract_lab_values_from_words(words_with_boxes: List[Dict]) -> List[Dict]: |
| """ |
| Extract lab values using word positions from docTR. |
| Groups words into rows by y-coordinate, then identifies columns |
| (test name, value, unit, range) by x-position within each row. |
| This is the most reliable method since it uses spatial layout. |
| """ |
| extracted = [] |
| if not words_with_boxes: |
| return extracted |
|
|
| |
| ROW_TOLERANCE = 0.015 |
| rows = [] |
| sorted_words = sorted(words_with_boxes, key=lambda w: (w['bbox'][0][1], w['bbox'][0][0])) |
|
|
| current_row = [] |
| current_y = None |
|
|
| for word_info in sorted_words: |
| y_center = (word_info['bbox'][0][1] + word_info['bbox'][1][1]) / 2 |
| if current_y is None or abs(y_center - current_y) < ROW_TOLERANCE: |
| current_row.append(word_info) |
| if current_y is None: |
| current_y = y_center |
| else: |
| current_y = (current_y + y_center) / 2 |
| else: |
| if current_row: |
| rows.append(sorted(current_row, key=lambda w: w['bbox'][0][0])) |
| current_row = [word_info] |
| current_y = y_center |
|
|
| if current_row: |
| rows.append(sorted(current_row, key=lambda w: w['bbox'][0][0])) |
|
|
| |
| UNITS = {'mg/dl', 'mmol/l', 'g/dl', 'u/l', 'miu/l', 'ng/dl', 'pg/ml', |
| 'ug/dl', 'ng/ml', 'fl', 'pg', '%', 'mm/hr', 'mg/l', 'mg/mmol', |
| 'ug/l', 'ml/min/1.73m2'} |
|
|
| SKIP_WORDS = {'result', 'unit', 'ref.range', 'ref', 'range', 'reference', |
| 'date', 'request', 'no', 'no:'} |
|
|
| for row in rows: |
| words_text = [w['word'] for w in row] |
| row_str = ' '.join(words_text).lower() |
|
|
| |
| if 'result' in row_str and ('unit' in row_str or 'ref' in row_str): |
| continue |
| if 'profile' in row_str and len(words_text) <= 3: |
| continue |
| if 'function' in row_str and len(words_text) <= 3: |
| continue |
|
|
| |
| name_parts = [] |
| value = None |
| unit = '' |
| range_parts = [] |
| is_flagged = False |
| in_range = False |
|
|
| for w in row: |
| word = w['word'].strip() |
| word_lower = word.lower().strip('()') |
|
|
| if not word: |
| continue |
|
|
| |
| if '(' in word or in_range: |
| in_range = True |
| range_parts.append(word) |
| if ')' in word: |
| in_range = False |
| continue |
|
|
| |
| if word == '*': |
| is_flagged = True |
| continue |
|
|
| |
| if word_lower in UNITS or word_lower.replace('/', '').replace('.', '').replace('1', '').replace('3', '').replace('7', '').replace('m', '').replace('2', '') == '': |
| cleaned_unit = word_lower |
| if cleaned_unit in UNITS: |
| unit = word |
| continue |
|
|
| |
| if 'x10' in word_lower or '10⁹' in word or '10¹²' in word: |
| unit = word |
| continue |
|
|
| |
| cleaned_word = word.lstrip('*').strip() |
| try: |
| num = float(cleaned_word) |
| if value is None: |
| value = num |
| if '*' in word: |
| is_flagged = True |
| continue |
| except ValueError: |
| pass |
|
|
| |
| if word_lower in SKIP_WORDS: |
| continue |
|
|
| |
| if all('\u4e00' <= c <= '\u9fff' or c in '()()' for c in word): |
| continue |
|
|
| |
| if any(c.isalpha() for c in word): |
| name_parts.append(word) |
|
|
| |
| range_str = ' '.join(range_parts).strip('() ') |
| ref_low, ref_high = parse_reference_range(range_str) |
|
|
| test_name = ' '.join(name_parts).strip() |
|
|
| |
| if test_name and value is not None and (ref_low is not None or ref_high is not None): |
| |
| if test_name.upper() == test_name and len(test_name.split()) > 2: |
| continue |
|
|
| extracted.append({ |
| 'test_name': test_name, |
| 'value': value, |
| 'unit': unit, |
| 'ref_low': ref_low, |
| 'ref_high': ref_high, |
| 'ref_range_str': range_str, |
| 'is_flagged_in_document': is_flagged, |
| }) |
|
|
| return extracted |
|
|
|
|
| def extract_lab_values_from_text(structured_text: str) -> List[Dict]: |
| """ |
| Extract test name, value, unit, and reference range from OCR structured text. |
| Handles the document format: TestName [ChineseName] Result Unit (Range) |
| """ |
| extracted = [] |
| if not structured_text: |
| return extracted |
|
|
| lines = structured_text.split('\n') |
| for line in lines: |
| line = line.strip() |
| if not line or len(line) < 5: |
| continue |
|
|
| |
| |
| |
| |
|
|
| |
| range_match = re.search(r'\(([<>\u2264\u2265]?\s*\d+\.?\d*(?:\s*[-\u2013]\s*\d+\.?\d*)?)\)\s*$', line) |
| ref_range_str = None |
| ref_low, ref_high = None, None |
| if range_match: |
| ref_range_str = range_match.group(1) |
| ref_low, ref_high = parse_reference_range(ref_range_str) |
| line = line[:range_match.start()].strip() |
|
|
| |
| if ref_low is None and ref_high is None: |
| continue |
|
|
| |
| |
| value_match = re.search(r'\*?\s*(\d+\.?\d*)\s+(mg/dL|mmol/L|g/dL|U/L|mIU/L|ng/dL|pg/mL|ug/dL|ng/mL|fL|pg|%|mm/hr|mg/L|mg/mmol|x10\^?\d+/L|mL/min/1\.73m2|ug/L)', line, re.IGNORECASE) |
|
|
| if not value_match: |
| |
| value_match = re.search(r'\*?\s*(\d+\.?\d+)\s*$', line) |
| if not value_match: |
| |
| value_match = re.search(r'(?:[\u4e00-\u9fff\s]+|\s+)\*?\s*(\d+\.?\d*)\s', line) |
|
|
| if not value_match: |
| continue |
|
|
| try: |
| value = float(value_match.group(1)) |
| except (ValueError, IndexError): |
| continue |
|
|
| |
| unit = '' |
| if value_match.lastindex and value_match.lastindex >= 2: |
| unit = value_match.group(2) |
|
|
| |
| |
| chinese_start = re.search(r'[\u4e00-\u9fff]', line) |
| if chinese_start: |
| test_name = line[:chinese_start.start()].strip() |
| else: |
| test_name = line[:value_match.start()].strip() |
|
|
| |
| test_name = test_name.strip().rstrip(':').strip() |
|
|
| |
| is_flagged = '*' in line[:value_match.end()] |
|
|
| if test_name and len(test_name) >= 2: |
| extracted.append({ |
| 'test_name': test_name, |
| 'value': value, |
| 'unit': unit, |
| 'ref_low': ref_low, |
| 'ref_high': ref_high, |
| 'ref_range_str': ref_range_str or '', |
| 'is_flagged_in_document': is_flagged, |
| }) |
|
|
| return extracted |
|
|
|
|
| def extract_lab_values_from_table(table_data: Dict) -> List[Dict]: |
| """ |
| Extract lab values from structured table data. |
| table_data has 'cells' (list of rows, each row is list of cell strings). |
| """ |
| extracted = [] |
| cells = table_data.get('cells', []) |
| if not cells or len(cells) < 2: |
| return extracted |
|
|
| |
| header_row = cells[0] if cells else [] |
| name_col = -1 |
| value_col = -1 |
| unit_col = -1 |
| range_col = -1 |
|
|
| for i, cell in enumerate(header_row): |
| cell_lower = cell.strip().lower() |
| if any(kw in cell_lower for kw in ['test', 'name', 'parameter', 'investigation']): |
| name_col = i |
| elif 'result' in cell_lower: |
| value_col = i |
| elif 'unit' in cell_lower: |
| unit_col = i |
| elif any(kw in cell_lower for kw in ['ref', 'range', 'normal', 'reference']): |
| range_col = i |
|
|
| |
| if name_col == -1 or value_col == -1: |
| |
| for row in cells[1:]: |
| if len(row) < 2: |
| continue |
|
|
| test_name = None |
| value = None |
| unit = '' |
| ref_range_str = '' |
| is_flagged = False |
|
|
| for cell in row: |
| cell_stripped = cell.strip() |
| if not cell_stripped: |
| continue |
|
|
| |
| range_m = re.match(r'^\(([<>\u2264\u2265]?\s*\d+\.?\d*(?:\s*[-\u2013]\s*\d+\.?\d*)?)\)$', cell_stripped) |
| if range_m: |
| ref_range_str = range_m.group(1) |
| continue |
|
|
| |
| val_m = re.match(r'^\*?\s*(\d+\.?\d*)$', cell_stripped) |
| if val_m and test_name is not None and value is None: |
| value = float(val_m.group(1)) |
| is_flagged = '*' in cell_stripped |
| continue |
|
|
| |
| if cell_stripped.lower() in ['mg/dl', 'mmol/l', 'g/dl', 'u/l', 'miu/l', 'ng/dl', |
| 'pg/ml', 'ug/dl', 'ng/ml', 'fl', 'pg', '%', 'mm/hr', |
| 'mg/l', 'mg/mmol', 'ug/l', 'ml/min/1.73m2']: |
| unit = cell_stripped |
| continue |
|
|
| |
| if re.match(r'x10\^?\d+/L', cell_stripped, re.IGNORECASE): |
| unit = cell_stripped |
| continue |
|
|
| |
| if any(c.isalpha() for c in cell_stripped) and test_name is None: |
| |
| if not all('\u4e00' <= c <= '\u9fff' or c.isspace() for c in cell_stripped): |
| test_name = cell_stripped |
|
|
| if test_name and value is not None and ref_range_str: |
| ref_low, ref_high = parse_reference_range(ref_range_str) |
| if ref_low is not None or ref_high is not None: |
| extracted.append({ |
| 'test_name': test_name, |
| 'value': value, |
| 'unit': unit, |
| 'ref_low': ref_low, |
| 'ref_high': ref_high, |
| 'ref_range_str': ref_range_str, |
| 'is_flagged_in_document': is_flagged, |
| }) |
| return extracted |
|
|
| |
| for row in cells[1:]: |
| if len(row) <= max(name_col, value_col): |
| continue |
|
|
| test_name = row[name_col].strip() if name_col < len(row) else '' |
| value_str = row[value_col].strip() if value_col < len(row) else '' |
| unit = row[unit_col].strip() if unit_col >= 0 and unit_col < len(row) else '' |
| ref_range_str = row[range_col].strip().strip('()') if range_col >= 0 and range_col < len(row) else '' |
|
|
| if not test_name or not value_str: |
| continue |
|
|
| |
| is_flagged = '*' in value_str |
| val_m = re.search(r'(\d+\.?\d*)', value_str) |
| if not val_m: |
| continue |
|
|
| value = float(val_m.group(1)) |
| ref_low, ref_high = parse_reference_range(ref_range_str) |
|
|
| if ref_low is not None or ref_high is not None: |
| extracted.append({ |
| 'test_name': test_name, |
| 'value': value, |
| 'unit': unit, |
| 'ref_low': ref_low, |
| 'ref_high': ref_high, |
| 'ref_range_str': ref_range_str, |
| 'is_flagged_in_document': is_flagged, |
| }) |
|
|
| return extracted |
|
|
|
|
| def classify_lab_value(value: float, ref_low, ref_high) -> str: |
| """ |
| Classify a lab value against reference range. |
| Returns: 'critical_low', 'low', 'normal', 'high', 'critical_high' |
| """ |
| if ref_low is not None and value < ref_low: |
| |
| if value < ref_low * 0.7: |
| return 'critical_low' |
| return 'low' |
|
|
| if ref_high is not None and value > ref_high: |
| |
| if value > ref_high * 1.5: |
| return 'critical_high' |
| return 'high' |
|
|
| return 'normal' |
|
|
|
|
| def match_test_to_medlineplus(test_name: str) -> Optional[Dict]: |
| """ |
| Fuzzy-match a test name against the MedlinePlus map. |
| Returns the map entry if matched, None otherwise. |
| """ |
| if not MEDLINEPLUS_MAP: |
| return None |
|
|
| name_lower = test_name.lower().strip() |
|
|
| |
| if name_lower in MEDLINEPLUS_MAP: |
| return MEDLINEPLUS_MAP[name_lower] |
|
|
| |
| for key, data in MEDLINEPLUS_MAP.items(): |
| aliases = [a.lower() for a in data.get('aliases', [])] |
| if name_lower in aliases: |
| return data |
|
|
| |
| for key, data in MEDLINEPLUS_MAP.items(): |
| if key in name_lower or name_lower in key: |
| return data |
| for alias in data.get('aliases', []): |
| if alias.lower() in name_lower or name_lower in alias.lower(): |
| return data |
|
|
| |
| all_names = list(MEDLINEPLUS_MAP.keys()) |
| for key, data in MEDLINEPLUS_MAP.items(): |
| all_names.extend([a.lower() for a in data.get('aliases', [])]) |
|
|
| close = difflib.get_close_matches(name_lower, all_names, n=1, cutoff=0.7) |
| if close: |
| matched_name = close[0] |
| if matched_name in MEDLINEPLUS_MAP: |
| return MEDLINEPLUS_MAP[matched_name] |
| for key, data in MEDLINEPLUS_MAP.items(): |
| if matched_name in [a.lower() for a in data.get('aliases', [])]: |
| return data |
|
|
| return None |
|
|
|
|
| def get_medlineplus_info(slug: str, status: str) -> Dict: |
| """ |
| Get educational info from MedlinePlus cache for a given test slug and status. |
| Falls back to fetching from MedlinePlus if not cached. |
| """ |
| url = f"https://medlineplus.gov/lab-tests/{slug}/" |
|
|
| |
| if slug in MEDLINEPLUS_CACHE: |
| cached = MEDLINEPLUS_CACHE[slug] |
| direction = 'high' if 'high' in status else 'low' |
| return { |
| 'url': cached.get('url', url), |
| 'description': cached.get(direction, ''), |
| } |
|
|
| |
| try: |
| response = httpx.get(url, timeout=5.0, follow_redirects=True) |
| if response.status_code == 200: |
| soup = BeautifulSoup(response.text, 'html.parser') |
|
|
| |
| results_section = None |
| for heading in soup.find_all(['h2', 'h3']): |
| if 'results' in heading.get_text().lower() and 'mean' in heading.get_text().lower(): |
| results_section = heading |
| break |
|
|
| description = '' |
| if results_section: |
| |
| content_parts = [] |
| for sibling in results_section.find_next_siblings(): |
| if sibling.name in ['h2', 'h3']: |
| break |
| text = sibling.get_text(strip=True) |
| if text: |
| content_parts.append(text) |
| description = ' '.join(content_parts[:3]) |
|
|
| |
| MEDLINEPLUS_CACHE[slug] = { |
| 'url': url, |
| 'high': description, |
| 'low': description, |
| 'fetched_at': 'runtime' |
| } |
|
|
| return { |
| 'url': url, |
| 'description': description, |
| } |
| except Exception as e: |
| print(f"MedlinePlus fetch failed for {slug}: {e}") |
|
|
| return {'url': url, 'description': ''} |
|
|
|
|
| def check_lab_values(structured_text: str, table_data: Optional[Dict], words_with_boxes: Optional[List[Dict]] = None) -> List[Dict]: |
| """ |
| Extract lab values from OCR output and check against reference ranges. |
| Uses three extraction methods in priority order: |
| 1. Word-position-based (most reliable — uses spatial layout from docTR) |
| 2. Table-based (if table was detected) |
| 3. Text regex-based (fallback) |
| Returns list of lab anomaly results. |
| """ |
| |
| extracted = [] |
| if words_with_boxes: |
| extracted = extract_lab_values_from_words(words_with_boxes) |
| print(f"Lab extraction (word-position): found {len(extracted)} values") |
|
|
| |
| if table_data and table_data.get('is_table'): |
| table_extracted = extract_lab_values_from_table(table_data) |
| print(f"Lab extraction (table): found {len(table_extracted)} values") |
| existing_names = {e['test_name'].lower() for e in extracted} |
| for te in table_extracted: |
| if te['test_name'].lower() not in existing_names: |
| extracted.append(te) |
| existing_names.add(te['test_name'].lower()) |
|
|
| |
| text_extracted = extract_lab_values_from_text(structured_text) |
| print(f"Lab extraction (text-regex): found {len(text_extracted)} values") |
|
|
| |
| existing_names = {e['test_name'].lower() for e in extracted} |
| for te in text_extracted: |
| if te['test_name'].lower() not in existing_names: |
| extracted.append(te) |
| existing_names.add(te['test_name'].lower()) |
|
|
| |
| results = [] |
| for item in extracted: |
| status = classify_lab_value(item['value'], item['ref_low'], item['ref_high']) |
|
|
| |
| if item['ref_low'] is not None and item['ref_high'] is not None: |
| range_display = f"{item['ref_low']} - {item['ref_high']}" |
| elif item['ref_high'] is not None: |
| range_display = f"< {item['ref_high']}" |
| elif item['ref_low'] is not None: |
| range_display = f"> {item['ref_low']}" |
| else: |
| range_display = item.get('ref_range_str', '') |
|
|
| |
| medlineplus_entry = match_test_to_medlineplus(item['test_name']) |
| description = '' |
| medlineplus_url = None |
| category = 'General' |
|
|
| if medlineplus_entry: |
| slug = medlineplus_entry.get('slug', '') |
| category = medlineplus_entry.get('category', 'General') |
| medlineplus_url = f"https://medlineplus.gov/lab-tests/{slug}/" |
|
|
| if status != 'normal' and slug: |
| info = get_medlineplus_info(slug, status) |
| description = info.get('description', '') |
| medlineplus_url = info.get('url', medlineplus_url) |
|
|
| results.append({ |
| 'test_name': item['test_name'], |
| 'value': item['value'], |
| 'unit': item['unit'], |
| 'status': status, |
| 'ref_low': item['ref_low'], |
| 'ref_high': item['ref_high'], |
| 'reference_range': range_display, |
| 'category': category, |
| 'description': description, |
| 'medlineplus_url': medlineplus_url, |
| 'is_flagged_in_document': item.get('is_flagged_in_document', False), |
| }) |
|
|
| return results |
|
|
| def map_entities_to_boxes(entities: list, words_with_boxes: list, cleaned_text: str) -> list: |
| """ |
| Map NER entities back to word bounding boxes. |
| Uses fuzzy matching to find entity words in OCR words. |
| """ |
| entities_with_boxes = [] |
|
|
| for entity in entities: |
| entity_word = entity['word'].lower().strip() |
| entity_parts = entity_word.split() |
|
|
| |
| matched_boxes = [] |
| for word_info in words_with_boxes: |
| ocr_word = word_info['word'].lower().strip() |
| |
| for part in entity_parts: |
| if part in ocr_word or ocr_word in part: |
| matched_boxes.append(word_info['bbox']) |
| break |
|
|
| |
| if matched_boxes: |
| |
| min_x = min(box[0][0] for box in matched_boxes) |
| min_y = min(box[0][1] for box in matched_boxes) |
| max_x = max(box[1][0] for box in matched_boxes) |
| max_y = max(box[1][1] for box in matched_boxes) |
| combined_bbox = [[min_x, min_y], [max_x, max_y]] |
| else: |
| combined_bbox = None |
|
|
| entities_with_boxes.append({ |
| 'entity_group': entity['entity_group'], |
| 'score': entity['score'], |
| 'word': entity['word'], |
| 'bbox': combined_bbox |
| }) |
|
|
| return entities_with_boxes |
|
|
| |
|
|
| @app.get("/") |
| async def root(): |
| """Health check endpoint.""" |
| return {"status": "running", "message": "ScanAssured OCR & NER API"} |
|
|
| @app.get("/models") |
| async def get_available_models(): |
| """Return all available OCR and NER models.""" |
| return { |
| "ocr_presets": [ |
| { |
| "id": preset_id, |
| "name": preset_data["name"], |
| "description": preset_data["description"] |
| } |
| for preset_id, preset_data in OCR_PRESETS.items() |
| ], |
| "ocr_detection_models": OCR_DETECTION_MODELS, |
| "ocr_recognition_models": OCR_RECOGNITION_MODELS, |
| "ner_models": { |
| model_id: { |
| "name": model_data["name"], |
| "description": model_data["description"], |
| "entities": model_data["entities"] |
| } |
| for model_id, model_data in NER_MODELS.items() |
| }, |
| "ocr_correction_model": { |
| "id": "ner-dictionary", |
| "name": "NER-Informed Dictionary Correction", |
| "description": "Edit-distance correction against medical entity dictionaries, guided by NER entity labels", |
| } |
| } |
|
|
| @app.post("/process") |
| async def process_image( |
| file: UploadFile = File(...), |
| ner_model_id: str = Form(...), |
| ocr_preset: str = Form("balanced"), |
| ocr_det_model: Optional[str] = Form(None), |
| ocr_reco_model: Optional[str] = Form(None), |
| enable_correction: str = Form("false"), |
| correction_threshold: str = Form("0.75"), |
| ): |
| """Process an image with OCR and NER.""" |
|
|
| |
| if ocr_det_model and ocr_reco_model: |
| det_arch = ocr_det_model |
| reco_arch = ocr_reco_model |
| else: |
| preset = OCR_PRESETS.get(ocr_preset, OCR_PRESETS["balanced"]) |
| det_arch = preset["det"] |
| reco_arch = preset["reco"] |
|
|
| |
| if ner_model_id not in NER_MODELS: |
| return JSONResponse( |
| status_code=400, |
| content={"detail": f"Unknown NER model: {ner_model_id}"} |
| ) |
|
|
| |
| ocr_predictor_instance = get_ocr_predictor(det_arch, reco_arch) |
| if not ocr_predictor_instance: |
| return JSONResponse( |
| status_code=503, |
| content={"detail": f"Failed to load OCR model: {det_arch}/{reco_arch}"} |
| ) |
|
|
| |
| ner_pipeline = get_ner_pipeline(ner_model_id) |
| if not ner_pipeline: |
| return JSONResponse( |
| status_code=503, |
| content={"detail": f"Failed to load NER model: {ner_model_id}"} |
| ) |
|
|
| try: |
| |
| file_content = await file.read() |
| preprocessed_img = preprocess_for_doctr(file_content) |
|
|
| |
| print("Running docTR OCR...") |
| |
| pil_img = Image.fromarray(preprocessed_img) |
| img_byte_arr = io.BytesIO() |
| pil_img.save(img_byte_arr, format='PNG') |
| img_bytes = img_byte_arr.getvalue() |
|
|
| doc = DocumentFile.from_images([img_bytes]) |
| result = ocr_predictor_instance(doc) |
|
|
| |
| img_height, img_width = preprocessed_img.shape[:2] |
|
|
| |
| structured_text = extract_text_structured(result) |
| cleaned_text = basic_cleanup(structured_text) |
| words_with_boxes = extract_words_with_boxes(result) |
|
|
| print(f"OCR Structured Text:\n{structured_text[:500]}...") |
| print(f"Extracted {len(words_with_boxes)} words with bounding boxes") |
|
|
| |
| print("Generating synthesized document image...") |
| synthesized_image = generate_synthesized_image(result) |
|
|
| |
| print("Running Docling pipeline for comparison...") |
| docling_result = run_docling_pipeline(file_content) |
|
|
| |
| print("Running img2table for table detection (Method 1: integrated OCR)...") |
| table_formatted_text, table_data = extract_text_with_table_detection( |
| img_bytes, img_width, img_height |
| ) |
|
|
| |
| print("Running two-stage table detection (Method 2: structure + cell OCR)...") |
| two_stage_text, two_stage_data = extract_text_two_stage( |
| img_bytes, img_width, img_height, ocr_predictor_instance |
| ) |
|
|
| |
| print("Running borderless table detection (Method 3: text position analysis)...") |
| borderless_text, borderless_data = extract_text_borderless(result) |
|
|
| |
| print("Running block-geometry table detection (Method 4: docTR block analysis)...") |
| block_geo_text, block_geo_data = extract_text_block_geometry(result) |
|
|
| |
| |
| if two_stage_data.get('is_table'): |
| display_text = two_stage_text |
| primary_table_data = two_stage_data |
| print(f"Using Two-stage: {two_stage_data.get('num_rows', 0)}x{two_stage_data.get('num_columns', 0)} table") |
| elif table_data.get('is_table'): |
| display_text = table_formatted_text |
| primary_table_data = table_data |
| print(f"Using img2table: {table_data.get('num_rows', 0)}x{table_data.get('num_columns', 0)} table") |
| elif borderless_data.get('is_table'): |
| display_text = borderless_text |
| primary_table_data = borderless_data |
| print(f"Using Borderless: {borderless_data.get('num_rows', 0)}x{borderless_data.get('num_columns', 0)} table") |
| elif block_geo_data.get('is_table'): |
| display_text = block_geo_text |
| primary_table_data = block_geo_data |
| print(f"Using Block-geometry: {block_geo_data.get('num_rows', 0)}x{block_geo_data.get('num_columns', 0)} table") |
| else: |
| display_text = structured_text |
| primary_table_data = {'is_table': False} |
| print("No table detected by any method, using regular OCR text") |
|
|
| |
| correction_enabled = enable_correction.lower() == "true" |
| correction_result = {'corrected_text': cleaned_text, 'corrections': []} |
|
|
| |
| ner_input_text = cleaned_text |
|
|
| |
| print("Running NER...") |
| entities = ner_pipeline(ner_input_text) |
|
|
| |
| structured_entities = [] |
| for entity in entities: |
| if entity.get('score', 0.0) > 0.1: |
| structured_entities.append({ |
| 'entity_group': entity['entity_group'], |
| 'score': float(entity['score']), |
| 'word': entity['word'].strip(), |
| }) |
|
|
| |
| entities_with_boxes = map_entities_to_boxes(structured_entities, words_with_boxes, ner_input_text) |
|
|
| |
| if correction_enabled: |
| ner_corr = correct_with_ner_entities( |
| words_with_boxes, structured_entities, |
| correction_result['corrected_text'], confidence_threshold=float(correction_threshold)) |
| if ner_corr['corrections']: |
| correction_result['corrections'].extend(ner_corr['corrections']) |
| correction_result['corrected_text'] = ner_corr['corrected_text'] |
| print(f"NER-informed correction: {len(ner_corr['corrections'])} additional fix(es)") |
|
|
| |
| detected_drugs = [] |
| for entity in structured_entities: |
| if entity['entity_group'] in ['CHEM', 'CHEMICAL', 'TREATMENT', 'MEDICATION']: |
| detected_drugs.append(entity['word']) |
|
|
| interactions = check_drug_interactions(detected_drugs) if detected_drugs else [] |
| print(f"Found {len(interactions)} drug interactions") |
|
|
| |
| lab_anomalies = check_lab_values(structured_text, primary_table_data, words_with_boxes) |
| print(f"Found {len(lab_anomalies)} lab values ({sum(1 for a in lab_anomalies if a['status'] != 'normal')} abnormal)") |
|
|
| return { |
| "structured_text": display_text, |
| "cleaned_text": cleaned_text, |
| "corrected_text": correction_result['corrected_text'] if correction_enabled else None, |
| "corrections": correction_result['corrections'] if correction_enabled else [], |
| "medical_entities": entities_with_boxes, |
| "interactions": interactions, |
| "lab_anomalies": lab_anomalies, |
| "model_id": NER_MODELS[ner_model_id]["name"], |
| "ocr_model": f"{det_arch} + {reco_arch}", |
| "image_width": img_width, |
| "image_height": img_height, |
| "synthesized_image": synthesized_image, |
| |
| "table_detected": primary_table_data.get('is_table', False), |
| "table_data": { |
| "num_columns": primary_table_data.get('num_columns', 0), |
| "num_rows": primary_table_data.get('num_rows', 0), |
| "cells": primary_table_data.get('cells', []), |
| "method": primary_table_data.get('method', 'unknown') |
| } if primary_table_data.get('is_table') else None, |
| |
| "table_comparison": { |
| "method1_img2table": { |
| "name": "img2table (line detection + integrated OCR)", |
| "detected": table_data.get('is_table', False), |
| "num_columns": table_data.get('num_columns', 0), |
| "num_rows": table_data.get('num_rows', 0), |
| "cells": table_data.get('cells', []), |
| "formatted_text": table_formatted_text if table_data.get('is_table') else None |
| }, |
| "method2_two_stage": { |
| "name": "Two-stage (structure detection + cell-by-cell OCR)", |
| "detected": two_stage_data.get('is_table', False), |
| "num_columns": two_stage_data.get('num_columns', 0), |
| "num_rows": two_stage_data.get('num_rows', 0), |
| "cells": two_stage_data.get('cells', []), |
| "formatted_text": two_stage_text if two_stage_data.get('is_table') else None |
| }, |
| "method3_borderless": { |
| "name": "Borderless (text position clustering)", |
| "detected": borderless_data.get('is_table', False), |
| "num_columns": borderless_data.get('num_columns', 0), |
| "num_rows": borderless_data.get('num_rows', 0), |
| "cells": borderless_data.get('cells', []), |
| "formatted_text": borderless_text if borderless_data.get('is_table') else None, |
| "fill_ratio": borderless_data.get('fill_ratio', 0) |
| }, |
| "method4_block_geometry": { |
| "name": "Block-geometry (docTR block grouping)", |
| "detected": block_geo_data.get('is_table', False), |
| "num_columns": block_geo_data.get('num_columns', 0), |
| "num_rows": block_geo_data.get('num_rows', 0), |
| "cells": block_geo_data.get('cells', []), |
| "formatted_text": block_geo_text if block_geo_data.get('is_table') else None, |
| "fill_ratio": block_geo_data.get('fill_ratio', 0) |
| } |
| }, |
| |
| "docling_result": { |
| "available": docling_result.get("success", False), |
| "markdown_text": docling_result.get("markdown_text", ""), |
| "plain_text": docling_result.get("plain_text", ""), |
| "table_detected": bool(docling_result.get("tables")), |
| "table_data": docling_result.get("primary_table"), |
| "error": docling_result.get("error"), |
| } if docling_result else { |
| "available": False, |
| "error": "Docling pipeline did not run", |
| } |
| } |
|
|
| except Exception as e: |
| print(f"Processing error: {e}") |
| import traceback |
| traceback.print_exc() |
| return JSONResponse( |
| status_code=500, |
| content={"detail": f"An error occurred during processing: {str(e)}"} |
| ) |
|
|