Spaces:
Running
Running
| import os | |
| import io | |
| import base64 | |
| from pathlib import Path | |
| from typing import Optional, List | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="OmniParser-v2.0 API", | |
| description="Extract UI elements and cursor coordinates from screenshots", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global OmniParser model (lazy loaded) | |
| omni_parser = None | |
| class ParseRequest(BaseModel): | |
| """Request model for UI parsing""" | |
| image_base64: str | |
| extract_text: bool = True | |
| extract_icons: bool = True | |
| class UIElement(BaseModel): | |
| """Model for UI element""" | |
| element_id: int | |
| label: str | |
| bbox: List[int] # [x1, y1, x2, y2] | |
| element_type: str | |
| confidence: float | |
| class ParseResponse(BaseModel): | |
| """Response model for parsing results""" | |
| elements: List[UIElement] | |
| image_width: int | |
| image_height: int | |
| processing_time: float | |
| model_used: str = "OmniParser-v2.0" | |
| def load_omniparser(): | |
| """Load OmniParser model (lazy loading)""" | |
| global omni_parser | |
| if omni_parser is None: | |
| try: | |
| logger.info("Loading OmniParser-v2.0 from HuggingFace...") | |
| # Import and initialize OmniParser | |
| # For now, we'll use a placeholder that demonstrates the structure | |
| # You can replace this with actual OmniParser initialization | |
| omni_parser = { | |
| "loaded": True, | |
| "model_name": "microsoft/OmniParser-v2.0" | |
| } | |
| logger.info("OmniParser loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load OmniParser: {e}") | |
| raise | |
| return omni_parser | |
| def extract_image_from_base64(image_base64: str) -> Image.Image: | |
| """Decode base64 image""" | |
| try: | |
| image_data = base64.b64decode(image_base64) | |
| image = Image.open(io.BytesIO(image_data)) | |
| return image | |
| except Exception as e: | |
| raise ValueError(f"Failed to decode image: {e}") | |
| def parse_ui_elements(image: Image.Image) -> List[UIElement]: | |
| """Parse UI elements from image using OmniParser""" | |
| try: | |
| # Load model | |
| load_omniparser() | |
| # Placeholder implementation - replace with actual OmniParser logic | |
| logger.info(f"Processing image of size: {image.size}") | |
| # For demonstration, create mock UI elements | |
| # Replace this with actual OmniParser parsing logic | |
| elements = [ | |
| UIElement( | |
| element_id=1, | |
| label="Button", | |
| bbox=[10, 10, 100, 50], | |
| element_type="button", | |
| confidence=0.95 | |
| ), | |
| UIElement( | |
| element_id=2, | |
| label="Search", | |
| bbox=[150, 10, 400, 50], | |
| element_type="textfield", | |
| confidence=0.92 | |
| ), | |
| ] | |
| return elements | |
| except Exception as e: | |
| logger.error(f"Error parsing UI elements: {e}") | |
| raise | |
| async def root(): | |
| """Root endpoint""" | |
| return { | |
| "message": "OmniParser-v2.0 API", | |
| "status": "running", | |
| "endpoints": [ | |
| "/docs - API documentation", | |
| "/health - Health check", | |
| "/parse - Parse UI elements from screenshot" | |
| ] | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| try: | |
| load_omniparser() | |
| return {"status": "healthy", "model": "OmniParser-v2.0"} | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=503, | |
| content={"status": "unhealthy", "error": str(e)} | |
| ) | |
| async def parse_screenshot(file: UploadFile = File(...)): | |
| """ | |
| Parse UI elements from a screenshot. | |
| - **file**: Image file (PNG, JPG, etc.) | |
| Returns UI elements with bounding boxes and cursor coordinates. | |
| """ | |
| try: | |
| import time | |
| start_time = time.time() | |
| # Read uploaded file | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| # Parse UI elements | |
| elements = parse_ui_elements(image) | |
| # Calculate processing time | |
| processing_time = time.time() - start_time | |
| return ParseResponse( | |
| elements=elements, | |
| image_width=image.width, | |
| image_height=image.height, | |
| processing_time=processing_time | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in parse endpoint: {e}") | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| async def parse_base64(request: ParseRequest): | |
| """ | |
| Parse UI elements from base64-encoded image. | |
| Request body: | |
| - **image_base64**: Base64-encoded image string | |
| - **extract_text**: Extract text from elements (default: True) | |
| - **extract_icons**: Extract icons (default: True) | |
| """ | |
| try: | |
| import time | |
| start_time = time.time() | |
| # Decode image | |
| image = extract_image_from_base64(request.image_base64) | |
| # Parse UI elements | |
| elements = parse_ui_elements(image) | |
| # Calculate processing time | |
| processing_time = time.time() - start_time | |
| return ParseResponse( | |
| elements=elements, | |
| image_width=image.width, | |
| image_height=image.height, | |
| processing_time=processing_time | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in parse-base64 endpoint: {e}") | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |