Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
import torch | |
from PIL import Image | |
import requests | |
from transformers import DetrImageProcessor | |
from transformers import DetrForObjectDetection | |
from random import choice | |
import matplotlib.pyplot as plt | |
import io | |
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], | |
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] | |
def get_output_figure(pil_img, scores, labels, boxes): | |
plt.figure(figsize=(16, 10)) | |
plt.imshow(pil_img) | |
ax = plt.gca() | |
colors = COLORS * 100 | |
for score, label, (xmin, ymin, xmax, ymax), c in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors): | |
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3)) | |
text = f'{model.config.id2label[label]}: {score:0.2f}' | |
ax.text(xmin, ymin, text, fontsize=15, | |
bbox=dict(facecolor='yellow', alpha=0.5)) | |
plt.axis('off') | |
return plt.gcf() | |
def get_output_attn_figure(image, encoding, results, outputs): | |
# keep only predictions of queries with +0.9 condifence (excluding no-object class) | |
probas = outputs.logits.softmax(-1)[0, :, :-1] | |
keep = probas.max(-1).values > 0.9 | |
bboxes_scaled = results['boxes'] | |
# use lists to store the outputs vis up-values | |
conv_features = [] | |
hooks = [ | |
model.model.backbone.conv_encoder.register_forward_hook( | |
lambda self, input, output: conv_features.append(output) | |
) | |
] | |
# propagate through the model | |
outputs = model(**encoding, output_attentions=True) | |
for hook in hooks: | |
hook.remove() | |
# don't need the list anymore | |
conv_features = conv_features[0] | |
# get cross-attentions weights of last decoder layer - which is of shape (batch_size, num_heads, num_queries, width*height) | |
dec_attn_weights = outputs.cross_attentions[-1] | |
#average them over the 8 heads and detach from graph | |
dec_attn_weights = torch.mean(dec_attn_weights, dim=1).detach() | |
# get the feature map shape | |
h, w = conv_features[-1][0].shape[-2:] | |
fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=2, figsize=(22, 7)) | |
colors = COLORS * 100 | |
for idx, ax_i, box in zip(keep.nonzero(), axs.T, bboxes_scaled): | |
xmin, ymin, xmax, ymax = box.detach().numpy() | |
ax = ax_i[0] | |
ax.imshow(dec_attn_weights[0, idx].view(h, w)) | |
ax.axis('off') | |
ax.set_title(f'query id: {idx.item()}') | |
ax = ax_i[1] | |
ax.imshow(image) | |
ax.add_patch(plt.Rectangle((xmin, ymin), xmax-xmin, ymax - ymin, fill=False, | |
color='blue', linewidth=3)) | |
ax.axis('off') | |
ax.set_title(model.config.id2label[probas[idx].argmax().item()]) | |
fig.tight_layout() | |
return plt.gcf() | |
def detect(image): | |
encoding = processor(image, return_tensors='pt') | |
print(encoding.keys()) | |
with torch.no_grad(): | |
outputs = model(**encoding) | |
width, height = image.size | |
postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9) | |
results = postprocessed_outputs[0] | |
output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes']) | |
buf = io.BytesIO() | |
output_figure.savefig(buf, bbox_inches='tight') | |
buf.seek(0) | |
output_pil_img = Image.open(buf) | |
output_figure_attn = get_output_attn_figure(image, encoding, results, outputs) | |
buf = io.BytesIO() | |
output_figure_attn.savefig(buf, bbox_inches='tight') | |
buf.seek(0) | |
output_pil_img_attn = Image.open(buf) | |
return output_pil_img, output_pil_img_attn | |
with gr.Blocks() as demo: | |
gr.Markdown("# Object detection with DETR") | |
gr.Markdown( | |
""" | |
This applciation uses DETR (DEtection TRansformers) to detect objects on images. | |
You can load an image and see the predictions for the objects detected along with the attention weights. | |
""" | |
) | |
gr.Interface( | |
fn=detect, | |
inputs=gr.Image(label="Input image", type="pil"), | |
outputs=[ | |
gr.Image(label="Output prediction", type="pil"), | |
gr.Image(label="Attention weights", type="pil") | |
] | |
)#.launch() | |
demo.launch(show_error=True) | |