from PIL import Image import requests import matplotlib.pyplot as plt import numpy as np import torch import torchvision.transforms as T import os import random import cv2 import DETR.util.misc as utils from DETR.models import build_model from DETR.modules.ExplanationGenerator import Generator import argparse class Namespace: def __init__(self, **kwargs): self.__dict__.update(kwargs) # COCO classes CLASSES = [ 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' ] # colors for visualization COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] # standard PyTorch mean-std input image normalization transform = T.Compose([ T.Resize(800), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # for output bounding box post-processing def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=1) def rescale_bboxes(out_bbox, size): img_w, img_h = size b = box_cxcywh_to_xyxy(out_bbox) b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) return b def plot_results(pil_img, prob, boxes): plt.figure(figsize=(16, 10)) plt.imshow(pil_img) ax = plt.gca() colors = COLORS * 100 for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors): ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3)) cl = p.argmax() text = f'{CLASSES[cl]}: {p[cl]:0.2f}' ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5)) plt.axis('off') plt.show() device = 'cpu' args = Namespace(aux_loss=True, backbone='resnet50', batch_size=2, bbox_loss_coef=5, clip_max_norm=0.1, coco_panoptic_path=None, coco_path=None, dataset_file='coco', dec_layers=6, device='cpu', dice_loss_coef=1, dilation=False, dim_feedforward=2048, dist_url='env://', distributed=False, dropout=0.1, enc_layers=6, eos_coef=0.1, epochs=300, eval=False, frozen_weights=None, giou_loss_coef=2, hidden_dim=256, lr=0.0001, lr_backbone=1e-05, lr_drop=200, mask_loss_coef=1, masks=False, nheads=8, num_queries=100, num_workers=2, output_dir='', position_embedding='sine', pre_norm=False, remove_difficult=False, resume='https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth', seed=42, set_cost_bbox=5, set_cost_class=1, set_cost_giou=2, start_epoch=0, weight_decay=0.0001, world_size=1) model, criterion, postprocessors = build_model(args) model.to(device) checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) model.load_state_dict(checkpoint['model'], strict=False) gen = Generator(model) def evaluate(im, device, image_id=None): # mean-std normalize the input image (batch-size: 1) im1 = transform(im) img = transform(im).unsqueeze(0).to(device) # propagate through the model outputs = model(img) # keep only predictions with 0.7+ confidence probas = outputs['pred_logits'].softmax(-1)[0, :, :-1] keep = probas.max(-1).values > 0.9 keep = keep.cpu() if keep.nonzero().shape[0] <= 1: return outputs['pred_boxes'] = outputs['pred_boxes'].cpu() # convert boxes from [0; 1] to image scales bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size) # use lists to store the outputs via up-values conv_features, enc_attn_weights, dec_attn_weights = [], [], [] hooks = [ model.backbone[-2].register_forward_hook( lambda self, input, output: conv_features.append(output) ), # model.transformer.encoder.layers[-1].self_attn.register_forward_hook( # lambda self, input, output: enc_attn_weights.append(output[1]) # ), model.transformer.decoder.layers[-1].multihead_attn.register_forward_hook( lambda self, input, output: dec_attn_weights.append(output[1]) ), ] for layer in model.transformer.encoder.layers: hook = layer.self_attn.register_forward_hook( lambda self, input, output: enc_attn_weights.append(output[1]) ) hooks.append(hook) model(img) for hook in hooks: hook.remove() # don't need the list anymore conv_features = conv_features[0] enc_attn_weights = enc_attn_weights[-1] dec_attn_weights = dec_attn_weights[0] # get the feature map shape h, w = conv_features['0'].tensors.shape[-2:] img_np = np.array(im).astype(float) fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=2, figsize=(22, 7)) for idx, ax_i, (xmin, ymin, xmax, ymax) in zip(keep.nonzero(), axs.T, bboxes_scaled): ax = ax_i[0] cam = gen.generate_ours(img, idx, use_lrp=True) cam = (cam - cam.min()) / (cam.max() - cam.min()) # cmap = plt.cm.get_cmap('Blues').reversed() ax.imshow(cam.view(h, w).data.cpu().numpy()) ax.axis('off') # ax.set_title(f'query id: {idx.item()}') ax = ax_i[1] ax.imshow(im) ax.add_patch( plt.Rectangle((xmin.detach(), ymin.detach()), xmax.detach() - xmin.detach(), ymax.detach() - ymin.detach(), fill=False, color='blue', linewidth=3)) ax.axis('off') ax.set_title(CLASSES[probas[idx].argmax()]) id_str = '' if image_id == None else image_id fig.tight_layout() plt.savefig('detr.png') return fig