File size: 4,071 Bytes
1d6d0bd
 
 
 
 
 
 
 
 
 
 
 
0a4335a
1d6d0bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4748cb
 
1d6d0bd
 
 
c56dc17
1d6d0bd
 
 
 
 
c56dc17
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from matplotlib import pyplot as plt

import numpy as np
import tensorflow as tf
import gradio as gr
from huggingface_hub import from_pretrained_keras


model = from_pretrained_keras('keras-io/attention_mil')

# functions for inference
IMG_SIZE = 28

# resize the image and it to a float between 0,1
def plot(input_images=None, predictions=None, attention_weights=None):
    bag_class = np.argmax(predictions)
    bag_class = 'This set of image does not contain number 8' if bag_class == 0 else 'This set of image contains number 8'

    # attention_weights = [round(i, 2) for i in attention_weights]
    prob_str = f"Each image probability: {attention_weights[0]:.2f}, {attention_weights[1]:.2f}, {attention_weights[2]:.2f}"

    if input_images is not None:
        figure = plt.figure(figsize=(8, 8))
        for j in range(len(input_images)):
            image = input_images[j]
            figure.add_subplot(1, len(input_images), j + 1)
            plt.grid(False)
            if attention_weights is not None:
                plt.title(f"prob={attention_weights[j]:.2f}")
            plt.imshow(np.squeeze(input_images[j]))
        return [bag_class, plt.gcf()]

    return [bag_class, prob_str]


def preprocess_image(image):
    # image = image[:, :, 0]
    image = image / 255.0
    image = np.expand_dims(image, axis = 0)
    return image

def infer(input_images_1, input_images_2, input_images_3):
    if (input_images_1 is not None) & (input_images_2 is not None) & (input_images_3 is not None):
        # Normalize input data
        input_images_1 = preprocess_image(input_images_1)
        input_images_2 = preprocess_image(input_images_2)
        input_images_3 = preprocess_image(input_images_3)

        # Collect info per model.
        prediction = model.predict([input_images_1, input_images_2, input_images_3])
        prediction = np.squeeze(np.swapaxes(prediction, 1, 0))
        intermediate_model = keras.Model(model.input, model.get_layer("alpha").output)
        intermediate_predictions = intermediate_model.predict([input_images_1, input_images_2, input_images_3])
        attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0))

        return plot(
            [input_images_1, input_images_2, input_images_3],
            predictions=prediction,
            attention_weights=attention_weights
        )

# get the inputs
input1 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='First image', show_label=True, visible=True)
input2 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='Second image', show_label=True, visible=True)
input3 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='Third image', show_label=True, visible=True)
# the app outputs two segmented images
output = [gr.Label(), gr.Plot()]
# output = [gr.Plot()]
# it's good practice to pass examples, description and a title to guide users
title = 'Bag of Image Classification'
description = 'This is the demo for Keras Implementation of Classification using Attention-based Deep Multiple Instance Learning (MIL). The model will try to predict whether number 8 is within the set of input images. As it was trained on MNIST dataset, please use MNIST image for precise result.'

gr_interface = gr.Interface(
    infer, inputs=[input1, input2, input3], outputs=output, allow_flagging='never',
    analytics_enabled=False, title=title, description=description,
    # examples = [[f'{i}.png' for i in range(0,3)], [f'{i}.png' for i in range(3,6)], [f'{i}.png' for i in range(6,9)], '9.png']
    examples = [['samples/0.png', 'samples/6.png', 'samples/2.png'], ['samples/1.png','samples/2.png', 'samples/3.png'],
                ['samples/4.png', 'samples/8.png', 'samples/7.png'], ['samples/8.png', 'samples/0.png', 'samples/9.png'],
                ['samples/5.png', 'samples/6.png', 'samples/3.png'], ['samples/7.png', 'samples/8.png', 'samples/9.png']]
)
gr_interface.launch(enable_queue=True, debug=False)