geninhu's picture
Update app.py
612f6fc
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
import gradio as gr
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras('keras-io/attention_mil')
# functions for inference
IMG_SIZE = 28
# resize the image and it to a float between 0,1
def plot(input_images=None, predictions=None, attention_weights=None):
bag_class = np.argmax(predictions)
bag_class = 'This set of image does not contain number 8' if bag_class == 0 else 'This set of image contains number 8'
# attention_weights = [round(i, 2) for i in attention_weights]
prob_str = f"Each image probability: {attention_weights[0]:.2f}, {attention_weights[1]:.2f}, {attention_weights[2]:.2f}"
if input_images is not None:
figure = plt.figure(figsize=(8, 8))
for j in range(len(input_images)):
image = input_images[j]
figure.add_subplot(1, len(input_images), j + 1)
plt.grid(False)
if attention_weights is not None:
plt.title(f"prob={attention_weights[j]:.2f}")
plt.imshow(np.squeeze(input_images[j]))
return [bag_class, plt.gcf()]
return [bag_class, prob_str]
def preprocess_image(image):
# image = image[:, :, 0]
image = image / 255.0
image = np.expand_dims(image, axis = 0)
return image
def infer(input_images_1, input_images_2, input_images_3):
if (input_images_1 is not None) & (input_images_2 is not None) & (input_images_3 is not None):
# Normalize input data
input_images_1 = preprocess_image(input_images_1)
input_images_2 = preprocess_image(input_images_2)
input_images_3 = preprocess_image(input_images_3)
# Collect info per model.
prediction = model.predict([input_images_1, input_images_2, input_images_3])
prediction = np.squeeze(np.swapaxes(prediction, 1, 0))
intermediate_model = keras.Model(model.input, model.get_layer("alpha").output)
intermediate_predictions = intermediate_model.predict([input_images_1, input_images_2, input_images_3])
attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0))
return plot(
[input_images_1, input_images_2, input_images_3],
predictions=prediction,
attention_weights=attention_weights
)
# get the inputs
input1 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='First image', show_label=True, visible=True)
input2 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='Second image', show_label=True, visible=True)
input3 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='Third image', show_label=True, visible=True)
# the app outputs two segmented images
output = [gr.Label(), gr.Plot()]
# output = [gr.Plot()]
# it's good practice to pass examples, description and a title to guide users
title = 'Bag of Image Classification'
description = 'This is the demo for Keras Implementation of Classification using Attention-based Deep Multiple Instance Learning (MIL). The model will try to predict whether number 8 is within the set of input images. As it was trained on MNIST dataset, please use MNIST image for precise result.'
article = "Author: <a href=\"https://huggingface.co/geninhu\">Nhu Hoang</a>. Based on the following Keras example <a href=\"https://keras.io/examples/vision/attention_mil_classification\"> Classification using Attention-based Deep Multiple Instance Learning (MIL)</a> by <a href=\"https://www.linkedin.com/in/mohamadjaber1\">Mohamad Jaber.</a> <br> Check out the model <a href=\"https://huggingface.co/keras-io/attention_mil\">here</a>"
gr_interface = gr.Interface(
infer, inputs=[input1, input2, input3], outputs=output, allow_flagging='never',
analytics_enabled=False, title=title, description=description, article=article,
# examples = [[f'{i}.png' for i in range(0,3)], [f'{i}.png' for i in range(3,6)], [f'{i}.png' for i in range(6,9)], '9.png']
examples = [['samples/0.png', 'samples/6.png', 'samples/2.png'], ['samples/1.png','samples/2.png', 'samples/3.png'],
['samples/4.png', 'samples/8.png', 'samples/7.png'], ['samples/8.png', 'samples/0.png', 'samples/9.png'],
['samples/5.png', 'samples/6.png', 'samples/3.png'], ['samples/7.png', 'samples/8.png', 'samples/9.png']]
)
gr_interface.launch(enable_queue=True, debug=False)