omni / main.py
Samfredoly's picture
Upload 14 files
2a729e6 verified
import os
import io
import base64
from pathlib import Path
from typing import Optional, List
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import cv2
import numpy as np
from PIL import Image
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="OmniParser-v2.0 API",
description="Extract UI elements and cursor coordinates from screenshots",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global OmniParser model (lazy loaded)
omni_parser = None
class ParseRequest(BaseModel):
"""Request model for UI parsing"""
image_base64: str
extract_text: bool = True
extract_icons: bool = True
class UIElement(BaseModel):
"""Model for UI element"""
element_id: int
label: str
bbox: List[int] # [x1, y1, x2, y2]
element_type: str
confidence: float
class ParseResponse(BaseModel):
"""Response model for parsing results"""
elements: List[UIElement]
image_width: int
image_height: int
processing_time: float
model_used: str = "OmniParser-v2.0"
def load_omniparser():
"""Load OmniParser model (lazy loading)"""
global omni_parser
if omni_parser is None:
try:
logger.info("Loading OmniParser-v2.0 from HuggingFace...")
# Import and initialize OmniParser
# For now, we'll use a placeholder that demonstrates the structure
# You can replace this with actual OmniParser initialization
omni_parser = {
"loaded": True,
"model_name": "microsoft/OmniParser-v2.0"
}
logger.info("OmniParser loaded successfully")
except Exception as e:
logger.error(f"Failed to load OmniParser: {e}")
raise
return omni_parser
def extract_image_from_base64(image_base64: str) -> Image.Image:
"""Decode base64 image"""
try:
image_data = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_data))
return image
except Exception as e:
raise ValueError(f"Failed to decode image: {e}")
def parse_ui_elements(image: Image.Image) -> List[UIElement]:
"""Parse UI elements from image using OmniParser"""
try:
# Load model
load_omniparser()
# Placeholder implementation - replace with actual OmniParser logic
logger.info(f"Processing image of size: {image.size}")
# For demonstration, create mock UI elements
# Replace this with actual OmniParser parsing logic
elements = [
UIElement(
element_id=1,
label="Button",
bbox=[10, 10, 100, 50],
element_type="button",
confidence=0.95
),
UIElement(
element_id=2,
label="Search",
bbox=[150, 10, 400, 50],
element_type="textfield",
confidence=0.92
),
]
return elements
except Exception as e:
logger.error(f"Error parsing UI elements: {e}")
raise
@app.get("/")
async def root():
"""Root endpoint"""
return {
"message": "OmniParser-v2.0 API",
"status": "running",
"endpoints": [
"/docs - API documentation",
"/health - Health check",
"/parse - Parse UI elements from screenshot"
]
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
try:
load_omniparser()
return {"status": "healthy", "model": "OmniParser-v2.0"}
except Exception as e:
return JSONResponse(
status_code=503,
content={"status": "unhealthy", "error": str(e)}
)
@app.post("/parse", response_model=ParseResponse)
async def parse_screenshot(file: UploadFile = File(...)):
"""
Parse UI elements from a screenshot.
- **file**: Image file (PNG, JPG, etc.)
Returns UI elements with bounding boxes and cursor coordinates.
"""
try:
import time
start_time = time.time()
# Read uploaded file
contents = await file.read()
image = Image.open(io.BytesIO(contents))
# Parse UI elements
elements = parse_ui_elements(image)
# Calculate processing time
processing_time = time.time() - start_time
return ParseResponse(
elements=elements,
image_width=image.width,
image_height=image.height,
processing_time=processing_time
)
except Exception as e:
logger.error(f"Error in parse endpoint: {e}")
raise HTTPException(status_code=400, detail=str(e))
@app.post("/parse-base64", response_model=ParseResponse)
async def parse_base64(request: ParseRequest):
"""
Parse UI elements from base64-encoded image.
Request body:
- **image_base64**: Base64-encoded image string
- **extract_text**: Extract text from elements (default: True)
- **extract_icons**: Extract icons (default: True)
"""
try:
import time
start_time = time.time()
# Decode image
image = extract_image_from_base64(request.image_base64)
# Parse UI elements
elements = parse_ui_elements(image)
# Calculate processing time
processing_time = time.time() - start_time
return ParseResponse(
elements=elements,
image_width=image.width,
image_height=image.height,
processing_time=processing_time
)
except Exception as e:
logger.error(f"Error in parse-base64 endpoint: {e}")
raise HTTPException(status_code=400, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)