File size: 863 Bytes
46cc8e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import keras
import numpy as np
class ModelInference:
"""Inference module to predict an image class"""
def __init__(self, weights_path='./Weights/model.h5', threshold=0.8):
self.weights_path = weights_path
self.threshold = threshold
self.model = self.load_model()
def load_model(self):
return keras.models.load_model(self.weights_path)
def predict(self, image_array: np.array) -> bool:
model_output = self.model.predict(image_array)
prediction = self.parse_model_output(model_output)
return prediction
def parse_model_output(self, model_output: list) -> bool:
confidence_score = model_output[0][0]
result = False
if confidence_score >= self.threshold:
result = True
return result |