Spaces:
Runtime error
Runtime error
| from PIL import Image | |
| import requests | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as T | |
| import os | |
| import random | |
| import cv2 | |
| import DETR.util.misc as utils | |
| from DETR.models import build_model | |
| from DETR.modules.ExplanationGenerator import Generator | |
| import argparse | |
| class Namespace: | |
| def __init__(self, **kwargs): | |
| self.__dict__.update(kwargs) | |
| # COCO classes | |
| CLASSES = [ | |
| 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', | |
| 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', | |
| 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', | |
| 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', | |
| 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', | |
| 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', | |
| 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', | |
| 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', | |
| 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', | |
| 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', | |
| 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', | |
| 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', | |
| 'toothbrush' | |
| ] | |
| # colors for visualization | |
| COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], | |
| [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] | |
| # standard PyTorch mean-std input image normalization | |
| transform = T.Compose([ | |
| T.Resize(800), | |
| T.ToTensor(), | |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # for output bounding box post-processing | |
| def box_cxcywh_to_xyxy(x): | |
| x_c, y_c, w, h = x.unbind(1) | |
| b = [(x_c - 0.5 * w), (y_c - 0.5 * h), | |
| (x_c + 0.5 * w), (y_c + 0.5 * h)] | |
| return torch.stack(b, dim=1) | |
| def rescale_bboxes(out_bbox, size): | |
| img_w, img_h = size | |
| b = box_cxcywh_to_xyxy(out_bbox) | |
| b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) | |
| return b | |
| def plot_results(pil_img, prob, boxes): | |
| plt.figure(figsize=(16, 10)) | |
| plt.imshow(pil_img) | |
| ax = plt.gca() | |
| colors = COLORS * 100 | |
| for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors): | |
| ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, | |
| fill=False, color=c, linewidth=3)) | |
| cl = p.argmax() | |
| text = f'{CLASSES[cl]}: {p[cl]:0.2f}' | |
| ax.text(xmin, ymin, text, fontsize=15, | |
| bbox=dict(facecolor='yellow', alpha=0.5)) | |
| plt.axis('off') | |
| plt.show() | |
| device = 'cpu' | |
| args = Namespace(aux_loss=True, backbone='resnet50', batch_size=2, bbox_loss_coef=5, clip_max_norm=0.1, | |
| coco_panoptic_path=None, coco_path=None, dataset_file='coco', dec_layers=6, device='cpu', | |
| dice_loss_coef=1, dilation=False, dim_feedforward=2048, dist_url='env://', distributed=False, | |
| dropout=0.1, enc_layers=6, eos_coef=0.1, epochs=300, eval=False, frozen_weights=None, giou_loss_coef=2, | |
| hidden_dim=256, lr=0.0001, lr_backbone=1e-05, lr_drop=200, mask_loss_coef=1, masks=False, nheads=8, | |
| num_queries=100, num_workers=2, output_dir='', position_embedding='sine', pre_norm=False, | |
| remove_difficult=False, resume='https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth', seed=42, | |
| set_cost_bbox=5, set_cost_class=1, set_cost_giou=2, start_epoch=0, weight_decay=0.0001, world_size=1) | |
| model, criterion, postprocessors = build_model(args) | |
| model.to(device) | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| args.resume, map_location='cpu', check_hash=True) | |
| model.load_state_dict(checkpoint['model'], strict=False) | |
| gen = Generator(model) | |
| def evaluate(im, device, image_id=None): | |
| # mean-std normalize the input image (batch-size: 1) | |
| im1 = transform(im) | |
| img = transform(im).unsqueeze(0).to(device) | |
| # propagate through the model | |
| outputs = model(img) | |
| # keep only predictions with 0.7+ confidence | |
| probas = outputs['pred_logits'].softmax(-1)[0, :, :-1] | |
| keep = probas.max(-1).values > 0.9 | |
| keep = keep.cpu() | |
| if keep.nonzero().shape[0] <= 1: | |
| return | |
| outputs['pred_boxes'] = outputs['pred_boxes'].cpu() | |
| # convert boxes from [0; 1] to image scales | |
| bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size) | |
| # use lists to store the outputs via up-values | |
| conv_features, enc_attn_weights, dec_attn_weights = [], [], [] | |
| hooks = [ | |
| model.backbone[-2].register_forward_hook( | |
| lambda self, input, output: conv_features.append(output) | |
| ), | |
| # model.transformer.encoder.layers[-1].self_attn.register_forward_hook( | |
| # lambda self, input, output: enc_attn_weights.append(output[1]) | |
| # ), | |
| model.transformer.decoder.layers[-1].multihead_attn.register_forward_hook( | |
| lambda self, input, output: dec_attn_weights.append(output[1]) | |
| ), | |
| ] | |
| for layer in model.transformer.encoder.layers: | |
| hook = layer.self_attn.register_forward_hook( | |
| lambda self, input, output: enc_attn_weights.append(output[1]) | |
| ) | |
| hooks.append(hook) | |
| model(img) | |
| for hook in hooks: | |
| hook.remove() | |
| # don't need the list anymore | |
| conv_features = conv_features[0] | |
| enc_attn_weights = enc_attn_weights[-1] | |
| dec_attn_weights = dec_attn_weights[0] | |
| # get the feature map shape | |
| h, w = conv_features['0'].tensors.shape[-2:] | |
| img_np = np.array(im).astype(float) | |
| fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=2, figsize=(22, 7)) | |
| for idx, ax_i, (xmin, ymin, xmax, ymax) in zip(keep.nonzero(), axs.T, bboxes_scaled): | |
| ax = ax_i[0] | |
| cam = gen.generate_ours(img, idx, use_lrp=True) | |
| cam = (cam - cam.min()) / (cam.max() - cam.min()) | |
| # cmap = plt.cm.get_cmap('Blues').reversed() | |
| ax.imshow(cam.view(h, w).data.cpu().numpy()) | |
| ax.axis('off') | |
| # ax.set_title(f'query id: {idx.item()}') | |
| ax = ax_i[1] | |
| ax.imshow(im) | |
| ax.add_patch( | |
| plt.Rectangle((xmin.detach(), ymin.detach()), xmax.detach() - xmin.detach(), ymax.detach() - ymin.detach(), | |
| fill=False, color='blue', linewidth=3)) | |
| ax.axis('off') | |
| ax.set_title(CLASSES[probas[idx].argmax()]) | |
| id_str = '' if image_id == None else image_id | |
| fig.tight_layout() | |
| plt.savefig('detr.png') | |
| return fig | |