detr / app.py
jdavidaguil's picture
Update app.py
7266b1d
raw
history blame contribute delete
No virus
4.2 kB
import math
from PIL import Image, ImageDraw, ImageFont
import requests
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
torch.set_grad_enabled(False);
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
model.eval();
# standard PyTorch mean-std input image normalization
transform = T.Compose([
T.Resize(800),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# COCO classes
CLASSES = [
'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
'toothbrush'
]
# colors for visualization
COLORS = [[0, 114, 189], [217, 83, 25], [237, 177, 32],
[126, 47, 142], [119, 172, 48], [77, 190, 238]]
# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
def plot_results2(pil_img, prob, boxes):
colors = COLORS * 100
out_img = pil_img.copy()
img1 = ImageDraw.Draw(out_img)
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
shape = [(xmin, ymin), (xmax, ymax)]
img1.rectangle(shape, outline=tuple(c), width=2)
cl = p.argmax()
img1.text((xmin+5, ymin+5),f'{CLASSES[cl]}: {p[cl]:0.2f}',tuple(c))
return out_img
import gradio as gr
def process_image(im):
img = transform(im).unsqueeze(0)
# propagate through the model
outputs = model(img)
# keep only predictions with 0.7+ confidence
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9
# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
imout = plot_results2(im, probas[keep], bboxes_scaled)
return imout
title = "Demo: DETR detection"
description = "Demo for Facebooks's DETR: What it is. Unlike traditional computer vision techniques, DETR approaches object detection as a direct set prediction problem. It consists of a set-based global loss, which forces unique predictions via bipartite matching, and a Transformer encoder-decoder architecture. Given a fixed small set of learned object queries, DETR reasons about the relations of the objects and the global image context to directly output the final set of predictions in parallel. Due to this parallel nature, DETR is very fast and efficient."
examples =[['cats.jpg']]
iface = gr.Interface(fn=process_image,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.outputs.Image(type="pil", label="DETR detections"),
title=title,
description=description,
examples=examples,
enable_queue=True)
iface.launch(debug=True)