Spaces:
Runtime error
Runtime error
import math | |
from PIL import Image, ImageDraw, ImageFont | |
import requests | |
import matplotlib.pyplot as plt | |
import ipywidgets as widgets | |
from IPython.display import display, clear_output | |
import torch | |
from torch import nn | |
from torchvision.models import resnet50 | |
import torchvision.transforms as T | |
torch.set_grad_enabled(False); | |
import ssl | |
ssl._create_default_https_context = ssl._create_unverified_context | |
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True) | |
model.eval(); | |
# standard PyTorch mean-std input image normalization | |
transform = T.Compose([ | |
T.Resize(800), | |
T.ToTensor(), | |
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# COCO classes | |
CLASSES = [ | |
'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', | |
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', | |
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', | |
'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', | |
'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', | |
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', | |
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', | |
'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', | |
'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', | |
'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', | |
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', | |
'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', | |
'toothbrush' | |
] | |
# colors for visualization | |
COLORS = [[0, 114, 189], [217, 83, 25], [237, 177, 32], | |
[126, 47, 142], [119, 172, 48], [77, 190, 238]] | |
# for output bounding box post-processing | |
def box_cxcywh_to_xyxy(x): | |
x_c, y_c, w, h = x.unbind(1) | |
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), | |
(x_c + 0.5 * w), (y_c + 0.5 * h)] | |
return torch.stack(b, dim=1) | |
def rescale_bboxes(out_bbox, size): | |
img_w, img_h = size | |
b = box_cxcywh_to_xyxy(out_bbox) | |
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) | |
return b | |
def plot_results2(pil_img, prob, boxes): | |
colors = COLORS * 100 | |
out_img = pil_img.copy() | |
img1 = ImageDraw.Draw(out_img) | |
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors): | |
shape = [(xmin, ymin), (xmax, ymax)] | |
img1.rectangle(shape, outline=tuple(c), width=2) | |
cl = p.argmax() | |
img1.text((xmin+5, ymin+5),f'{CLASSES[cl]}: {p[cl]:0.2f}',tuple(c)) | |
return out_img | |
import gradio as gr | |
def process_image(im): | |
img = transform(im).unsqueeze(0) | |
# propagate through the model | |
outputs = model(img) | |
# keep only predictions with 0.7+ confidence | |
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1] | |
keep = probas.max(-1).values > 0.9 | |
# convert boxes from [0; 1] to image scales | |
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size) | |
imout = plot_results2(im, probas[keep], bboxes_scaled) | |
return imout | |
title = "Demo: DETR detection" | |
description = "Demo for Facebooks's DETR: What it is. Unlike traditional computer vision techniques, DETR approaches object detection as a direct set prediction problem. It consists of a set-based global loss, which forces unique predictions via bipartite matching, and a Transformer encoder-decoder architecture. Given a fixed small set of learned object queries, DETR reasons about the relations of the objects and the global image context to directly output the final set of predictions in parallel. Due to this parallel nature, DETR is very fast and efficient." | |
examples =[['cats.jpg']] | |
iface = gr.Interface(fn=process_image, | |
inputs=gr.inputs.Image(type="pil"), | |
outputs=gr.outputs.Image(type="pil", label="DETR detections"), | |
title=title, | |
description=description, | |
examples=examples, | |
enable_queue=True) | |
iface.launch(debug=True) | |