File size: 1,742 Bytes
4986f6d
 
b245237
 
4986f6d
 
7bf08cb
 
7cfde67
4986f6d
b245237
4986f6d
 
 
 
 
 
 
 
0890b20
 
4f3205b
 
 
0890b20
4f3205b
7cfde67
 
4f3205b
4986f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bf08cb
4986f6d
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import cv2
import numpy as np

from fastapi import APIRouter, File, Response, WebSocket, WebSocketDisconnect
from app.constants import classNames, colors
from app import detector
from mmcv import imfrombytes
from app.custom_mmcv.main import imshow_det_bboxes
from app import logger

router = APIRouter(prefix="/image", tags=["Image"])


@router.post("")
async def handleImageRequest(

    file: bytes = File(...),

    threshold: float = 0.3,

    raw: bool = False,

):
    try:
        img = imfrombytes(file, cv2.IMREAD_COLOR)
        if raw:
            bboxes, labels = inferenceImage(img, threshold, True)
            return {"bboxes": bboxes.tolist(), "labels": labels.tolist()}

        img = inferenceImage(img, threshold, False)
    except Exception as e:
        logger.error(e)
        return Response(content="Failed to read image", status_code=400)

    ret, jpeg = cv2.imencode(".jpg", img)

    if not ret:
        return Response(content="Failed to encode image", status_code=500)
    jpeg_bytes: bytes = jpeg.tobytes()

    return Response(content=jpeg_bytes, media_type="image/jpeg")


def inferenceImage(img, threshold: float, isRaw: bool = False):
    bboxes, labels, _ = detector(img)
    if isRaw:
        removeIndexs = []
        for i, bbox in enumerate(bboxes):
            if bbox[4] < threshold:
                removeIndexs.append(i)

        bboxes = np.delete(bboxes, removeIndexs, axis=0)
        labels = np.delete(labels, removeIndexs)

        return bboxes, labels
    return imshow_det_bboxes(
        img=img,
        bboxes=bboxes,
        labels=labels,
        class_names=classNames,
        colors=colors,
        score_thr=threshold,
    )