Spaces:
Sleeping
Sleeping
File size: 6,073 Bytes
a1c932a d9307fe a1c932a 056fb25 3fa83fd b7a4dbf d03c47c d0b9c8a a1c932a 056fb25 ea2ade6 056fb25 7ecde71 056fb25 9c616dc 056fb25 d0b9c8a 056fb25 d9307fe ea2ade6 056fb25 beccd45 056fb25 beccd45 ea2ade6 355ca17 056fb25 beccd45 ea2ade6 056fb25 a1c932a d9307fe a1c932a d0b9c8a 355ca17 d0b9c8a d9307fe a1c932a d9307fe d0b9c8a d9307fe d0b9c8a d9307fe d0b9c8a d9307fe d0b9c8a d9307fe d0b9c8a d9307fe d0b9c8a 355ca17 d0b9c8a 7ecde71 d9307fe 7ecde71 d0b9c8a d9307fe b89e6d8 d0b9c8a 7ecde71 d9307fe 7ecde71 d0b9c8a 7ecde71 d0b9c8a 7ecde71 d0b9c8a 7ecde71 d0b9c8a 7ecde71 355ca17 d0b9c8a 355ca17 7ecde71 d0b9c8a 70f32bc b89e6d8 d9307fe d0b9c8a a1c932a d9307fe d0b9c8a b7a4dbf d0b9c8a d9307fe b7a4dbf d9307fe d0b9c8a d9307fe d0b9c8a d9307fe d0b9c8a a1c932a 056fb25 e81868e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
from fastapi import FastAPI, File, UploadFile, HTTPException
from pydantic import BaseModel
import base64
import io
import os
import logging
import gc # Import garbage collector##
from PIL import Image, UnidentifiedImageError
import torch
import asyncio
from utils import (
check_ocr_box,
get_yolo_model,
get_caption_model_processor,
get_som_labeled_img,
)
from transformers import AutoProcessor, AutoModelForCausalLM
# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Load YOLO model
yolo_model = get_yolo_model(model_path="weights/best.pt")
# Handle device placement
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
yolo_model = yolo_model.to(device)
# Load caption model and processor
try:
processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base", trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
).to(device)
except Exception as e:
logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
)
caption_model_processor = {"processor": processor, "model": model}
logger.info("Finished loading models!")
# Initialize FastAPI app
app = FastAPI()
MAX_QUEUE_SIZE = 10 # Set a reasonable limit based on your system capacity
request_queue = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
# Define response model
class ProcessResponse(BaseModel):
image: str # Base64 encoded image
parsed_content_list: str
label_coordinates: str
# Background worker to process queue tasks
async def worker():
while True:
task = await request_queue.get()
try:
await task
except Exception as e:
logger.error(f"Error while processing task: {e}")
finally:
request_queue.task_done()
# Start worker on startup
@app.on_event("startup")
async def startup_event():
logger.info("Starting background worker...")
asyncio.create_task(worker())
# Image processing function with memory cleanup
async def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
try:
# Define save path
image_save_path = "imgs/saved_image_demo.png"
os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
# Save image
image_input.save(image_save_path)
logger.debug(f"Image saved to: {image_save_path}")
# YOLO and caption model inference
box_overlay_ratio = image_input.size[0] / 3200
draw_bbox_config = {
"text_scale": 0.8 * box_overlay_ratio,
"text_thickness": max(int(2 * box_overlay_ratio), 1),
"text_padding": max(int(3 * box_overlay_ratio), 1),
"thickness": max(int(3 * box_overlay_ratio), 1),
}
ocr_bbox_rslt, is_goal_filtered = await asyncio.to_thread(
check_ocr_box,
image_save_path,
display_img=False,
output_bb_format="xyxy",
goal_filtering=None,
easyocr_args={"paragraph": False, "text_threshold": 0.9},
use_paddleocr=True,
)
text, ocr_bbox = ocr_bbox_rslt
dino_labled_img, label_coordinates, parsed_content_list = await asyncio.to_thread(
get_som_labeled_img,
image_save_path,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=iou_threshold,
)
# Convert labeled image to base64
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Join parsed content list
parsed_content_list_str = "\n".join([str(item) for item in parsed_content_list])
response = ProcessResponse(
image=img_str,
parsed_content_list=parsed_content_list_str,
label_coordinates=str(label_coordinates),
)
# **Memory Cleanup**
del image_input, text, ocr_bbox, dino_labled_img, label_coordinates, parsed_content_list
torch.cuda.empty_cache() # Free GPU memory
gc.collect() # Free CPU memory
return response
except Exception as e:
logger.error(f"Error in process function: {e}")
raise HTTPException(status_code=500, detail=f"Failed to process the image: {e}")
# API endpoint for processing images
@app.post("/process_image", response_model=ProcessResponse)
async def process_image(
image_file: UploadFile = File(...),
box_threshold: float = 0.05,
iou_threshold: float = 0.1,
):
try:
# Read image file
contents = await image_file.read()
try:
image_input = Image.open(io.BytesIO(contents)).convert("RGB")
except UnidentifiedImageError:
logger.error("Unsupported image format.")
raise HTTPException(status_code=400, detail="Unsupported image format.")
# Create processing task
task = asyncio.create_task(process(image_input, box_threshold, iou_threshold))
# Add task to queue
await request_queue.put(task)
logger.info(f"Task added to queue. Current queue size: {request_queue.qsize()}")
# Wait for processing to complete
response = await task
return response
except HTTPException as he:
raise he
except Exception as e:
logger.error(f"Error processing image: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")#
|