Spaces:
Runtime error
Runtime error
from io import BytesIO | |
from tempfile import NamedTemporaryFile | |
from fastapi import FastAPI, Response, status, UploadFile | |
from torchvision.io import read_image | |
from torchvision.models.detection import (FasterRCNN_ResNet50_FPN_V2_Weights, | |
fasterrcnn_resnet50_fpn_v2) | |
from torchvision.transforms.v2.functional import to_pil_image | |
from torchvision.utils import draw_bounding_boxes | |
from PIL import Image | |
app = FastAPI(docs_url='/', title='Test PyTorch COCO Object Detection') | |
# Step 1: Initialize model with the best available weights | |
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT | |
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9) | |
model.eval() | |
# Step 2: Initialize the inference transforms | |
preprocess = weights.transforms() | |
async def healthcheck(): | |
return Response(status_code=status.HTTP_200_OK) | |
async def infer(image: UploadFile): | |
with NamedTemporaryFile() as f: | |
f.write(image.file.read()) | |
f.seek(0) | |
img = read_image(f.name) | |
batch = [preprocess(img)] | |
prediction = model(batch)[0] | |
labels = [weights.meta["categories"][i] for i in prediction["labels"]] | |
box = draw_bounding_boxes(img, boxes=prediction["boxes"], | |
labels=labels, | |
colors="red", | |
width=4, font_size=30) | |
im = to_pil_image(box.detach()) | |
with BytesIO() as bio: | |
im.save(bio, format='PNG') | |
return Response(content=bio.getvalue(), media_type='image/png') | |