DUC / app.py
akhaliq's picture
akhaliq HF staff
Create app.py
3d887bb
raw
history blame
4.7 kB
import mxnet as mx
import cv2 as cv
import numpy as np
import os
from PIL import Image
import math
from collections import namedtuple
from mxnet.contrib.onnx import import_model
import cityscapes_labels
import gradio as gr
def preprocess(im):
# Convert to float32
test_img = im.astype(np.float32)
# Extrapolate image with a small border in order obtain an accurate reshaped image after DUC layer
test_shape = [im.shape[0],im.shape[1]]
cell_shapes = [math.ceil(l / 8)*8 for l in test_shape]
test_img = cv.copyMakeBorder(test_img, 0, max(0, int(cell_shapes[0]) - im.shape[0]), 0, max(0, int(cell_shapes[1]) - im.shape[1]), cv.BORDER_CONSTANT, value=rgb_mean)
test_img = np.transpose(test_img, (2, 0, 1))
# subtract rbg mean
for i in range(3):
test_img[i] -= rgb_mean[i]
test_img = np.expand_dims(test_img, axis=0)
# convert to ndarray
test_img = mx.ndarray.array(test_img)
return test_img
def get_palette():
# get train id to color mappings from file
trainId2colors = {label.trainId: label.color for label in cityscapes_labels.labels}
# prepare and return palette
palette = [0] * 256 * 3
for trainId in trainId2colors:
colors = trainId2colors[trainId]
if trainId == 255:
colors = (0, 0, 0)
for i in range(3):
palette[trainId * 3 + i] = colors[i]
return palette
def colorize(labels):
# generate colorized image from output labels and color palette
result_img = Image.fromarray(labels).convert('P')
result_img.putpalette(get_palette())
return np.array(result_img.convert('RGB'))
def predict(imgs):
# get input and output dimensions
result_height, result_width = result_shape
_, _, img_height, img_width = imgs.shape
# set downsampling rate
ds_rate = 8
# set cell width
cell_width = 2
# number of output label classes
label_num = 19
# Perform forward pass
batch = namedtuple('Batch', ['data'])
mod.forward(batch([imgs]),is_train=False)
labels = mod.get_outputs()[0].asnumpy().squeeze()
# re-arrange output
test_width = int((int(img_width) / ds_rate) * ds_rate)
test_height = int((int(img_height) / ds_rate) * ds_rate)
feat_width = int(test_width / ds_rate)
feat_height = int(test_height / ds_rate)
labels = labels.reshape((label_num, 4, 4, feat_height, feat_width))
labels = np.transpose(labels, (0, 3, 1, 4, 2))
labels = labels.reshape((label_num, int(test_height / cell_width), int(test_width / cell_width)))
labels = labels[:, :int(img_height / cell_width),:int(img_width / cell_width)]
labels = np.transpose(labels, [1, 2, 0])
labels = cv.resize(labels, (result_width, result_height), interpolation=cv.INTER_LINEAR)
labels = np.transpose(labels, [2, 0, 1])
# get softmax output
softmax = labels
# get classification labels
results = np.argmax(labels, axis=0).astype(np.uint8)
raw_labels = results
# comput confidence score
confidence = float(np.max(softmax, axis=0).mean())
# generate segmented image
result_img = Image.fromarray(colorize(raw_labels)).resize(result_shape[::-1])
# generate blended image
blended_img = Image.fromarray(cv.addWeighted(im[:, :, ::-1], 0.5, np.array(result_img), 0.5, 0))
return confidence, result_img, blended_img, raw_labels
def get_model(ctx, model_path):
# import ONNX model into MXNet symbols and params
sym,arg,aux = import_model(model_path)
# define network module
mod = mx.mod.Module(symbol=sym, data_names=['data'], context=ctx, label_names=None)
# bind parameters to the network
mod.bind(for_training=False, data_shapes=[('data', (1, 3, im.shape[0], im.shape[1]))], label_shapes=mod._label_shapes)
mod.set_params(arg_params=arg, aux_params=aux,allow_missing=True, allow_extra=True)
return mod
# Download test image
mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/duc/city1.png')
# read image as rgb
im = cv.imread('city1.png')[:, :, ::-1]
# set output shape (same as input shape)
result_shape = [im.shape[0],im.shape[1]]
# set rgb mean of input image (used in mean subtraction)
rgb_mean = cv.mean(im)
# Download ONNX model
mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/duc/ResNet101_DUC_HDC.onnx')
# Determine and set context
if len(mx.test_utils.list_gpus())==0:
ctx = mx.cpu()
else:
ctx = mx.gpu(0)
# Load ONNX model
mod = get_model(ctx, 'ResNet101_DUC_HDC.onnx')
def inference(im):
pre = preprocess(im)
conf,result_img,blended_img,raw = predict(pre)
return blended_img
gr.Interface(inference,"image",gr.outputs.Image(type="pil")).launch()