# Copyright (C) 2022, Pyronear. # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. import argparse import json import gradio as gr import numpy as np import onnxruntime from huggingface_hub import hf_hub_download from PIL import Image REPO = "pyronear/rexnet1_0x" # Download model config & checkpoint with open(hf_hub_download(REPO, filename="config.json"), "rb") as f: cfg = json.load(f) ort_session = onnxruntime.InferenceSession(hf_hub_download(REPO, filename="model.onnx")) def preprocess_image(pil_img: Image.Image) -> np.ndarray: """Preprocess an image for inference Args: pil_img: a valid pillow image Returns: the resized and normalized image of shape (1, C, H, W) """ # Resizing img = pil_img.resize(cfg["input_shape"][-2:], Image.BILINEAR) # (H, W, C) --> (C, H, W) img = np.asarray(img).transpose((2, 0, 1)).astype(np.float32) / 255 # Normalization img -= np.array(cfg["mean"])[:, None, None] img /= np.array(cfg["std"])[:, None, None] return img[None, ...] def predict(image): # Preprocessing np_img = preprocess_image(image) ort_input = {ort_session.get_inputs()[0].name: np_img} # Inference ort_out = ort_session.run(None, ort_input) # Post-processing probs = 1 / (1 + np.exp(-ort_out[0][0])) return {class_name: float(conf) for class_name, conf in zip(cfg["classes"], probs)} img = gr.inputs.Image(type="pil") outputs = gr.outputs.Label(num_top_classes=1) gr.Interface( fn=predict, inputs=[img], outputs=outputs, title="PyroVision: image classification demo", article=( "

" "Github Repo | " "Documentation

" ), live=True, ).launch()