Spaces:
Paused
Paused
| """HF Dots.OCR Text Extraction Endpoint | |
| This FastAPI application provides a Hugging Face Space endpoint for Dots.OCR | |
| text extraction with ROI support and standardized field extraction schema. | |
| """ | |
| import logging | |
| import os | |
| import time | |
| import uuid | |
| import json | |
| import re | |
| from typing import List, Optional, Dict, Any | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, File, Form, HTTPException, UploadFile | |
| from fastapi.responses import JSONResponse | |
| # Import local modules | |
| from .api_models import ( | |
| BoundingBox, | |
| ExtractedField, | |
| ExtractedFields, | |
| MRZData, | |
| OCRDetection, | |
| OCRResponse, | |
| ) | |
| from .enhanced_field_extraction import EnhancedFieldExtractor | |
| from .model_loader import load_model, extract_text, is_model_loaded, get_model_info | |
| from .preprocessing import process_document, validate_file_size, get_document_info | |
| from .response_builder import build_ocr_response, build_error_response | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Enable verbose logging globally if DOTS_OCR_DEBUG env var is set. | |
| _env_debug = os.getenv("DOTS_OCR_DEBUG", "0").lower() in {"1", "true", "yes"} | |
| if _env_debug: | |
| # Elevate root logger to DEBUG to include lower-level events from submodules | |
| logging.getLogger().setLevel(logging.DEBUG) | |
| logger.info("DOTS_OCR_DEBUG enabled via environment — verbose logging active") | |
| # Global model state | |
| model_loaded = False | |
| # FieldExtractor is now imported from the shared module | |
| async def lifespan(app: FastAPI): | |
| """Application lifespan manager for model loading.""" | |
| global model_loaded | |
| # Allow tests and lightweight environments to skip model loading | |
| # Set DOTS_OCR_SKIP_MODEL_LOAD=1 to bypass heavy downloads during tests/CI | |
| skip_model_load = os.getenv("DOTS_OCR_SKIP_MODEL_LOAD", "0") == "1" | |
| logger.info("Loading Dots.OCR model...") | |
| try: | |
| if skip_model_load: | |
| # Explicitly skip model loading for fast startup in tests/CI | |
| model_loaded = False | |
| logger.warning( | |
| "DOTS_OCR_SKIP_MODEL_LOAD=1 set - skipping model load (mock mode)" | |
| ) | |
| else: | |
| # Load the model using the new model loader | |
| load_model() | |
| model_loaded = True | |
| logger.info("Dots.OCR model loaded successfully") | |
| # Log model information | |
| model_info = get_model_info() | |
| logger.info(f"Model info: {model_info}") | |
| except Exception as e: | |
| logger.error(f"Failed to load Dots.OCR model: {e}") | |
| # Don't raise - allow mock mode for development | |
| model_loaded = False | |
| logger.warning("Model loading failed - using mock implementation") | |
| yield | |
| logger.info("Shutting down Dots.OCR endpoint...") | |
| app = FastAPI( | |
| title="KYB Dots.OCR Text Extraction", | |
| description="Dots.OCR for identity document text extraction with ROI support", | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| async def root(): | |
| """Root route for uptime checks.""" | |
| return {"status": "ok"} | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| global model_loaded | |
| status = "healthy" if model_loaded else "degraded" | |
| model_info = get_model_info() if model_loaded else None | |
| return { | |
| "status": status, | |
| "version": "1.0.0", | |
| "model_loaded": model_loaded, | |
| "model_info": model_info, | |
| } | |
| async def extract_text_endpoint( | |
| file: UploadFile = File(..., description="Image or PDF file to process"), | |
| roi: Optional[str] = Form(None, description="ROI coordinates as JSON string"), | |
| debug: Optional[bool] = Form( | |
| None, | |
| description=( | |
| "Enable verbose debug logging for this request. Overrides env when True." | |
| ), | |
| ), | |
| ): | |
| """Extract text from identity document image or PDF.""" | |
| global model_loaded | |
| # Allow mock mode when model isn't loaded to support tests/CI and dev flows | |
| allow_mock = os.getenv("DOTS_OCR_ALLOW_MOCK", "1") == "1" | |
| is_mock_mode = (not model_loaded) and allow_mock | |
| if not model_loaded and not allow_mock: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| # Determine effective debug mode for this request | |
| env_debug = os.getenv("DOTS_OCR_DEBUG", "0").lower() in {"1", "true", "yes"} | |
| debug_enabled = bool(debug) if debug is not None else env_debug | |
| if debug_enabled: | |
| logger.info( | |
| f"[debug] Request {request_id}: debug logging enabled (env={env_debug}, form={debug})" | |
| ) | |
| if is_mock_mode: | |
| logger.warning( | |
| "Using mock mode — OCR text will be empty. To enable real inference, ensure the model loads successfully (unset DOTS_OCR_SKIP_MODEL_LOAD and provide resources)." | |
| ) | |
| start_time = time.time() | |
| request_id = str(uuid.uuid4()) | |
| try: | |
| # Read file data | |
| file_data = await file.read() | |
| # Validate file size | |
| if not validate_file_size(file_data): | |
| raise HTTPException(status_code=413, detail="File size exceeds limit") | |
| # Get document information | |
| doc_info = get_document_info(file_data) | |
| logger.info(f"Processing document: {doc_info}") | |
| # Parse ROI if provided | |
| roi_coords = None | |
| if roi: | |
| try: | |
| roi_data = json.loads(roi) | |
| roi_bbox = BoundingBox(**roi_data) | |
| roi_coords = (roi_bbox.x1, roi_bbox.y1, roi_bbox.x2, roi_bbox.y2) | |
| logger.info(f"Using ROI: {roi_coords}") | |
| except Exception as e: | |
| logger.warning(f"Invalid ROI provided: {e}") | |
| raise HTTPException(status_code=400, detail=f"Invalid ROI format: {e}") | |
| # Process document (PDF to images or single image) | |
| try: | |
| processed_images = process_document(file_data, roi_coords) | |
| logger.info(f"Processed {len(processed_images)} images from document") | |
| except Exception as e: | |
| logger.error(f"Document processing failed: {e}") | |
| raise HTTPException( | |
| status_code=400, detail=f"Document processing failed: {e}" | |
| ) | |
| # Process each image and extract text | |
| ocr_texts = [] | |
| page_metadata = [] | |
| for i, image in enumerate(processed_images): | |
| try: | |
| # Extract text using the loaded model, or produce mock output in mock mode | |
| if is_mock_mode: | |
| # In mock mode, we skip model inference and return empty text | |
| ocr_text = "" | |
| else: | |
| ocr_text = extract_text(image) | |
| logger.info( | |
| f"Page {i + 1} - Extracted text length: {len(ocr_text)} characters" | |
| ) | |
| ocr_texts.append(ocr_text) | |
| # Collect page metadata | |
| page_meta = { | |
| "page_index": i, | |
| "image_size": image.size, | |
| "text_length": len(ocr_text), | |
| "processing_successful": True, | |
| } | |
| page_metadata.append(page_meta) | |
| except Exception as e: | |
| logger.error(f"Text extraction failed for page {i + 1}: {e}") | |
| # Add empty text for failed page | |
| ocr_texts.append("") | |
| page_meta = { | |
| "page_index": i, | |
| "image_size": image.size if hasattr(image, "size") else (0, 0), | |
| "text_length": 0, | |
| "processing_successful": False, | |
| "error": str(e), | |
| } | |
| page_metadata.append(page_meta) | |
| # Determine media type for response | |
| media_type = "pdf" if doc_info["is_pdf"] else "image" | |
| processing_time = time.time() - start_time | |
| # Build response using the response builder | |
| return build_ocr_response( | |
| request_id=request_id, | |
| media_type=media_type, | |
| processing_time=processing_time, | |
| ocr_texts=ocr_texts, | |
| page_metadata=page_metadata, | |
| debug=debug_enabled, | |
| ) | |
| except HTTPException: | |
| # Re-raise HTTP exceptions as-is | |
| raise | |
| except Exception as e: | |
| logger.error(f"OCR extraction failed: {e}") | |
| processing_time = time.time() - start_time | |
| error_response = build_error_response( | |
| request_id=request_id, | |
| error_message=f"OCR extraction failed: {str(e)}", | |
| processing_time=processing_time, | |
| ) | |
| raise HTTPException(status_code=500, detail=error_response.dict()) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |