IHappyPlant's picture
fixes
c92f607
raw
history blame contribute delete
No virus
1.6 kB
from io import BytesIO
from tempfile import NamedTemporaryFile
from fastapi import FastAPI, Response, status, UploadFile
from torchvision.io import read_image
from torchvision.models.detection import (FasterRCNN_ResNet50_FPN_V2_Weights,
fasterrcnn_resnet50_fpn_v2)
from torchvision.transforms.v2.functional import to_pil_image
from torchvision.utils import draw_bounding_boxes
from PIL import Image
app = FastAPI(docs_url='/', title='Test PyTorch COCO Object Detection')
# Step 1: Initialize model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
@app.get('/healthcheck')
async def healthcheck():
return Response(status_code=status.HTTP_200_OK)
@app.post('/detectObjects')
async def infer(image: UploadFile):
with NamedTemporaryFile() as f:
f.write(image.file.read())
f.seek(0)
img = read_image(f.name)
batch = [preprocess(img)]
prediction = model(batch)[0]
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, boxes=prediction["boxes"],
labels=labels,
colors="red",
width=4, font_size=30)
im = to_pil_image(box.detach())
with BytesIO() as bio:
im.save(bio, format='PNG')
return Response(content=bio.getvalue(), media_type='image/png')