manasch commited on
Commit
fb37ffb
·
verified ·
1 Parent(s): 40a4417

add basic app for pace prediction

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ import gradio as gr
6
+
7
+ import cv2
8
+ import keras
9
+ from keras import Sequential
10
+ from keras.layers import Flatten, Dense
11
+
12
+ model_weights_path = (Path.cwd() / "pace_model_weights.h5").resolve()
13
+ height, width, channels = (224, 224, 3)
14
+
15
+ class PaceModel:
16
+ def __init__(self, height, width, channels):
17
+ self.resnet_model = Sequential()
18
+ self.height = height
19
+ self.width = width
20
+ self.channels = channels
21
+ self.class_names = ["Fast", "Medium", "Slow"]
22
+
23
+ self.create_pretrained()
24
+ self.create_architecture()
25
+
26
+ def create_pretrained(self):
27
+ self.pretrained_model = tf.keras.applications.ResNet50(
28
+ include_top=False,
29
+ input_shape=(self.height, self.width, self.channels),
30
+ pooling="avg",
31
+ classes=211,
32
+ weights="imagenet"
33
+ )
34
+
35
+ for layer in self.pretrained_model.layers:
36
+ layer.trainable = False
37
+
38
+ def create_architecture(self):
39
+ self.resnet_model.add(self.pretrained_model)
40
+ self.resnet_model.add(Flatten())
41
+ self.resnet_model.add(Dense(1024, activation="relu"))
42
+ self.resnet_model.add(Dense(256, activation="relu"))
43
+ self.resnet_model.add(Dense(3, activation="softmax"))
44
+
45
+ self.resnet_model.load_weights(model_weights_path)
46
+
47
+ def predict(self, input_image: np.ndarray):
48
+ resized_image = cv2.resize(input_image, (self.height, self.width))
49
+ image = np.expand_dims(resized_image, axis=0)
50
+
51
+ prediction = self.resnet_model.predict(image)
52
+ print(prediction, np.argmax(prediction))
53
+ return self.class_names[np.argmax(prediction)]
54
+
55
+ def main():
56
+ model = PaceModel(height, width, channels)
57
+
58
+ demo = gr.Interface(
59
+ fn=model.predict,
60
+ inputs=gr.Image(
61
+ type="numpy",
62
+ label="Upload an image",
63
+ show_label=True,
64
+ container=True
65
+ ),
66
+ outputs=gr.Textbox(
67
+ lines=1,
68
+ placeholder="Fast | Medium | Slow",
69
+ label="Pace of the image",
70
+ show_label=True,
71
+ container=True,
72
+ type="text"
73
+ ),
74
+ cache_examples=False,
75
+ live=False,
76
+ title="Predict Pace",
77
+ description="Provide an image to determine the pace of the image",
78
+ )
79
+
80
+ demo.queue().launch()
81
+
82
+ if __name__ == "__main__":
83
+ main()