audio_palette / lib /pace_model.py
manasch's picture
add sentiment analyser and refactor code
8a68e19 verified
raw
history blame contribute delete
No virus
2.07 kB
import numpy as np
import tensorflow as tf
import cv2
import keras
import PIL
from keras import Sequential
from keras.applications.resnet50 import ResNet50
from keras.layers import Flatten, Dense
class PaceModel:
"""
The pace model which uses ResNet50's architecture as base and builds upon by adding further layers to determine the pace of an image.
"""
def __init__(self, height, width, channels, resnet50_tf_model_weights_path, pace_model_weights_path):
self.resnet_model = Sequential()
self.height = height
self.width = width
self.channels = channels
self.class_names = ["Fast", "Medium", "Slow"]
self.resnet50_tf_model_weights_path = resnet50_tf_model_weights_path
self.pace_model_weights_path = pace_model_weights_path
self.create_base_model()
self.create_architecture()
def create_base_model(self):
self.base_model = ResNet50(
include_top=False,
input_shape=(self.height, self.width, self.channels),
pooling="avg",
classes=211,
weights="imagenet"
)
self.base_model.load_weights(self.resnet50_tf_model_weights_path)
for layer in self.base_model.layers:
layer.trainable = False
def create_architecture(self):
self.resnet_model.add(self.base_model)
self.resnet_model.add(Flatten())
self.resnet_model.add(Dense(1024, activation="relu"))
self.resnet_model.add(Dense(256, activation="relu"))
self.resnet_model.add(Dense(3, activation="softmax"))
self.resnet_model.load_weights(self.pace_model_weights_path)
def predict(self, input_image: PIL.Image.Image):
np_image = np.array(input_image)
resized_image = cv2.resize(np_image, (self.height, self.width))
image = np.expand_dims(resized_image, axis=0)
prediction = self.resnet_model.predict(image)
# print(prediction, np.argmax(prediction))
return self.class_names[np.argmax(prediction)]