from io import BytesIO from tempfile import NamedTemporaryFile from fastapi import FastAPI, Response, status, UploadFile from torchvision.io import read_image from torchvision.models.detection import (FasterRCNN_ResNet50_FPN_V2_Weights, fasterrcnn_resnet50_fpn_v2) from torchvision.transforms.v2.functional import to_pil_image from torchvision.utils import draw_bounding_boxes from PIL import Image app = FastAPI(docs_url='/', title='Test PyTorch COCO Object Detection') # Step 1: Initialize model with the best available weights weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9) model.eval() # Step 2: Initialize the inference transforms preprocess = weights.transforms() @app.get('/healthcheck') async def healthcheck(): return Response(status_code=status.HTTP_200_OK) @app.post('/detectObjects') async def infer(image: UploadFile): with NamedTemporaryFile() as f: f.write(image.file.read()) f.seek(0) img = read_image(f.name) batch = [preprocess(img)] prediction = model(batch)[0] labels = [weights.meta["categories"][i] for i in prediction["labels"]] box = draw_bounding_boxes(img, boxes=prediction["boxes"], labels=labels, colors="red", width=4, font_size=30) im = to_pil_image(box.detach()) with BytesIO() as bio: im.save(bio, format='PNG') return Response(content=bio.getvalue(), media_type='image/png')