Spaces:
Runtime error
Runtime error
# pylint: disable=no-name-in-module | |
# pylint: disable=no-member | |
""" | |
Author : Bastien GUILLAUME | |
Version : 0.0.1 | |
Date : 2023-03-16 | |
Title : Inference With Gradio running an onnxruntime backend | |
""" | |
from pathlib import Path | |
import numpy as np | |
import onnxruntime as ort | |
import requests | |
from config_parser import * | |
from torchvision import transforms | |
def make_func(task, product, model_number): | |
def _analysis(image): | |
""" | |
Main function that process inference and return results strings | |
Args: | |
- task | |
- use case | |
- image | |
Returns: | |
- String including label and confidence of the model | |
""" | |
input_image = pre_process_all( | |
task=task, product=product, model_number=model_number, image=image | |
) | |
result = inference(task, product, input_image, model_number=model_number) | |
logging.log(level=logging.DEBUG, msg=result) | |
return result | |
return _analysis | |
def download_models(product, model, model_uuid): | |
logging.log(level=logging.DEBUG, msg=model) | |
response = requests.get(model, stream=True).content | |
models_folder = Path(f"models/{product}/{model_uuid}") | |
os.makedirs(models_folder, exist_ok=True) | |
filepath = Path(models_folder / f'{model.split("/")[-1]}') | |
logging.log(level=logging.DEBUG, msg=filepath) | |
if filepath.exists(): | |
pass | |
else: | |
with open(filepath, "xb") as f: | |
f.write(response) | |
# return [f"task{task_number+1}", product, filepath] | |
return filepath | |
corck_screwing_metadata = { | |
"image_threshold": 0.9247307181358337, | |
"pixel_threshold": 0.9247307181358337, | |
"min": 4.034907666260021e-26, | |
"max": 0.998478353023529, | |
} | |
inferencer_arr = {} | |
logging.log(level=logging.INFO, msg="Loading models...") | |
for task in config["tasks"].keys(): | |
inferencer_arr[task] = {} | |
r = None | |
for product in config["tasks"][task]["models"]: | |
inferencer_arr[task][product] = {} | |
for model_number in range(len(config["tasks"][task]["models"][product])): | |
model = config["tasks"][task]["models"][product][model_number] | |
model_path = model["path"] | |
model_uuid = model_path.split("/")[-2:-1][0] | |
logging.log( | |
level=logging.INFO, | |
msg=f"Loading model for product {product}, version {model_number}", | |
) | |
logging.log( | |
level=logging.INFO, | |
msg=f"Model UUID {model_uuid}", | |
) | |
inferencer_arr[task][product][str(model_number)] = {} | |
if model_path.startswith("http"): | |
# r = requests.get(model_path, stream=True).content | |
model_path = download_models(product, model_path, model_uuid) | |
inferencer_arr[task][product][str(model_number)][model_uuid] = {} | |
inferencer_arr[task][product][str(model_number)][model_uuid][ | |
"model" | |
# ] = ort.InferenceSession(r if model_path.startswith("http") else model_path) | |
] = ort.InferenceSession(model_path.as_posix()) | |
inferencer_arr[task][product][str(model_number)][model_uuid]["function"] = make_func( | |
task, product, model_number | |
) | |
inferencer_arr[task][product][str(model_number)][model_uuid]["input_name"] = ( | |
inferencer_arr[task][product][str(model_number)][model_uuid]["model"] | |
.get_inputs()[0] | |
.name | |
) | |
inferencer_arr[task][product][str(model_number)][model_uuid]["output_name"] = ( | |
inferencer_arr[task][product][str(model_number)][model_uuid]["model"] | |
.get_outputs()[0] | |
.name | |
) | |
logging.log(level=logging.INFO, msg=f"Model for {product} loaded\n") | |
logging.log(level=logging.INFO, msg="All models loaded...\n\n") | |
logging.log(level=logging.DEBUG, msg=f"Inferencer Array : {inferencer_arr}") | |
def softmax(x): | |
e_x = np.exp(x - np.max(x)) | |
return e_x / e_x.sum(axis=0) | |
def is_anomalous_classification(task, product, model_number, prediction, meta_data): | |
pred_label = None | |
pred_score = prediction.reshape(-1).max() | |
logging.log(level=logging.INFO, msg=f"Task {task}") | |
logging.log(level=logging.INFO, msg=f"Product {product}") | |
class_names = config["tasks"][task]["models"][product][model_number]["class_names"] | |
if "image_threshold" in meta_data: | |
pred_label = ( | |
class_names[0] | |
if (pred_score >= meta_data["image_threshold"]) | |
else class_names[1] | |
) | |
logging.log( | |
level=logging.INFO, | |
msg=f"Predicted label {pred_label} with a confidence of {pred_score}", | |
) | |
return pred_label, pred_score | |
def pre_process_all(task, product, model_number, image): | |
# model_number = model_number-1 | |
logging.log(level=logging.INFO, msg=f"Task {task}") | |
logging.log(level=logging.INFO, msg=f"Product {product}") | |
logging.log(level=logging.INFO, msg=f"Model Number {model_number}") | |
logging.log( | |
level=logging.DEBUG, | |
msg=f'Product Array {config["tasks"][task]["models"][product]}', | |
) | |
logging.log( | |
level=logging.DEBUG, | |
msg=f'Model Array {config["tasks"][task]["models"][product][model_number]}', | |
) | |
preprocessed_image = [] | |
input_shape = config["tasks"][task]["models"][product][model_number]["input_shape"] | |
mean = config["tasks"][task]["models"][product][int(model_number)]["mean"] | |
std = config["tasks"][task]["models"][product][int(model_number)]["std"] | |
logging.log(level=logging.DEBUG, msg=f"Shape {input_shape}") | |
logging.log(level=logging.DEBUG, msg=f"Mean {mean}") | |
logging.log(level=logging.DEBUG, msg=f"Std {std}") | |
data_transforms = transforms.Compose( | |
[ | |
transforms.ToPILImage(), | |
transforms.Resize(input_shape), | |
transforms.CenterCrop(input_shape), | |
transforms.ToTensor(), | |
transforms.Normalize(mean, std), | |
] | |
) | |
preprocessed_image = data_transforms(image).detach().numpy() | |
preprocessed_image = np.expand_dims(preprocessed_image, axis=0) | |
logging.log(level=logging.DEBUG, msg=preprocessed_image) | |
return preprocessed_image | |
def inference(task, product, input_image, model_number): | |
""" | |
Process inference for bottle labels | |
Args: | |
- task | |
- product | |
- model to use (number) | |
- image | |
Returns: | |
- String including label and confidence of the model | |
""" | |
# model_number = model_number-1 | |
logging.log(level=logging.DEBUG, msg=f"Inferencer {inferencer_arr}") | |
logging.log(level=logging.INFO, msg=f"Task {task}") | |
logging.log(level=logging.INFO, msg=f"Product {product}") | |
logging.log(level=logging.INFO, msg=f"Model {model_number}") | |
model_uuid = config["tasks"][task]["models"][product][int(model_number)]["path"].split("/")[-2:-1][0] | |
result = "Algorithm not yet supported" | |
prediction = inferencer_arr[task][product][str(model_number)][model_uuid]["model"].run( | |
[inferencer_arr[task][product][str(model_number)][model_uuid]["output_name"]], | |
{inferencer_arr[task][product][str(model_number)][model_uuid]["input_name"]: input_image}, | |
) | |
prediction = prediction[0].squeeze() | |
model_type = config["tasks"][task]["models"][product][int(model_number)]["type"] | |
class_names = config["tasks"][task]["models"][product][int(model_number)][ | |
"class_names" | |
] | |
if model_type == "classification": | |
logging.log( | |
level=logging.INFO, msg=f"Softmaxed prediction {softmax(prediction)}" | |
) | |
result = f"{class_names[np.argmax(prediction)]} avec une confiance de {str(round(softmax(prediction)[np.argmax(prediction)]*100))} %" | |
elif model_type == "anomaly_detection-classification": | |
label, score = is_anomalous_classification( | |
task=task, | |
product=product, | |
model_number=int(model_number), | |
prediction=prediction, | |
meta_data=corck_screwing_metadata, | |
) | |
result = f"{label} avec une confiance de {str(round(score*100))} %" | |
else: | |
pass | |
logging.log(level=logging.DEBUG, msg=result) | |
return result | |