#!/usr/bin/env python3 """API Endpoint Test Script for Dots.OCR This script tests the deployed Dots.OCR API endpoint using real ID card images. It can be used to validate the complete pipeline in a production environment. """ import os import sys import json import time import requests import logging from pathlib import Path from typing import Dict, Any, Optional, List import argparse # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class DotsOCRAPITester: """Test client for the Dots.OCR API endpoint.""" def __init__(self, base_url: str, timeout: int = 30): """Initialize the API tester. Args: base_url: Base URL of the deployed API (e.g., "http://localhost:7860") timeout: Request timeout in seconds """ self.base_url = base_url.rstrip('/') self.timeout = timeout self.session = requests.Session() # Set common headers self.session.headers.update({ 'User-Agent': 'DotsOCR-APITester/1.0' }) def health_check(self) -> Dict[str, Any]: """Check API health status. Returns: Health check response """ try: response = self.session.get( f"{self.base_url}/health", timeout=self.timeout ) response.raise_for_status() return response.json() except Exception as e: logger.error(f"Health check failed: {e}") return {"error": str(e)} def test_ocr_endpoint( self, image_path: str, roi: Optional[Dict[str, float]] = None, expected_fields: Optional[List[str]] = None ) -> Dict[str, Any]: """Test the OCR endpoint with an image file. Args: image_path: Path to the image file roi: Optional ROI coordinates as {x1, y1, x2, y2} expected_fields: List of expected field names to validate Returns: Test results dictionary """ logger.info(f"Testing OCR endpoint with {image_path}") # Prepare files and data files = {'file': open(image_path, 'rb')} data = {} if roi: data['roi'] = json.dumps(roi) logger.info(f"Using ROI: {roi}") try: # Make request start_time = time.time() response = self.session.post( f"{self.base_url}/v1/id/ocr", files=files, data=data, timeout=self.timeout ) request_time = time.time() - start_time # Close file files['file'].close() # Check response response.raise_for_status() result = response.json() # Validate response structure validation_result = self._validate_response(result) # Check expected fields field_validation = self._validate_expected_fields(result, expected_fields) return { "success": True, "request_time": request_time, "response": result, "validation": validation_result, "field_validation": field_validation, "status_code": response.status_code } except requests.exceptions.RequestException as e: logger.error(f"Request failed: {e}") return { "success": False, "error": str(e), "status_code": getattr(e.response, 'status_code', None) } except Exception as e: logger.error(f"Unexpected error: {e}") return { "success": False, "error": str(e) } finally: # Ensure file is closed if 'file' in locals(): files['file'].close() def _validate_response(self, response: Dict[str, Any]) -> Dict[str, Any]: """Validate the API response structure. Args: response: API response dictionary Returns: Validation results """ validation = { "valid": True, "errors": [], "warnings": [] } # Required fields required_fields = ['request_id', 'media_type', 'processing_time', 'detections'] for field in required_fields: if field not in response: validation["errors"].append(f"Missing required field: {field}") validation["valid"] = False # Validate detections if 'detections' in response: if not isinstance(response['detections'], list): validation["errors"].append("detections must be a list") validation["valid"] = False else: for i, detection in enumerate(response['detections']): if not isinstance(detection, dict): validation["errors"].append(f"detection {i} must be a dictionary") validation["valid"] = False else: # Check for extracted_fields if 'extracted_fields' not in detection: validation["warnings"].append(f"detection {i} missing extracted_fields") if 'mrz_data' not in detection: validation["warnings"].append(f"detection {i} missing mrz_data") # Validate processing time if 'processing_time' in response: if not isinstance(response['processing_time'], (int, float)): validation["errors"].append("processing_time must be a number") validation["valid"] = False elif response['processing_time'] < 0: validation["warnings"].append("processing_time is negative") return validation def _validate_expected_fields( self, response: Dict[str, Any], expected_fields: Optional[List[str]] ) -> Dict[str, Any]: """Validate that expected fields are present in the response. Args: response: API response dictionary expected_fields: List of expected field names Returns: Field validation results """ if not expected_fields: return {"valid": True, "found_fields": [], "missing_fields": []} found_fields = [] missing_fields = [] # Check all detections for fields for i, detection in enumerate(response.get('detections', [])): extracted_fields = detection.get('extracted_fields', {}) for field_name in expected_fields: if field_name in extracted_fields and extracted_fields[field_name] is not None: found_fields.append(f"{field_name} (detection {i})") else: missing_fields.append(f"{field_name} (detection {i})") return { "valid": len(missing_fields) == 0, "found_fields": found_fields, "missing_fields": missing_fields } def test_multiple_images( self, image_paths: List[str], roi: Optional[Dict[str, float]] = None ) -> Dict[str, Any]: """Test multiple images and return aggregated results. Args: image_paths: List of image file paths roi: Optional ROI coordinates Returns: Aggregated test results """ logger.info(f"Testing {len(image_paths)} images") results = [] successful_tests = 0 total_processing_time = 0 for image_path in image_paths: if not os.path.exists(image_path): logger.warning(f"Image not found: {image_path}") results.append({ "image": image_path, "success": False, "error": "File not found" }) continue result = self.test_ocr_endpoint(image_path, roi) results.append({ "image": image_path, **result }) if result.get("success", False): successful_tests += 1 total_processing_time += result.get("request_time", 0) return { "total_images": len(image_paths), "successful_tests": successful_tests, "failed_tests": len(image_paths) - successful_tests, "success_rate": successful_tests / len(image_paths) if image_paths else 0, "average_processing_time": total_processing_time / successful_tests if successful_tests > 0 else 0, "results": results } def main(): """Main test function.""" parser = argparse.ArgumentParser(description="Test Dots.OCR API endpoint") parser.add_argument( "--url", default="http://localhost:7860", help="API base URL (default: http://localhost:7860)" ) parser.add_argument( "--timeout", type=int, default=30, help="Request timeout in seconds (default: 30)" ) parser.add_argument( "--roi", type=str, help="ROI coordinates as JSON string (e.g., '{\"x1\": 0.1, \"y1\": 0.1, \"x2\": 0.9, \"y2\": 0.9}')" ) parser.add_argument( "--expected-fields", nargs="+", help="Expected field names to validate (e.g., document_number surname given_names)" ) parser.add_argument( "--verbose", action="store_true", help="Enable verbose logging" ) args = parser.parse_args() if args.verbose: logging.getLogger().setLevel(logging.DEBUG) # Parse ROI if provided roi = None if args.roi: try: roi = json.loads(args.roi) except json.JSONDecodeError as e: logger.error(f"Invalid ROI JSON: {e}") sys.exit(1) # Initialize tester tester = DotsOCRAPITester(args.url, args.timeout) # Health check logger.info("🔍 Checking API health...") health = tester.health_check() if "error" in health: logger.error(f"❌ API health check failed: {health['error']}") sys.exit(1) logger.info(f"✅ API is healthy: {health}") # Test images test_images = [ "tom_id_card_front.jpg", "tom_id_card_back.jpg" ] # Check if test images exist existing_images = [] for image in test_images: image_path = Path(__file__).parent / image if image_path.exists(): existing_images.append(str(image_path)) else: logger.warning(f"Test image not found: {image_path}") if not existing_images: logger.error("❌ No test images found") sys.exit(1) # Expected fields for validation expected_fields = args.expected_fields or [ "document_number", "surname", "given_names", "nationality", "date_of_birth", "gender" ] # Run tests logger.info(f"🚀 Starting API tests with {len(existing_images)} images...") if len(existing_images) == 1: # Single image test result = tester.test_ocr_endpoint(existing_images[0], roi, expected_fields) if result["success"]: logger.info("✅ Single image test passed") logger.info(f"⏱️ Processing time: {result['request_time']:.2f}s") logger.info(f"📄 Detections: {len(result['response']['detections'])}") # Print field validation results field_validation = result.get("field_validation", {}) if field_validation.get("found_fields"): logger.info(f"✅ Found fields: {', '.join(field_validation['found_fields'])}") if field_validation.get("missing_fields"): logger.warning(f"⚠️ Missing fields: {', '.join(field_validation['missing_fields'])}") else: logger.error(f"❌ Single image test failed: {result.get('error', 'Unknown error')}") sys.exit(1) else: # Multiple images test results = tester.test_multiple_images(existing_images, roi) logger.info(f"📊 Test Results:") logger.info(f" Total images: {results['total_images']}") logger.info(f" Successful: {results['successful_tests']}") logger.info(f" Failed: {results['failed_tests']}") logger.info(f" Success rate: {results['success_rate']:.1%}") logger.info(f" Average processing time: {results['average_processing_time']:.2f}s") # Print detailed results for result in results["results"]: image_name = Path(result["image"]).name if result["success"]: logger.info(f" ✅ {image_name}: {result['request_time']:.2f}s") else: logger.error(f" ❌ {image_name}: {result.get('error', 'Unknown error')}") if results["failed_tests"] > 0: sys.exit(1) logger.info("🎉 All tests completed successfully!") if __name__ == "__main__": main()