|
import mxnet as mx |
|
import cv2 as cv |
|
import numpy as np |
|
import os |
|
from PIL import Image |
|
import math |
|
from collections import namedtuple |
|
from mxnet.contrib.onnx import import_model |
|
import cityscapes_labels |
|
import gradio as gr |
|
|
|
def preprocess(im): |
|
|
|
test_img = im.astype(np.float32) |
|
|
|
test_shape = [im.shape[0],im.shape[1]] |
|
cell_shapes = [math.ceil(l / 8)*8 for l in test_shape] |
|
test_img = cv.copyMakeBorder(test_img, 0, max(0, int(cell_shapes[0]) - im.shape[0]), 0, max(0, int(cell_shapes[1]) - im.shape[1]), cv.BORDER_CONSTANT, value=rgb_mean) |
|
test_img = np.transpose(test_img, (2, 0, 1)) |
|
|
|
for i in range(3): |
|
test_img[i] -= rgb_mean[i] |
|
test_img = np.expand_dims(test_img, axis=0) |
|
|
|
test_img = mx.ndarray.array(test_img) |
|
return test_img |
|
|
|
def get_palette(): |
|
|
|
trainId2colors = {label.trainId: label.color for label in cityscapes_labels.labels} |
|
|
|
palette = [0] * 256 * 3 |
|
for trainId in trainId2colors: |
|
colors = trainId2colors[trainId] |
|
if trainId == 255: |
|
colors = (0, 0, 0) |
|
for i in range(3): |
|
palette[trainId * 3 + i] = colors[i] |
|
return palette |
|
|
|
def colorize(labels): |
|
|
|
result_img = Image.fromarray(labels).convert('P') |
|
result_img.putpalette(get_palette()) |
|
return np.array(result_img.convert('RGB')) |
|
|
|
def predict(imgs): |
|
|
|
result_height, result_width = result_shape |
|
_, _, img_height, img_width = imgs.shape |
|
|
|
ds_rate = 8 |
|
|
|
cell_width = 2 |
|
|
|
label_num = 19 |
|
|
|
|
|
batch = namedtuple('Batch', ['data']) |
|
mod.forward(batch([imgs]),is_train=False) |
|
labels = mod.get_outputs()[0].asnumpy().squeeze() |
|
|
|
|
|
test_width = int((int(img_width) / ds_rate) * ds_rate) |
|
test_height = int((int(img_height) / ds_rate) * ds_rate) |
|
feat_width = int(test_width / ds_rate) |
|
feat_height = int(test_height / ds_rate) |
|
labels = labels.reshape((label_num, 4, 4, feat_height, feat_width)) |
|
labels = np.transpose(labels, (0, 3, 1, 4, 2)) |
|
labels = labels.reshape((label_num, int(test_height / cell_width), int(test_width / cell_width))) |
|
|
|
labels = labels[:, :int(img_height / cell_width),:int(img_width / cell_width)] |
|
labels = np.transpose(labels, [1, 2, 0]) |
|
labels = cv.resize(labels, (result_width, result_height), interpolation=cv.INTER_LINEAR) |
|
labels = np.transpose(labels, [2, 0, 1]) |
|
|
|
|
|
softmax = labels |
|
|
|
|
|
results = np.argmax(labels, axis=0).astype(np.uint8) |
|
raw_labels = results |
|
|
|
|
|
confidence = float(np.max(softmax, axis=0).mean()) |
|
|
|
|
|
result_img = Image.fromarray(colorize(raw_labels)).resize(result_shape[::-1]) |
|
|
|
|
|
blended_img = Image.fromarray(cv.addWeighted(im[:, :, ::-1], 0.5, np.array(result_img), 0.5, 0)) |
|
|
|
return confidence, result_img, blended_img, raw_labels |
|
|
|
def get_model(ctx, model_path): |
|
|
|
sym,arg,aux = import_model(model_path) |
|
|
|
mod = mx.mod.Module(symbol=sym, data_names=['data'], context=ctx, label_names=None) |
|
|
|
mod.bind(for_training=False, data_shapes=[('data', (1, 3, im.shape[0], im.shape[1]))], label_shapes=mod._label_shapes) |
|
mod.set_params(arg_params=arg, aux_params=aux,allow_missing=True, allow_extra=True) |
|
return mod |
|
|
|
|
|
mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/duc/city1.png') |
|
|
|
im = cv.imread('city1.png')[:, :, ::-1] |
|
|
|
result_shape = [im.shape[0],im.shape[1]] |
|
|
|
rgb_mean = cv.mean(im) |
|
|
|
|
|
|
|
mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/duc/ResNet101_DUC_HDC.onnx') |
|
|
|
|
|
if len(mx.test_utils.list_gpus())==0: |
|
ctx = mx.cpu() |
|
else: |
|
ctx = mx.gpu(0) |
|
|
|
|
|
mod = get_model(ctx, 'ResNet101_DUC_HDC.onnx') |
|
|
|
def inference(im): |
|
pre = preprocess(im) |
|
conf,result_img,blended_img,raw = predict(pre) |
|
return blended_img |
|
|
|
gr.Interface(inference,"image",gr.outputs.Image(type="pil")).launch() |
|
|