zihaoz96 commited on
Commit
5ecfb97
1 Parent(s): 17891ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -10
app.py CHANGED
@@ -1,29 +1,58 @@
1
  import gradio as gr
2
  import numpy as np
3
- import pandas as pd
4
  from tensorflow.keras import models
5
 
6
  import tensorflow as tf
7
 
 
 
 
 
 
 
8
  # open categories.txt in read mode
9
  categories = open("categories.txt", "r")
10
  labels = categories.readline().split(";")
11
 
12
- model = models.load_model('models/modelnet/best_model.h5')
 
 
 
 
13
 
 
 
 
 
 
 
 
14
 
15
- def predict_image(image):
16
- image = np.array(image) / 255
17
- image = np.expand_dims(image, axis=0)
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- pred = model.predict(image)
20
 
21
- acc = dict((labels[i], "%.2f" % pred[0][i]) for i in range(len(labels)))
22
- print(acc)
23
- return acc
24
 
 
 
25
 
26
- image = gr.inputs.Image(shape=(224, 224), label="Upload Your Image Here")
27
  label = gr.outputs.Label(num_top_classes=len(labels))
28
 
29
  samples = ['samples/basking.jpg', 'samples/blacktip.jpg', 'samples/blue.jpg', 'samples/bull.jpg', 'samples/hammerhead.jpg',
 
1
  import gradio as gr
2
  import numpy as np
 
3
  from tensorflow.keras import models
4
 
5
  import tensorflow as tf
6
 
7
+ models_name = [
8
+ "VGG16",
9
+ "mobilenet_v2",
10
+ "DenseNet"
11
+ ]
12
+
13
  # open categories.txt in read mode
14
  categories = open("categories.txt", "r")
15
  labels = categories.readline().split(";")
16
 
17
+ # create a radio
18
+ radio = gr.inputs.Radio(models_name, default="DenseNet", type="value")
19
+
20
+ def predict_image(image, model_name):
21
+
22
 
23
+ print("======================")
24
+ print(type(image))
25
+ print(type(model_name))
26
+ print("==========")
27
+ print(image)
28
+ print(model_name)
29
+ print("======================")
30
 
31
+ if model_name == "DenseNet":
32
+ image = np.array(image) / 255
33
+ image = np.expand_dims(image, axis=0)
34
+
35
+ model = "./models/" + model_name + "model.h5"
36
+ pred = model.predict(image)
37
+
38
+ pred = dict((labels[i], "%.2f" % pred[0][i]) for i in range(len(labels)))
39
+ else:
40
+
41
+ image = Image.fromarray(np.uint8(image)).convert('RGB')
42
+ classifier = TorchVisionClassifierInference(
43
+ model_path = "./models/" + model_name
44
+ )
45
 
46
+ pred = classifier.predict_image(img=image, return_str=False)
47
 
48
+ for key in pred.keys():
49
+ pred[key] = pred[key]/100
50
+
51
 
52
+ print(pred)
53
+ return pred
54
 
55
+ image = gr.inputs.Image(shape=(300, 300), label="Upload Your Image Here")
56
  label = gr.outputs.Label(num_top_classes=len(labels))
57
 
58
  samples = ['samples/basking.jpg', 'samples/blacktip.jpg', 'samples/blue.jpg', 'samples/bull.jpg', 'samples/hammerhead.jpg',