wound_detect / predict.py
Ani14's picture
Update predict.py
3ff0ea1 verified
raw
history blame
948 Bytes
from fastapi import FastAPI, File, UploadFile
import cv2
import numpy as np
from ultralytics import YOLO
from fastapi.responses import FileResponse
app = FastAPI()
yolo_model_path = 'best.pt'
yolo = YOLO(yolo_model_path)
def detect_wounds(image):
results = yolo(image)
boxes = results[0].boxes.xyxy.tolist()
return boxes
def draw_boxes(image, boxes):
for box in boxes:
x1, y1, x2, y2 = map(int, box)
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
return image
@app.post("/detect")
async def detect(image: UploadFile = File(...)):
image_bytes = await image.read()
image = np.frombuffer(image_bytes, np.uint8)
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
wound_boxes = detect_wounds(image)
image_with_boxes = draw_boxes(image, wound_boxes)
result_path = 'result.jpg'
cv2.imwrite(result_path, image_with_boxes)
return FileResponse(result_path, media_type='image/jpeg')