classification / inferencer.py
Isaacgv's picture
update
17979b3
# pylint: disable=no-name-in-module
# pylint: disable=no-member
"""
Author : Bastien GUILLAUME
Version : 0.0.1
Date : 2023-03-16
Title : Inference With Gradio running an onnxruntime backend
"""
from pathlib import Path
import numpy as np
import onnxruntime as ort
import requests
from config_parser import *
from torchvision import transforms
def make_func(task, product, model_number):
def _analysis(image):
"""
Main function that process inference and return results strings
Args:
- task
- use case
- image
Returns:
- String including label and confidence of the model
"""
input_image = pre_process_all(
task=task, product=product, model_number=model_number, image=image
)
result = inference(task, product, input_image, model_number=model_number)
logging.log(level=logging.DEBUG, msg=result)
return result
return _analysis
def download_models(product, model, model_uuid):
logging.log(level=logging.DEBUG, msg=model)
response = requests.get(model, stream=True).content
models_folder = Path(f"models/{product}/{model_uuid}")
os.makedirs(models_folder, exist_ok=True)
filepath = Path(models_folder / f'{model.split("/")[-1]}')
logging.log(level=logging.DEBUG, msg=filepath)
if filepath.exists():
pass
else:
with open(filepath, "xb") as f:
f.write(response)
# return [f"task{task_number+1}", product, filepath]
return filepath
corck_screwing_metadata = {
"image_threshold": 0.9247307181358337,
"pixel_threshold": 0.9247307181358337,
"min": 4.034907666260021e-26,
"max": 0.998478353023529,
}
inferencer_arr = {}
logging.log(level=logging.INFO, msg="Loading models...")
for task in config["tasks"].keys():
inferencer_arr[task] = {}
r = None
for product in config["tasks"][task]["models"]:
inferencer_arr[task][product] = {}
for model_number in range(len(config["tasks"][task]["models"][product])):
model = config["tasks"][task]["models"][product][model_number]
model_path = model["path"]
model_uuid = model_path.split("/")[-2:-1][0]
logging.log(
level=logging.INFO,
msg=f"Loading model for product {product}, version {model_number}",
)
logging.log(
level=logging.INFO,
msg=f"Model UUID {model_uuid}",
)
inferencer_arr[task][product][str(model_number)] = {}
if model_path.startswith("http"):
# r = requests.get(model_path, stream=True).content
model_path = download_models(product, model_path, model_uuid)
inferencer_arr[task][product][str(model_number)][model_uuid] = {}
inferencer_arr[task][product][str(model_number)][model_uuid][
"model"
# ] = ort.InferenceSession(r if model_path.startswith("http") else model_path)
] = ort.InferenceSession(model_path.as_posix())
inferencer_arr[task][product][str(model_number)][model_uuid]["function"] = make_func(
task, product, model_number
)
inferencer_arr[task][product][str(model_number)][model_uuid]["input_name"] = (
inferencer_arr[task][product][str(model_number)][model_uuid]["model"]
.get_inputs()[0]
.name
)
inferencer_arr[task][product][str(model_number)][model_uuid]["output_name"] = (
inferencer_arr[task][product][str(model_number)][model_uuid]["model"]
.get_outputs()[0]
.name
)
logging.log(level=logging.INFO, msg=f"Model for {product} loaded\n")
logging.log(level=logging.INFO, msg="All models loaded...\n\n")
logging.log(level=logging.DEBUG, msg=f"Inferencer Array : {inferencer_arr}")
def softmax(x):
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=0)
def is_anomalous_classification(task, product, model_number, prediction, meta_data):
pred_label = None
pred_score = prediction.reshape(-1).max()
logging.log(level=logging.INFO, msg=f"Task {task}")
logging.log(level=logging.INFO, msg=f"Product {product}")
class_names = config["tasks"][task]["models"][product][model_number]["class_names"]
if "image_threshold" in meta_data:
pred_label = (
class_names[0]
if (pred_score >= meta_data["image_threshold"])
else class_names[1]
)
logging.log(
level=logging.INFO,
msg=f"Predicted label {pred_label} with a confidence of {pred_score}",
)
return pred_label, pred_score
def pre_process_all(task, product, model_number, image):
# model_number = model_number-1
logging.log(level=logging.INFO, msg=f"Task {task}")
logging.log(level=logging.INFO, msg=f"Product {product}")
logging.log(level=logging.INFO, msg=f"Model Number {model_number}")
logging.log(
level=logging.DEBUG,
msg=f'Product Array {config["tasks"][task]["models"][product]}',
)
logging.log(
level=logging.DEBUG,
msg=f'Model Array {config["tasks"][task]["models"][product][model_number]}',
)
preprocessed_image = []
input_shape = config["tasks"][task]["models"][product][model_number]["input_shape"]
mean = config["tasks"][task]["models"][product][int(model_number)]["mean"]
std = config["tasks"][task]["models"][product][int(model_number)]["std"]
logging.log(level=logging.DEBUG, msg=f"Shape {input_shape}")
logging.log(level=logging.DEBUG, msg=f"Mean {mean}")
logging.log(level=logging.DEBUG, msg=f"Std {std}")
data_transforms = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize(input_shape),
transforms.CenterCrop(input_shape),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
)
preprocessed_image = data_transforms(image).detach().numpy()
preprocessed_image = np.expand_dims(preprocessed_image, axis=0)
logging.log(level=logging.DEBUG, msg=preprocessed_image)
return preprocessed_image
def inference(task, product, input_image, model_number):
"""
Process inference for bottle labels
Args:
- task
- product
- model to use (number)
- image
Returns:
- String including label and confidence of the model
"""
# model_number = model_number-1
logging.log(level=logging.DEBUG, msg=f"Inferencer {inferencer_arr}")
logging.log(level=logging.INFO, msg=f"Task {task}")
logging.log(level=logging.INFO, msg=f"Product {product}")
logging.log(level=logging.INFO, msg=f"Model {model_number}")
model_uuid = config["tasks"][task]["models"][product][int(model_number)]["path"].split("/")[-2:-1][0]
result = "Algorithm not yet supported"
prediction = inferencer_arr[task][product][str(model_number)][model_uuid]["model"].run(
[inferencer_arr[task][product][str(model_number)][model_uuid]["output_name"]],
{inferencer_arr[task][product][str(model_number)][model_uuid]["input_name"]: input_image},
)
prediction = prediction[0].squeeze()
model_type = config["tasks"][task]["models"][product][int(model_number)]["type"]
class_names = config["tasks"][task]["models"][product][int(model_number)][
"class_names"
]
if model_type == "classification":
logging.log(
level=logging.INFO, msg=f"Softmaxed prediction {softmax(prediction)}"
)
result = f"{class_names[np.argmax(prediction)]} avec une confiance de {str(round(softmax(prediction)[np.argmax(prediction)]*100))} %"
elif model_type == "anomaly_detection-classification":
label, score = is_anomalous_classification(
task=task,
product=product,
model_number=int(model_number),
prediction=prediction,
meta_data=corck_screwing_metadata,
)
result = f"{label} avec une confiance de {str(round(score*100))} %"
else:
pass
logging.log(level=logging.DEBUG, msg=result)
return result