import math from PIL import Image, ImageDraw, ImageFont import requests import matplotlib.pyplot as plt import ipywidgets as widgets from IPython.display import display, clear_output import torch from torch import nn from torchvision.models import resnet50 import torchvision.transforms as T torch.set_grad_enabled(False); import ssl ssl._create_default_https_context = ssl._create_unverified_context model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True) model.eval(); # standard PyTorch mean-std input image normalization transform = T.Compose([ T.Resize(800), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # COCO classes CLASSES = [ 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' ] # colors for visualization COLORS = [[0, 114, 189], [217, 83, 25], [237, 177, 32], [126, 47, 142], [119, 172, 48], [77, 190, 238]] # for output bounding box post-processing def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=1) def rescale_bboxes(out_bbox, size): img_w, img_h = size b = box_cxcywh_to_xyxy(out_bbox) b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) return b def plot_results2(pil_img, prob, boxes): colors = COLORS * 100 out_img = pil_img.copy() img1 = ImageDraw.Draw(out_img) for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors): shape = [(xmin, ymin), (xmax, ymax)] img1.rectangle(shape, outline=tuple(c), width=2) cl = p.argmax() img1.text((xmin+5, ymin+5),f'{CLASSES[cl]}: {p[cl]:0.2f}',tuple(c)) return out_img import gradio as gr def process_image(im): img = transform(im).unsqueeze(0) # propagate through the model outputs = model(img) # keep only predictions with 0.7+ confidence probas = outputs['pred_logits'].softmax(-1)[0, :, :-1] keep = probas.max(-1).values > 0.9 # convert boxes from [0; 1] to image scales bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size) imout = plot_results2(im, probas[keep], bboxes_scaled) return imout title = "Demo: DETR detection" description = "Demo for Facebooks's DETR: What it is. Unlike traditional computer vision techniques, DETR approaches object detection as a direct set prediction problem. It consists of a set-based global loss, which forces unique predictions via bipartite matching, and a Transformer encoder-decoder architecture. Given a fixed small set of learned object queries, DETR reasons about the relations of the objects and the global image context to directly output the final set of predictions in parallel. Due to this parallel nature, DETR is very fast and efficient." examples =[['cats.jpg']] iface = gr.Interface(fn=process_image, inputs=gr.inputs.Image(type="pil"), outputs=gr.outputs.Image(type="pil", label="DETR detections"), title=title, description=description, examples=examples, enable_queue=True) iface.launch(debug=True)