# import the necessary packages from utilities import config from utilities import load_model from tensorflow.keras import layers import tensorflow as tf import matplotlib.pyplot as plt import math import gradio as gr # load the models from disk (conv_stem, conv_trunk, conv_attn) = load_model.loader( stem=config.IMAGENETTE_STEM_PATH, trunk=config.IMAGENETTE_TRUNK_PATH, attn=config.IMAGENETTE_ATTN_PATH, ) # load labels labels = [ 'tench', 'english springer', 'cassette player', 'chain saw', 'church', 'french horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute' ] def get_results(image): # resize the image to a 224, 224 dim image = tf.image.convert_image_dtype(image, tf.float32) image = tf.image.resize(image, (224, 224)) image = image[tf.newaxis, ...] # pass through the stem test_x = conv_stem(image) # pass through the trunk test_x = conv_trunk(test_x) # pass through the attention pooling block logits, test_viz_weights = conv_attn(test_x) test_viz_weights = test_viz_weights[tf.newaxis, ...] # reshape the vizualization weights num_patches = tf.shape(test_viz_weights)[-1] height = width = int(math.sqrt(num_patches)) test_viz_weights = layers.Reshape((height, width))(test_viz_weights) index = 0 selected_image = image[index] selected_weight = test_viz_weights[index] img = plt.imshow(selected_image) plt.imshow( selected_weight, cmap="inferno", alpha=0.6, extent=img.get_extent() ) plt.axis("off") prediction = tf.nn.softmax(logits, axis=-1) prediction = prediction.numpy()[0] return plt, {labels[i]: float(prediction[i]) for i in range(10)} article = "

Augmenting Convolutional networks with attention-based aggregation

Contributors: Aritra Roy Gosthipaty|Ritwik Raha|Devjyoti Chakraborty
" iface = gr.Interface( fn=get_results, title = "Patch ConvNet Demo", description = "This space is a demo of the paper 'Augmenting Convolutional networks with attention-based aggregation' 👀", article = article, inputs=gr.inputs.Image(label="Input Image"), outputs=[ gr.outputs.Image(label="Attention Map"), gr.outputs.Label(num_top_classes=3, label="Prediction") ], examples=[["examples/chainsaw.jpeg"], ["examples/church.jpeg"]], ).launch()