geninhu commited on
Commit
1d6d0bd
1 Parent(s): 3458250

Add application file

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from tensorflow import keras
4
+ from tensorflow.keras import layers
5
+ from matplotlib import pyplot as plt
6
+
7
+ import numpy as np
8
+ import tensorflow as tf
9
+ import gradio as gr
10
+ from huggingface_hub import from_pretrained_keras
11
+
12
+
13
+ model = from_pretrained_keras('geninhu/attention_mil')
14
+
15
+ # functions for inference
16
+ IMG_SIZE = 28
17
+
18
+ # resize the image and it to a float between 0,1
19
+ def plot(input_images=None, predictions=None, attention_weights=None):
20
+ bag_class = np.argmax(predictions)
21
+ bag_class = 'This set of image does not contain number 8' if bag_class == 0 else 'This set of image contains number 8'
22
+
23
+ # attention_weights = [round(i, 2) for i in attention_weights]
24
+ prob_str = f"Each image probability: {attention_weights[0]:.2f}, {attention_weights[1]:.2f}, {attention_weights[2]:.2f}"
25
+
26
+ if input_images is not None:
27
+ figure = plt.figure(figsize=(8, 8))
28
+ for j in range(len(input_images)):
29
+ image = input_images[j]
30
+ figure.add_subplot(1, len(input_images), j + 1)
31
+ plt.grid(False)
32
+ if attention_weights is not None:
33
+ plt.title(f"prob={attention_weights[j]:.2f}")
34
+ plt.imshow(np.squeeze(input_images[j]))
35
+ return [bag_class, plt.gcf()]
36
+
37
+ return [bag_class, prob_str]
38
+
39
+
40
+ def preprocess_image(image):
41
+ # image = image[:, :, 0]
42
+ image = image / 255.0
43
+ image = np.expand_dims(image, axis = 0)
44
+ return image
45
+
46
+ def infer(input_images_1, input_images_2, input_images_3):
47
+ if (input_images_1 is not None) & (input_images_2 is not None) & (input_images_3 is not None):
48
+ # Normalize input data
49
+ input_images_1 = preprocess_image(input_images_1)
50
+ input_images_2 = preprocess_image(input_images_2)
51
+ input_images_3 = preprocess_image(input_images_3)
52
+
53
+ # Collect info per model.
54
+ prediction = model.predict([input_images_1, input_images_2, input_images_3])
55
+ prediction = np.squeeze(np.swapaxes(prediction, 1, 0))
56
+ intermediate_model = keras.Model(model.input, model.get_layer("alpha").output)
57
+ intermediate_predictions = intermediate_model.predict([input_images_1, input_images_2, input_images_3])
58
+ attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0))
59
+
60
+ return plot(
61
+ [input_images_1, input_images_2, input_images_3],
62
+ predictions=prediction,
63
+ attention_weights=attention_weights
64
+ )
65
+
66
+ # get the inputs
67
+ input1 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='First image', show_label=True, visible=True)
68
+ input2 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='Second image', show_label=True, visible=True)
69
+ input3 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='Third image', show_label=True, visible=True)
70
+ # the app outputs two segmented images
71
+ output = [gr.Label(), gr.Plot()]
72
+ # output = [gr.Plot()]
73
+ # it's good practice to pass examples, description and a title to guide users
74
+ title = 'Image classification'
75
+ description = 'Upload an image'
76
+
77
+ gr_interface = gr.Interface(
78
+ infer, inputs=[input1, input2, input3], outputs=output, allow_flagging='never',
79
+ analytics_enabled=False, title=title, description=description, live=True,
80
+ # examples = [[f'{i}.png' for i in range(0,3)], [f'{i}.png' for i in range(3,6)], [f'{i}.png' for i in range(6,9)], '9.png']
81
+ examples = [['samples/0.png', 'samples/6.png', 'samples/2.png'], ['samples/1.png','samples/2.png', 'samples/3.png'],
82
+ ['samples/4.png', 'samples/8.png', 'samples/7.png'], ['samples/8.png', 'samples/0.png', 'samples/9.png'],
83
+ ['samples/5.png', 'samples/6.png', 'samples/3.png'], ['samples/7.png', 'samples/8.png', 'samples/9.png']]
84
+ )
85
+ gr_interface.launch(enable_queue=True, debug=True, inbrowser=True)
86
+
87
+ # gr_interface = gr.Interface(infer, input, output, examples=examples, allow_flagging=False, analytics_enabled=False, title=title, description=description).launch(enable_queue=True, debug=False)
88
+ # gr_interface.launch()