| | """ |
| | Inference module for LearningStudio Callout Detection wrapper. |
| | |
| | This module: |
| | 1. Normalizes input to bytes (handles URLs, data URLs, raw base64) |
| | 2. Gets presigned S3 URL from API Gateway |
| | 3. Uploads image directly to S3 (bypasses API Gateway for large payloads) |
| | 4. Calls API Gateway to start detection job |
| | 5. Polls for completion |
| | 6. Transforms callouts to EMCO format |
| | """ |
| |
|
| | import os |
| | import base64 |
| | import time |
| | import logging |
| | from typing import Dict, Any, List, Optional, Tuple |
| |
|
| | import requests |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | API_GATEWAY_URL = os.environ.get("API_GATEWAY_URL", "") |
| | API_KEY = os.environ.get("API_KEY", "") |
| |
|
| | |
| | MAX_WAIT_SECONDS = 900 |
| | POLL_INTERVAL_SECONDS = 5 |
| |
|
| | |
| | DEFAULT_PARAMS = { |
| | "tiling": {"tile": 2048, "overlap": 0.30}, |
| | "floodfill": {"erase_text": False, "min_fill_vs_text": 0.0}, |
| | "preclean": {"denoise_sw": 8} |
| | } |
| |
|
| |
|
| | def normalize_to_bytes(image_input: str) -> Tuple[bytes, str]: |
| | """ |
| | Normalize image input to bytes. |
| | |
| | Handles: |
| | - HTTP/HTTPS URLs: Downloads image |
| | - Data URLs (data:image/png;base64,...): Decodes base64 |
| | - Raw base64: Decodes to bytes |
| | |
| | Args: |
| | image_input: Image URL, data URL, or base64 string |
| | |
| | Returns: |
| | Tuple of (image_bytes, filename) |
| | """ |
| | |
| | if image_input.startswith(("http://", "https://")): |
| | logger.info(f"Downloading image from URL: {image_input[:100]}...") |
| | response = requests.get(image_input, timeout=60) |
| | response.raise_for_status() |
| |
|
| | |
| | from urllib.parse import urlparse |
| | parsed = urlparse(image_input) |
| | filename = os.path.basename(parsed.path) or "image.png" |
| |
|
| | return response.content, filename |
| |
|
| | |
| | if image_input.startswith("data:"): |
| | |
| | try: |
| | header, encoded = image_input.split(",", 1) |
| | |
| | mime_part = header.split(";")[0].replace("data:", "") |
| | ext = mime_part.split("/")[-1] if "/" in mime_part else "png" |
| | return base64.b64decode(encoded), f"image.{ext}" |
| | except ValueError: |
| | raise ValueError("Invalid data URL format") |
| |
|
| | |
| | try: |
| | return base64.b64decode(image_input), "image.png" |
| | except Exception as e: |
| | raise ValueError(f"Invalid base64 string: {e}") |
| |
|
| |
|
| | def get_upload_url(filename: str = "image.png") -> Dict[str, str]: |
| | """ |
| | Get presigned S3 URL for image upload. |
| | |
| | Args: |
| | filename: Original filename for the image |
| | |
| | Returns: |
| | Dict with job_id, upload_url, s3_url |
| | """ |
| | if not API_GATEWAY_URL or not API_KEY: |
| | raise ValueError( |
| | "API_GATEWAY_URL and API_KEY must be set in environment variables. " |
| | "Configure these in your HF Inference Endpoint secrets." |
| | ) |
| |
|
| | url = f"{API_GATEWAY_URL.rstrip('/')}/upload-url" |
| | headers = {"x-api-key": API_KEY} |
| | params = {"filename": filename} |
| |
|
| | logger.info(f"Getting upload URL from {url}") |
| | response = requests.get(url, headers=headers, params=params, timeout=30) |
| | response.raise_for_status() |
| |
|
| | result = response.json() |
| | logger.info(f"Got upload URL for job_id={result.get('job_id')}") |
| | return result |
| |
|
| |
|
| | def upload_to_s3(upload_url: str, image_bytes: bytes) -> None: |
| | """ |
| | Upload image directly to S3 using presigned URL. |
| | |
| | Args: |
| | upload_url: Presigned PUT URL |
| | image_bytes: Image data to upload |
| | """ |
| | logger.info(f"Uploading {len(image_bytes)} bytes to S3...") |
| | response = requests.put( |
| | upload_url, |
| | data=image_bytes, |
| | headers={"Content-Type": "image/png"}, |
| | timeout=60 |
| | ) |
| | response.raise_for_status() |
| | logger.info("Upload complete") |
| |
|
| |
|
| | def start_detection_job(job_id: str, s3_url: str, params: Optional[Dict] = None) -> str: |
| | """ |
| | Start a detection job via API Gateway. |
| | |
| | Args: |
| | job_id: Job ID from get_upload_url |
| | s3_url: S3 URL from get_upload_url |
| | params: Optional processing parameters |
| | |
| | Returns: |
| | Job ID for polling |
| | """ |
| | url = f"{API_GATEWAY_URL.rstrip('/')}/detect" |
| | headers = { |
| | "x-api-key": API_KEY, |
| | "Content-Type": "application/json" |
| | } |
| | payload = { |
| | "job_id": job_id, |
| | "s3_url": s3_url |
| | } |
| | if params: |
| | payload["params"] = params |
| |
|
| | logger.info(f"Starting detection job {job_id}") |
| | response = requests.post(url, headers=headers, json=payload, timeout=30) |
| | response.raise_for_status() |
| |
|
| | result = response.json() |
| | logger.info(f"Detection job started: {result.get('status')}") |
| | return job_id |
| |
|
| |
|
| | def poll_for_completion(job_id: str) -> Dict[str, Any]: |
| | """ |
| | Poll API Gateway for job completion. |
| | |
| | Args: |
| | job_id: Job ID to poll |
| | |
| | Returns: |
| | Final result with callouts |
| | """ |
| | url = f"{API_GATEWAY_URL.rstrip('/')}/status/{job_id}" |
| | headers = {"x-api-key": API_KEY} |
| |
|
| | elapsed = 0 |
| | while elapsed < MAX_WAIT_SECONDS: |
| | logger.info(f"Polling job {job_id} (elapsed: {elapsed}s)") |
| |
|
| | response = requests.get(url, headers=headers, timeout=30) |
| | response.raise_for_status() |
| |
|
| | result = response.json() |
| | status = result.get("status") |
| |
|
| | if status == "SUCCEEDED": |
| | logger.info(f"Job {job_id} completed successfully") |
| | return result |
| |
|
| | if status in ("FAILED", "TIMED_OUT", "ABORTED"): |
| | error_msg = result.get("error", f"Job {status.lower()}") |
| | logger.error(f"Job {job_id} failed: {error_msg}") |
| | return { |
| | "status": status, |
| | "error": error_msg, |
| | "callouts": [] |
| | } |
| |
|
| | |
| | time.sleep(POLL_INTERVAL_SECONDS) |
| | elapsed += POLL_INTERVAL_SECONDS |
| |
|
| | |
| | logger.error(f"Job {job_id} timed out after {MAX_WAIT_SECONDS}s") |
| | return { |
| | "status": "TIMEOUT", |
| | "error": f"Timeout waiting for results after {MAX_WAIT_SECONDS}s", |
| | "callouts": [] |
| | } |
| |
|
| |
|
| | def transform_to_emco_format( |
| | callouts: List[Dict], |
| | image_base64: str, |
| | image_width: int = 0, |
| | image_height: int = 0 |
| | ) -> Dict[str, Any]: |
| | """ |
| | Transform callouts from Lambda format to EMCO format. |
| | |
| | Lambda format: |
| | {"bbox": [x, y, w, h], "score": 0.95, ...} # xywh |
| | |
| | EMCO format: |
| | {"bbox": {"x1": x, "y1": y, "x2": x+w, "y2": y+h}, "confidence": 0.95, ...} # xyxy |
| | |
| | Args: |
| | callouts: List of callouts from Lambda |
| | image_base64: Original image as base64 |
| | image_width: Image width |
| | image_height: Image height |
| | |
| | Returns: |
| | EMCO-compatible response dict |
| | """ |
| | predictions = [] |
| |
|
| | for i, callout in enumerate(callouts): |
| | bbox = callout.get("bbox", [0, 0, 0, 0]) |
| |
|
| | |
| | x, y, w, h = bbox[0], bbox[1], bbox[2], bbox[3] |
| |
|
| | prediction = { |
| | "id": i + 1, |
| | "label": "callout", |
| | "class_id": 0, |
| | "confidence": callout.get("score", callout.get("confidence", 1.0)), |
| | "bbox": { |
| | "x1": int(x), |
| | "y1": int(y), |
| | "x2": int(x + w), |
| | "y2": int(y + h) |
| | } |
| | } |
| |
|
| | |
| | if "text" in callout: |
| | prediction["text"] = callout["text"] |
| |
|
| | predictions.append(prediction) |
| |
|
| | return { |
| | "predictions": predictions, |
| | "total_detections": len(predictions), |
| | "image": image_base64, |
| | "image_width": image_width, |
| | "image_height": image_height |
| | } |
| |
|
| |
|
| | def inference(image_input: str, parameters: Optional[Dict] = None) -> Dict[str, Any]: |
| | """ |
| | Run inference on an image. |
| | |
| | This is the main entry point for the HF wrapper. |
| | |
| | Flow: |
| | 1. Normalize input to bytes |
| | 2. Get presigned S3 URL |
| | 3. Upload image directly to S3 |
| | 4. Start detection job (small JSON payload) |
| | 5. Poll for completion |
| | 6. Transform results to EMCO format |
| | |
| | Args: |
| | image_input: Image URL, data URL, or base64 string |
| | parameters: Optional processing parameters |
| | |
| | Returns: |
| | EMCO-compatible response with predictions |
| | """ |
| | try: |
| | |
| | logger.info("Normalizing input...") |
| | image_bytes, filename = normalize_to_bytes(image_input) |
| |
|
| | |
| | image_base64 = base64.b64encode(image_bytes).decode("utf-8") |
| |
|
| | |
| | logger.info("Getting upload URL...") |
| | upload_info = get_upload_url(filename) |
| | job_id = upload_info["job_id"] |
| | upload_url = upload_info["upload_url"] |
| | s3_url = upload_info["s3_url"] |
| |
|
| | |
| | logger.info("Uploading to S3...") |
| | upload_to_s3(upload_url, image_bytes) |
| |
|
| | |
| | logger.info("Starting detection job...") |
| | merged_params = {**DEFAULT_PARAMS, **(parameters or {})} |
| | start_detection_job(job_id, s3_url, merged_params) |
| |
|
| | |
| | logger.info("Polling for completion...") |
| | result = poll_for_completion(job_id) |
| |
|
| | |
| | if result.get("status") in ("FAILED", "TIMED_OUT", "ABORTED", "TIMEOUT"): |
| | return { |
| | "error": result.get("error", "Unknown error"), |
| | "predictions": [], |
| | "total_detections": 0, |
| | "image": image_base64 |
| | } |
| |
|
| | |
| | logger.info("Transforming results to EMCO format...") |
| | callouts = result.get("callouts", []) |
| | image_width = result.get("image_width", 0) |
| | image_height = result.get("image_height", 0) |
| |
|
| | return transform_to_emco_format( |
| | callouts, |
| | image_base64, |
| | image_width, |
| | image_height |
| | ) |
| |
|
| | except requests.exceptions.RequestException as e: |
| | logger.error(f"Request error: {e}") |
| | return { |
| | "error": f"Request error: {str(e)}", |
| | "predictions": [], |
| | "total_detections": 0, |
| | "image": "" |
| | } |
| | except ValueError as e: |
| | logger.error(f"Validation error: {e}") |
| | return { |
| | "error": str(e), |
| | "predictions": [], |
| | "total_detections": 0, |
| | "image": "" |
| | } |
| | except Exception as e: |
| | logger.error(f"Unexpected error: {e}", exc_info=True) |
| | return { |
| | "error": f"Unexpected error: {str(e)}", |
| | "predictions": [], |
| | "total_detections": 0, |
| | "image": "" |
| | } |
| |
|