File size: 2,065 Bytes
3ba4276
 
 
 
 
5f6a9dc
3ba4276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f6a9dc
 
 
3ba4276
 
 
8a68e19
3ba4276
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
import tensorflow as tf

import cv2
import keras
import PIL
from keras import Sequential
from keras.applications.resnet50 import ResNet50
from keras.layers import Flatten, Dense

class PaceModel:
    """
    The pace model which uses ResNet50's architecture as base and builds upon by adding further layers to determine the pace of an image.
    """
    def __init__(self, height, width, channels, resnet50_tf_model_weights_path, pace_model_weights_path):
        self.resnet_model = Sequential()
        self.height = height
        self.width = width
        self.channels = channels
        self.class_names = ["Fast", "Medium", "Slow"]
        self.resnet50_tf_model_weights_path = resnet50_tf_model_weights_path
        self.pace_model_weights_path = pace_model_weights_path

        self.create_base_model()
        self.create_architecture()
    
    def create_base_model(self):
        self.base_model = ResNet50(
            include_top=False,
            input_shape=(self.height, self.width, self.channels),
            pooling="avg",
            classes=211,
            weights="imagenet"
        )
        self.base_model.load_weights(self.resnet50_tf_model_weights_path)

        for layer in self.base_model.layers:
            layer.trainable = False
    
    def create_architecture(self):
        self.resnet_model.add(self.base_model)
        self.resnet_model.add(Flatten())
        self.resnet_model.add(Dense(1024, activation="relu"))
        self.resnet_model.add(Dense(256, activation="relu"))
        self.resnet_model.add(Dense(3, activation="softmax"))

        self.resnet_model.load_weights(self.pace_model_weights_path)
            
    def predict(self, input_image: PIL.Image.Image):
        np_image = np.array(input_image)
        resized_image = cv2.resize(np_image, (self.height, self.width))
        image = np.expand_dims(resized_image, axis=0)

        prediction = self.resnet_model.predict(image)
        # print(prediction, np.argmax(prediction))
        return self.class_names[np.argmax(prediction)]