Spaces:
Sleeping
Sleeping
from PIL import Image | |
import requests | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torchvision.transforms as T | |
import os | |
import random | |
import cv2 | |
import DETR.util.misc as utils | |
from DETR.models import build_model | |
from DETR.modules.ExplanationGenerator import Generator | |
import argparse | |
class Namespace: | |
def __init__(self, **kwargs): | |
self.__dict__.update(kwargs) | |
# COCO classes | |
CLASSES = [ | |
'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', | |
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', | |
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', | |
'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', | |
'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', | |
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', | |
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', | |
'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', | |
'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', | |
'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', | |
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', | |
'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', | |
'toothbrush' | |
] | |
# colors for visualization | |
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], | |
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] | |
# standard PyTorch mean-std input image normalization | |
transform = T.Compose([ | |
T.Resize(800), | |
T.ToTensor(), | |
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# for output bounding box post-processing | |
def box_cxcywh_to_xyxy(x): | |
x_c, y_c, w, h = x.unbind(1) | |
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), | |
(x_c + 0.5 * w), (y_c + 0.5 * h)] | |
return torch.stack(b, dim=1) | |
def rescale_bboxes(out_bbox, size): | |
img_w, img_h = size | |
b = box_cxcywh_to_xyxy(out_bbox) | |
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) | |
return b | |
def plot_results(pil_img, prob, boxes): | |
plt.figure(figsize=(16, 10)) | |
plt.imshow(pil_img) | |
ax = plt.gca() | |
colors = COLORS * 100 | |
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors): | |
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, | |
fill=False, color=c, linewidth=3)) | |
cl = p.argmax() | |
text = f'{CLASSES[cl]}: {p[cl]:0.2f}' | |
ax.text(xmin, ymin, text, fontsize=15, | |
bbox=dict(facecolor='yellow', alpha=0.5)) | |
plt.axis('off') | |
plt.show() | |
device = 'cpu' | |
args = Namespace(aux_loss=True, backbone='resnet50', batch_size=2, bbox_loss_coef=5, clip_max_norm=0.1, | |
coco_panoptic_path=None, coco_path=None, dataset_file='coco', dec_layers=6, device='cpu', | |
dice_loss_coef=1, dilation=False, dim_feedforward=2048, dist_url='env://', distributed=False, | |
dropout=0.1, enc_layers=6, eos_coef=0.1, epochs=300, eval=False, frozen_weights=None, giou_loss_coef=2, | |
hidden_dim=256, lr=0.0001, lr_backbone=1e-05, lr_drop=200, mask_loss_coef=1, masks=False, nheads=8, | |
num_queries=100, num_workers=2, output_dir='', position_embedding='sine', pre_norm=False, | |
remove_difficult=False, resume='https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth', seed=42, | |
set_cost_bbox=5, set_cost_class=1, set_cost_giou=2, start_epoch=0, weight_decay=0.0001, world_size=1) | |
model, criterion, postprocessors = build_model(args) | |
model.to(device) | |
checkpoint = torch.hub.load_state_dict_from_url( | |
args.resume, map_location='cpu', check_hash=True) | |
model.load_state_dict(checkpoint['model'], strict=False) | |
gen = Generator(model) | |
def evaluate(im, device, image_id=None): | |
# mean-std normalize the input image (batch-size: 1) | |
im1 = transform(im) | |
img = transform(im).unsqueeze(0).to(device) | |
# propagate through the model | |
outputs = model(img) | |
# keep only predictions with 0.7+ confidence | |
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1] | |
keep = probas.max(-1).values > 0.9 | |
keep = keep.cpu() | |
if keep.nonzero().shape[0] <= 1: | |
return | |
outputs['pred_boxes'] = outputs['pred_boxes'].cpu() | |
# convert boxes from [0; 1] to image scales | |
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size) | |
# use lists to store the outputs via up-values | |
conv_features, enc_attn_weights, dec_attn_weights = [], [], [] | |
hooks = [ | |
model.backbone[-2].register_forward_hook( | |
lambda self, input, output: conv_features.append(output) | |
), | |
# model.transformer.encoder.layers[-1].self_attn.register_forward_hook( | |
# lambda self, input, output: enc_attn_weights.append(output[1]) | |
# ), | |
model.transformer.decoder.layers[-1].multihead_attn.register_forward_hook( | |
lambda self, input, output: dec_attn_weights.append(output[1]) | |
), | |
] | |
for layer in model.transformer.encoder.layers: | |
hook = layer.self_attn.register_forward_hook( | |
lambda self, input, output: enc_attn_weights.append(output[1]) | |
) | |
hooks.append(hook) | |
model(img) | |
for hook in hooks: | |
hook.remove() | |
# don't need the list anymore | |
conv_features = conv_features[0] | |
enc_attn_weights = enc_attn_weights[-1] | |
dec_attn_weights = dec_attn_weights[0] | |
# get the feature map shape | |
h, w = conv_features['0'].tensors.shape[-2:] | |
img_np = np.array(im).astype(float) | |
fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=2, figsize=(22, 7)) | |
for idx, ax_i, (xmin, ymin, xmax, ymax) in zip(keep.nonzero(), axs.T, bboxes_scaled): | |
ax = ax_i[0] | |
cam = gen.generate_ours(img, idx, use_lrp=True) | |
cam = (cam - cam.min()) / (cam.max() - cam.min()) | |
# cmap = plt.cm.get_cmap('Blues').reversed() | |
ax.imshow(cam.view(h, w).data.cpu().numpy()) | |
ax.axis('off') | |
# ax.set_title(f'query id: {idx.item()}') | |
ax = ax_i[1] | |
ax.imshow(im) | |
ax.add_patch( | |
plt.Rectangle((xmin.detach(), ymin.detach()), xmax.detach() - xmin.detach(), ymax.detach() - ymin.detach(), | |
fill=False, color='blue', linewidth=3)) | |
ax.axis('off') | |
ax.set_title(CLASSES[probas[idx].argmax()]) | |
id_str = '' if image_id == None else image_id | |
fig.tight_layout() | |
plt.savefig('detr.png') | |
return fig | |