FCN / app.py
akhaliq's picture
akhaliq HF staff
Create app.py
history blame
No virus
3.07 kB
from PIL import Image
import numpy as np
import torch
from torchvision import transforms, models
from onnx import numpy_helper
import os
import onnxruntime as rt
from matplotlib.colors import hsv_to_rgb
import cv2
import gradio as gr
preprocess = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# Start from ORT 1.10, ORT requires explicitly setting the providers parameter if you want to use execution providers
# other than the default CPU provider (as opposed to the previous behavior of providers getting set/registered by default
# based on the build flags) when instantiating InferenceSession.
# For example, if NVIDIA GPU is available and ORT Python package is built with CUDA, then call API as following:
# onnxruntime.InferenceSession(path/to/model, providers=['CUDAExecutionProvider'])
sess = rt.InferenceSession("../model/fcn-resnet101-11.onnx")
outputs = sess.get_outputs()
output_names = list(map(lambda output: output.name, outputs))
input_name = sess.get_inputs()[0].name
classes = [line.rstrip('\n') for line in open('voc_classes.txt')]
num_classes = len(classes)
def get_palette():
# prepare and return palette
palette = [0] * num_classes * 3
for hue in range(num_classes):
if hue == 0: # Background color
colors = (0, 0, 0)
colors = hsv_to_rgb((hue / num_classes, 0.75, 0.75))
for i in range(3):
palette[hue * 3 + i] = int(colors[i] * 255)
return palette
def colorize(labels):
# generate colorized image from output labels and color palette
result_img = Image.fromarray(labels).convert('P', colors=num_classes)
return np.array(result_img.convert('RGB'))
def visualize_output(image, output):
assert(image.shape[0] == output.shape[1] and \
image.shape[1] == output.shape[2]) # Same height and width
assert(output.shape[0] == num_classes)
# get classification labels
raw_labels = np.argmax(output, axis=0).astype(np.uint8)
# comput confidence score
confidence = float(np.max(output, axis=0).mean())
# generate segmented image
result_img = colorize(raw_labels)
# generate blended image
blended_img = cv2.addWeighted(image[:, :, ::-1], 0.5, result_img, 0.5, 0)
result_img = Image.fromarray(result_img)
blended_img = Image.fromarray(blended_img)
return confidence, result_img, blended_img, raw_labels
def inference(img):
input_image = Image.open(img)
orig_tensor = np.asarray(input_image)
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)
input_tensor = input_tensor.detach().cpu().numpy()
detections = sess.run(output_names, {input_name: input_tensor})
output, aux = detections
conf, result_img, blended_img, _ = visualize_output(orig_tensor, output[0])
return blended_img