wi / main.py
giveaccesstoall's picture
Update main.py
64c0245 verified
from ultralytics import YOLO
from PIL import Image, ImageDraw
import json
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi import FastAPI, Request,File, UploadFile
from io import BytesIO
import uvicorn
from fastapi.responses import ORJSONResponse
import os
# Set environment variables for cache directories
os.environ['HF_HOME'] = '/code/cache/huggingface/hub'
os.environ['TRANSFORMERS_CACHE'] = '/code/cache/images'
app = FastAPI()
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def hello(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/detect")
async def detect(image_file: UploadFile = File(...)):
buf = await image_file.read()
boxes = detect_objects_on_image(Image.open(BytesIO(buf)))
return ORJSONResponse(boxes)
def detect_objects_on_image(image):
model = YOLO("YOLO_WEIGHTS")
results = model.predict(image)
result = results[0]
output = []
for box in result.boxes:
x1, y1, x2, y2 = [round(x) for x in box.xyxy[0].tolist()]
class_id = box.cls[0].item()
prob = round(box.conf[0].item(), 2)
output.append([x1, y1, x2, y2, result.names[class_id], prob])
return output