Spaces:
Runtime error
Runtime error
# import the necessary packages | |
from utilities import config | |
from utilities import load_model | |
from tensorflow.keras import layers | |
import tensorflow as tf | |
import matplotlib.pyplot as plt | |
import math | |
import gradio as gr | |
# load the models from disk | |
(conv_stem, conv_trunk, conv_attn) = load_model.loader( | |
stem=config.IMAGENETTE_STEM_PATH, | |
trunk=config.IMAGENETTE_TRUNK_PATH, | |
attn=config.IMAGENETTE_ATTN_PATH, | |
) | |
# load labels | |
labels = [ | |
'tench', | |
'english springer', | |
'cassette player', | |
'chain saw', | |
'church', | |
'french horn', | |
'garbage truck', | |
'gas pump', | |
'golf ball', | |
'parachute' | |
] | |
def get_results(image): | |
# resize the image to a 224, 224 dim | |
image = tf.image.convert_image_dtype(image, tf.float32) | |
image = tf.image.resize(image, (224, 224)) | |
image = image[tf.newaxis, ...] | |
# pass through the stem | |
test_x = conv_stem(image) | |
# pass through the trunk | |
test_x = conv_trunk(test_x) | |
# pass through the attention pooling block | |
logits, test_viz_weights = conv_attn(test_x) | |
test_viz_weights = test_viz_weights[tf.newaxis, ...] | |
# reshape the vizualization weights | |
num_patches = tf.shape(test_viz_weights)[-1] | |
height = width = int(math.sqrt(num_patches)) | |
test_viz_weights = layers.Reshape((height, width))(test_viz_weights) | |
index = 0 | |
selected_image = image[index] | |
selected_weight = test_viz_weights[index] | |
img = plt.imshow(selected_image) | |
plt.imshow( | |
selected_weight, | |
cmap="inferno", | |
alpha=0.6, | |
extent=img.get_extent() | |
) | |
plt.axis("off") | |
prediction = tf.nn.softmax(logits, axis=-1) | |
prediction = prediction.numpy()[0] | |
return plt, {labels[i]: float(prediction[i]) for i in range(10)} | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.13692' target='_blank'>Augmenting Convolutional networks with attention-based aggregation</a></p> <center>Contributors: <a href='https://twitter.com/ariG23498'>Aritra Roy Gosthipaty</a>|<a href='https://twitter.com/ritwik_raha'>Ritwik Raha</a>|<a href='https://twitter.com/Cr0wley_zz'>Devjyoti Chakraborty</a></center>" | |
iface = gr.Interface( | |
fn=get_results, | |
title = "Patch ConvNet Demo", | |
description = "This space is a demo of the paper 'Augmenting Convolutional networks with attention-based aggregation' π", | |
article = article, | |
inputs=gr.inputs.Image(label="Input Image"), | |
outputs=[ | |
gr.outputs.Image(label="Attention Map"), | |
gr.outputs.Label(num_top_classes=3, label="Prediction") | |
], | |
examples=[["examples/chainsaw.jpeg"], ["examples/church.jpeg"]], | |
).launch() |