Spaces:
Runtime error
Runtime error
| # import the necessary packages | |
| from utilities import config | |
| from utilities import load_model | |
| from tensorflow.keras import layers | |
| import tensorflow as tf | |
| import matplotlib.pyplot as plt | |
| import math | |
| import gradio as gr | |
| # load the models from disk | |
| (conv_stem, conv_trunk, conv_attn) = load_model.loader( | |
| stem=config.IMAGENETTE_STEM_PATH, | |
| trunk=config.IMAGENETTE_TRUNK_PATH, | |
| attn=config.IMAGENETTE_ATTN_PATH, | |
| ) | |
| # load labels | |
| labels = [ | |
| 'tench', | |
| 'english springer', | |
| 'cassette player', | |
| 'chain saw', | |
| 'church', | |
| 'french horn', | |
| 'garbage truck', | |
| 'gas pump', | |
| 'golf ball', | |
| 'parachute' | |
| ] | |
| def get_results(image): | |
| # resize the image to a 224, 224 dim | |
| image = tf.image.convert_image_dtype(image, tf.float32) | |
| image = tf.image.resize(image, (224, 224)) | |
| image = image[tf.newaxis, ...] | |
| # pass through the stem | |
| test_x = conv_stem(image) | |
| # pass through the trunk | |
| test_x = conv_trunk(test_x) | |
| # pass through the attention pooling block | |
| logits, test_viz_weights = conv_attn(test_x) | |
| test_viz_weights = test_viz_weights[tf.newaxis, ...] | |
| # reshape the vizualization weights | |
| num_patches = tf.shape(test_viz_weights)[-1] | |
| height = width = int(math.sqrt(num_patches)) | |
| test_viz_weights = layers.Reshape((height, width))(test_viz_weights) | |
| index = 0 | |
| selected_image = image[index] | |
| selected_weight = test_viz_weights[index] | |
| img = plt.imshow(selected_image) | |
| plt.imshow( | |
| selected_weight, | |
| cmap="inferno", | |
| alpha=0.6, | |
| extent=img.get_extent() | |
| ) | |
| plt.axis("off") | |
| prediction = tf.nn.softmax(logits, axis=-1) | |
| prediction = prediction.numpy()[0] | |
| return plt, {labels[i]: float(prediction[i]) for i in range(10)} | |
| article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.13692' target='_blank'>Augmenting Convolutional networks with attention-based aggregation</a></p> <center>Contributors: <a href='https://twitter.com/ariG23498'>Aritra Roy Gosthipaty</a>|<a href='https://twitter.com/ritwik_raha'>Ritwik Raha</a>|<a href='https://twitter.com/Cr0wley_zz'>Devjyoti Chakraborty</a></center>" | |
| iface = gr.Interface( | |
| fn=get_results, | |
| title = "Patch ConvNet Demo", | |
| description = "This space is a demo of the paper 'Augmenting Convolutional networks with attention-based aggregation' π", | |
| article = article, | |
| inputs=gr.inputs.Image(label="Input Image"), | |
| outputs=[ | |
| gr.outputs.Image(label="Attention Map"), | |
| gr.outputs.Label(num_top_classes=3, label="Prediction") | |
| ], | |
| examples=[["examples/chainsaw.jpeg"], ["examples/church.jpeg"]], | |
| ).launch() |