osheina's picture
Upload 5 files
a7b063a verified
from sys import platform
import onnxruntime as rt
from einops import rearrange
import numpy as np
if platform in {"win32", "win64"}:
import onnxruntime.tools.add_openvino_win_libs as utils
utils.add_openvino_libs_to_path()
class Predictor:
def __init__(self, model_config):
"""
Initialize the Predictor class.
Args:
model_config (dict): Model configuration containing path_to_model,
path_to_class_list, threshold, and topk values.
"""
self.config = model_config
self.provider = self.config["provider"]
self.threshold = self.config["threshold"]
self.labels = {}
self.model_init(self.config["path_to_model"])
self.create_labels()
def create_labels(self):
"""
Create a dictionary of labels from the provided path_to_class_list.
"""
with open(self.config["path_to_class_list"], "r") as f:
labels = [line.strip() for line in f]
labels = self.decode_preds(labels)
idx_lbl_pairs = [x.split("\t") for x in labels]
self.labels = {int(x[0]): x[1] for x in idx_lbl_pairs}
def softmax(self, x):
exp_x = np.exp(x - np.max(x, axis=1, keepdims=True))
return exp_x / np.sum(exp_x, axis=1, keepdims=True)
def predict(self, x):
"""
Make a prediction using the provided input frames.
Args:
x (list): List of input frames.
Returns:
dict: Dictionary containing predicted labels and confidence values.
"""
clip = np.array(x).astype(np.float32) / 255.0
clip = rearrange(clip, "t h w c -> 1 c t h w")
prediction = self.model([self.output_name], {self.input_name: clip})[0]
prediction = self.softmax(prediction)
prediction = np.squeeze(prediction)
topk_labels = prediction.argsort()[-self.config["topk"] :][::-1]
topk_confidence = prediction[topk_labels]
result = [self.labels[lbl_idx] for lbl_idx in topk_labels]
if np.max(topk_confidence) < self.threshold:
return None
return {
"labels": dict(zip([i for i in range(len(result))], result)),
"confidence": dict(zip([i for i in range(len(result))], topk_confidence)),
}
def model_init(self, path_to_model: str) -> None:
"""
Load and init the ONNX model using the provided path.
Args:
path_to_model (str): Path to the ONNX model file.
Returns:
None
"""
session = rt.InferenceSession(path_to_model, providers=[self.provider])
self.input_name = session.get_inputs()[0].name
self.output_name = session.get_outputs()[0].name
self.model = session.run
def decode_preds(self, data):
if platform in {"win32", "win64"}:
data = [i.encode("cp1251").decode("utf-8") for i in data]
return data