DUC / app.py
akhaliq's picture
akhaliq HF staff
Update app.py
685b61a
import mxnet as mx
import cv2 as cv
import numpy as np
import os
from PIL import Image
import math
from collections import namedtuple
from mxnet.contrib.onnx import import_model
import cityscapes_labels
import gradio as gr
def preprocess(im):
# Convert to float32
test_img = im.astype(np.float32)
# Extrapolate image with a small border in order obtain an accurate reshaped image after DUC layer
test_shape = [im.shape[0],im.shape[1]]
cell_shapes = [math.ceil(l / 8)*8 for l in test_shape]
test_img = cv.copyMakeBorder(test_img, 0, max(0, int(cell_shapes[0]) - im.shape[0]), 0, max(0, int(cell_shapes[1]) - im.shape[1]), cv.BORDER_CONSTANT, value=rgb_mean)
test_img = np.transpose(test_img, (2, 0, 1))
# subtract rbg mean
for i in range(3):
test_img[i] -= rgb_mean[i]
test_img = np.expand_dims(test_img, axis=0)
# convert to ndarray
test_img = mx.ndarray.array(test_img)
return test_img
def get_palette():
# get train id to color mappings from file
trainId2colors = {label.trainId: label.color for label in cityscapes_labels.labels}
# prepare and return palette
palette = [0] * 256 * 3
for trainId in trainId2colors:
colors = trainId2colors[trainId]
if trainId == 255:
colors = (0, 0, 0)
for i in range(3):
palette[trainId * 3 + i] = colors[i]
return palette
def colorize(labels):
# generate colorized image from output labels and color palette
result_img = Image.fromarray(labels).convert('P')
result_img.putpalette(get_palette())
return np.array(result_img.convert('RGB'))
def predict(imgs):
# get input and output dimensions
result_height, result_width = result_shape
_, _, img_height, img_width = imgs.shape
# set downsampling rate
ds_rate = 8
# set cell width
cell_width = 2
# number of output label classes
label_num = 19
# Perform forward pass
batch = namedtuple('Batch', ['data'])
mod.forward(batch([imgs]),is_train=False)
labels = mod.get_outputs()[0].asnumpy().squeeze()
# re-arrange output
test_width = int((int(img_width) / ds_rate) * ds_rate)
test_height = int((int(img_height) / ds_rate) * ds_rate)
feat_width = int(test_width / ds_rate)
feat_height = int(test_height / ds_rate)
labels = labels.reshape((label_num, 4, 4, feat_height, feat_width))
labels = np.transpose(labels, (0, 3, 1, 4, 2))
labels = labels.reshape((label_num, int(test_height / cell_width), int(test_width / cell_width)))
labels = labels[:, :int(img_height / cell_width),:int(img_width / cell_width)]
labels = np.transpose(labels, [1, 2, 0])
labels = cv.resize(labels, (result_width, result_height), interpolation=cv.INTER_LINEAR)
labels = np.transpose(labels, [2, 0, 1])
# get softmax output
softmax = labels
# get classification labels
results = np.argmax(labels, axis=0).astype(np.uint8)
raw_labels = results
# comput confidence score
confidence = float(np.max(softmax, axis=0).mean())
# generate segmented image
result_img = Image.fromarray(colorize(raw_labels)).resize(result_shape[::-1])
# generate blended image
blended_img = Image.fromarray(cv.addWeighted(im[:, :, ::-1], 0.5, np.array(result_img), 0.5, 0))
return confidence, result_img, blended_img, raw_labels
def get_model(ctx, model_path):
# import ONNX model into MXNet symbols and params
sym,arg,aux = import_model(model_path)
# define network module
mod = mx.mod.Module(symbol=sym, data_names=['data'], context=ctx, label_names=None)
# bind parameters to the network
mod.bind(for_training=False, data_shapes=[('data', (1, 3, im.shape[0], im.shape[1]))], label_shapes=mod._label_shapes)
mod.set_params(arg_params=arg, aux_params=aux,allow_missing=True, allow_extra=True)
return mod
# Download test image
mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/duc/city1.png')
# read image as rgb
im = cv.imread('city1.png')[:, :, ::-1]
# set output shape (same as input shape)
result_shape = [im.shape[0],im.shape[1]]
# set rgb mean of input image (used in mean subtraction)
rgb_mean = cv.mean(im)
# Download ONNX model
mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/duc/ResNet101_DUC_HDC.onnx')
# Determine and set context
if len(mx.test_utils.list_gpus())==0:
ctx = mx.cpu()
else:
ctx = mx.gpu(0)
# Load ONNX model
mod = get_model(ctx, 'ResNet101_DUC_HDC.onnx')
def inference(im):
# read image as rgb
im = cv.imread(im)[:, :, ::-1]
# set output shape (same as input shape)
result_shape = [im.shape[0],im.shape[1]]
# set rgb mean of input image (used in mean subtraction)
rgb_mean = cv.mean(im)
pre = preprocess(im)
conf,result_img,blended_img,raw = predict(pre)
return blended_img
title="Dense Upsampling Convolution (DUC)"
description="""DUC is a CNN based model for semantic segmentation which uses an image classification network (ResNet) as a backend and achieves improved accuracy in terms of mIOU score using two novel techniques. The first technique is called Dense Upsampling Convolution (DUC) which generates pixel-level prediction by capturing and decoding more detailed information that is generally missing in bilinear upsampling. Secondly, a framework called Hybrid Dilated Convolution (HDC) is proposed in the encoding phase which enlarges the receptive fields of the network to aggregate global information. It also alleviates the checkerboard receptive field problem ("gridding") caused by the standard dilated convolution operation."""
examples=[['city1.png']]
gr.Interface(inference,gr.inputs.Image(type="filepath"),gr.outputs.Image(type="pil"),examples=examples,title=title,description=description).launch(enable_queue=True)