manasch commited on
Commit
ed8c121
1 Parent(s): d3ad6e4

update app.py and add resnet50 model

Browse files
app.py CHANGED
@@ -9,7 +9,8 @@ import keras
9
  from keras import Sequential
10
  from keras.layers import Flatten, Dense
11
 
12
- model_weights_path = (Path.cwd() / "pace_model_weights.h5").resolve()
 
13
  height, width, channels = (224, 224, 3)
14
 
15
  class PaceModel:
@@ -20,29 +21,30 @@ class PaceModel:
20
  self.channels = channels
21
  self.class_names = ["Fast", "Medium", "Slow"]
22
 
23
- self.create_pretrained()
24
  self.create_architecture()
25
 
26
- def create_pretrained(self):
27
- self.pretrained_model = tf.keras.applications.ResNet50(
28
  include_top=False,
29
  input_shape=(self.height, self.width, self.channels),
30
  pooling="avg",
31
  classes=211,
32
  weights="imagenet"
33
  )
 
34
 
35
- for layer in self.pretrained_model.layers:
36
  layer.trainable = False
37
 
38
  def create_architecture(self):
39
- self.resnet_model.add(self.pretrained_model)
40
  self.resnet_model.add(Flatten())
41
  self.resnet_model.add(Dense(1024, activation="relu"))
42
  self.resnet_model.add(Dense(256, activation="relu"))
43
  self.resnet_model.add(Dense(3, activation="softmax"))
44
 
45
- self.resnet_model.load_weights(model_weights_path)
46
 
47
  def predict(self, input_image: np.ndarray):
48
  resized_image = cv2.resize(input_image, (self.height, self.width))
 
9
  from keras import Sequential
10
  from keras.layers import Flatten, Dense
11
 
12
+ pace_model_weights_path = (Path.cwd() / "models" / "pace_model_weights.h5").resolve()
13
+ resnet50_tf_model_weights_path = (Path.cwd() / "models" / "resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5")
14
  height, width, channels = (224, 224, 3)
15
 
16
  class PaceModel:
 
21
  self.channels = channels
22
  self.class_names = ["Fast", "Medium", "Slow"]
23
 
24
+ self.create_base_model()
25
  self.create_architecture()
26
 
27
+ def create_base_model(self):
28
+ self.base_model = tf.keras.applications.ResNet50(
29
  include_top=False,
30
  input_shape=(self.height, self.width, self.channels),
31
  pooling="avg",
32
  classes=211,
33
  weights="imagenet"
34
  )
35
+ self.base_model.load_weights(resnet50_tf_model_weights_path)
36
 
37
+ for layer in self.base_model.layers:
38
  layer.trainable = False
39
 
40
  def create_architecture(self):
41
+ self.resnet_model.add(self.base_model)
42
  self.resnet_model.add(Flatten())
43
  self.resnet_model.add(Dense(1024, activation="relu"))
44
  self.resnet_model.add(Dense(256, activation="relu"))
45
  self.resnet_model.add(Dense(3, activation="softmax"))
46
 
47
+ self.resnet_model.load_weights(pace_model_weights_path)
48
 
49
  def predict(self, input_image: np.ndarray):
50
  resized_image = cv2.resize(input_image, (self.height, self.width))
pace_model_weights.h5 → models/pace_model_weights.h5 RENAMED
File without changes
models/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66c8b43daff3fcc15bc4f30e3d2a167e21a14d9c9598a5394e5516471f4af504
3
+ size 94765736