aje6's picture
Update app.py
65cfa3a verified
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as T
from ultralytics import YOLO
import onnxruntime as ort
import cv2
import numpy as np
# Load the onnx model
model = ort.InferenceSession("Model_IV.onnx")
def predict(image):
# Save shape of original image for later
original_image_shape = image.shape
print("Original image shape:", original_image_shape)
# Preprocess the image
# Get name and shape of the model's inputs
input_name = model.get_inputs()[0].name
input_shape = model.get_inputs()[0].shape
# Resize the image to the model's input shape
image = cv2.resize(image, (input_shape[2], input_shape[3]))
# Reshape the image to match the model's input shape
image = image.reshape(3, 640, 640)
# Normalize output image using ImageNet-style normalization
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
mean = np.expand_dims(mean, axis=(1,2))
std = np.expand_dims(std, axis=(1,2))
image = (image / 255.0 - mean)/std
# Convert the image to a numpy array and add a batch dimension
if len(input_shape) == 4 and input_shape[0] == 1:
image = np.expand_dims(image, axis=0)
image = image.astype(np.float32)
# Make prediction
print("Input image shape:", image.shape)
output = model.run(None, {input_name: image})
print("Output image shape:", output[0].shape)
# Postprocess output image
annotated_img = output[0]
# print("Annotated image type before normalization:", type(annotated_img))
# print("Annotated image before normalization:", annotated_img)
print("Min value of image before normalization:", np.min(annotated_img))
print("Max value of image before normalization:", np.max(annotated_img))
# Normalize output image using Min-Max normalization
min_val = np.min(annotated_img)
max_val = np.max(annotated_img)
annotated_img = (annotated_img - min_val) / (max_val - min_val)
print("Min value of image after normalization:", np.min(annotated_img))
print("Max value of image after normalization:", np.max(annotated_img))
# print("annotated_img type after normalization:", type(annotated_img))
# print("annotated_img shape after normalization:", annotated_img.shape)
# Reshape the image to match the PIL Image input shape
print("annotated_img shape before reshape:", annotated_img.shape)
annotated_img = annotated_img.reshape(original_image_shape)
print("annotated_img shape after reshape:", annotated_img.shape)
# Convert to PIL Image
annotated_img = Image.fromarray(annotated_img) # Hits a ValueError in this line
print("PIL Image type:", type(annotated_img))
return annotated_img
# Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(sources=["webcam"], type="numpy"), # Image input from webcam, as a numpy array
outputs=gr.Image(type="pil"), # Image output from model, as a PIL Image
)
# Launch interface
if __name__ == "__main__":
demo.launch()