liewchooichin commited on
Commit
5e033e4
·
verified ·
1 Parent(s): 72c83cc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sun Jan 28 18:48:07 2024
4
+
5
+ @author: liewchooichin
6
+ """
7
+ import os
8
+ import pathlib
9
+ import gradio as gr
10
+ import pandas as pd
11
+ # my own py to make predictions
12
+ import image_pretrained
13
+
14
+ # global variables
15
+ # predictions from:
16
+ pred_eff = pd.DataFrame() # Efficient Net
17
+ pred_mob = pd.DataFrame() # Mobile Net
18
+ pred_xcept = pd.DataFrame() # Xception
19
+
20
+
21
+ def get_prediction(img_path):
22
+ pred_eff, pred_mob, pred_xcept = \
23
+ image_pretrained.predict(img_path)
24
+ print(pred_eff)
25
+ return pred_eff, pred_mob, pred_xcept
26
+
27
+
28
+ def clear_image(img):
29
+ # Clear the previous output result
30
+ return pred_eff, pred_mob, pred_xcept
31
+
32
+
33
+ with gr.Blocks() as demo:
34
+ image_width = 256
35
+ image_height = 256
36
+
37
+ gr.Markdown(
38
+ """
39
+ # Image classfication
40
+
41
+ Predict the class of the image with pretrained model.
42
+
43
+ Models: Xception, MobileNet V3 Small, \
44
+ EfficientNet V2 Small.
45
+
46
+ Top three predictions of classes are shown for each \
47
+ of the model.
48
+
49
+ Upload an image for predictions of its class and \
50
+ its probabilities.
51
+ """
52
+ )
53
+ with gr.Row():
54
+ with gr.Column():
55
+ img = gr.Image(height=image_height,
56
+ width=image_width,
57
+ sources=["upload", "clipboard"],
58
+ interactive=True,
59
+ type="filepath")
60
+
61
+ # label_1 = gr.Label(label="Efficient net")
62
+ # label_2 = gr.Label(label="Mobile net")
63
+ # label_3 = gr.Label(label="Xception")
64
+ with gr.Column():
65
+ text_1 = gr.Text(label="Efficient net v2")
66
+ text_2 = gr.Text(label="Mobile net v3")
67
+ text_3 = gr.Text(label="Xception")
68
+
69
+ # load the images directory
70
+ data_dir = "images"
71
+ img_path = pathlib.Path(data_dir)
72
+ image_list = [[i] for i in list(img_path.glob("*.jpg"))]
73
+ print(f"List of examples: {image_list}")
74
+ examples = gr.Examples(
75
+ examples=[
76
+ os.path.join(os.path.dirname(__file__), "images",
77
+ "cat.jpg"),
78
+ os.path.join(os.path.dirname(__file__), "images",
79
+ "mrt_train.jpg"),
80
+ os.path.join(os.path.dirname(__file__), "images",
81
+ "duck.jpg"),
82
+ os.path.join(os.path.dirname(__file__), "images",
83
+ "daisy.jpg"),
84
+ os.path.join(os.path.dirname(__file__), "images",
85
+ "apples.jpg"),
86
+ os.path.join(os.path.dirname(__file__), "images",
87
+ "bus.jpg"),
88
+ os.path.join(os.path.dirname(__file__), "images",
89
+ "butterfly.jpg"),
90
+ ],
91
+ inputs=[img],
92
+ outputs=[text_1, text_2, text_3],
93
+ run_on_click=True,
94
+ fn=get_prediction
95
+ )
96
+ # prediction when a file is uploaded
97
+ img.upload(fn=get_prediction, inputs=[img],
98
+ outputs=[text_1, text_2, text_3])
99
+ # when an example is clicked
100
+ img.change(fn=get_prediction, inputs=[img],
101
+ outputs=[text_1, text_2, text_3])
102
+ # when an image is cleared
103
+ img.clear(fn=clear_image, inputs=[img],
104
+ outputs=[text_1, text_2, text_3])
105
+
106
+ if __name__ == "__main__":
107
+ demo.launch()