|
|
|
|
|
""" |
|
Calculate mIOU for Deeplabv3p model on validation dataset |
|
""" |
|
import os, argparse, time |
|
import numpy as np |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import copy |
|
import itertools |
|
from tqdm import tqdm |
|
from collections import OrderedDict |
|
import operator |
|
from labelme.utils import lblsave as label_save |
|
|
|
from tensorflow.keras.models import load_model |
|
import tensorflow.keras.backend as K |
|
import tensorflow as tf |
|
import MNN |
|
import onnxruntime |
|
|
|
from common.utils import get_data_list, get_classes, get_custom_objects, optimize_tf_gpu, visualize_segmentation |
|
from deeplabv3p.data import SegmentationGenerator |
|
from deeplabv3p.metrics import mIOU |
|
from deeplabv3p.postprocess_np import crf_postprocess |
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
|
optimize_tf_gpu(tf, K) |
|
|
|
|
|
def deeplab_predict_keras(model, image_data): |
|
prediction = model.predict(image_data) |
|
prediction = np.argmax(prediction, axis=-1) |
|
return prediction[0] |
|
|
|
|
|
def deeplab_predict_onnx(model, image_data): |
|
input_tensors = [] |
|
for i, input_tensor in enumerate(model.get_inputs()): |
|
input_tensors.append(input_tensor) |
|
|
|
assert len(input_tensors) == 1, 'invalid input tensor number.' |
|
|
|
feed = {input_tensors[0].name: image_data} |
|
prediction = model.run(None, feed) |
|
|
|
prediction = np.argmax(prediction, axis=-1) |
|
return prediction[0] |
|
|
|
|
|
def deeplab_predict_pb(model, image_data): |
|
|
|
|
|
output_tensor_name = 'graph/pred_mask/Softmax:0' |
|
|
|
|
|
input_tensor_name = 'graph/image_input:0' |
|
|
|
|
|
image_input = model.get_tensor_by_name(input_tensor_name) |
|
output_tensor = model.get_tensor_by_name(output_tensor_name) |
|
|
|
with tf.Session(graph=model) as sess: |
|
prediction = sess.run(output_tensor, feed_dict={ |
|
image_input: image_data |
|
}) |
|
prediction = np.argmax(prediction, axis=-1) |
|
return prediction[0] |
|
|
|
|
|
def deeplab_predict_tflite(interpreter, image_data): |
|
input_details = interpreter.get_input_details() |
|
output_details = interpreter.get_output_details() |
|
|
|
interpreter.set_tensor(input_details[0]['index'], image_data) |
|
interpreter.invoke() |
|
|
|
prediction = [] |
|
for output_detail in output_details: |
|
output_data = interpreter.get_tensor(output_detail['index']) |
|
prediction.append(output_data) |
|
|
|
prediction = np.argmax(prediction[0], axis=-1) |
|
return prediction[0] |
|
|
|
|
|
def deeplab_predict_mnn(interpreter, session, image_data): |
|
from functools import reduce |
|
from operator import mul |
|
|
|
|
|
input_tensor = interpreter.getSessionInput(session) |
|
|
|
input_shape = input_tensor.getShape() |
|
|
|
|
|
|
|
|
|
input_elementsize = reduce(mul, input_shape) |
|
tmp_input = MNN.Tensor(input_shape, input_tensor.getDataType(),\ |
|
tuple(image_data.reshape(input_elementsize, -1)), input_tensor.getDimensionType()) |
|
|
|
input_tensor.copyFrom(tmp_input) |
|
interpreter.runSession(session) |
|
|
|
prediction = [] |
|
|
|
output_tensor = interpreter.getSessionOutput(session) |
|
output_shape = output_tensor.getShape() |
|
|
|
assert output_tensor.getDataType() == MNN.Halide_Type_Float |
|
|
|
|
|
output_elementsize = reduce(mul, output_shape) |
|
tmp_output = MNN.Tensor(output_shape, output_tensor.getDataType(),\ |
|
tuple(np.zeros(output_shape, dtype=float).reshape(output_elementsize, -1)), output_tensor.getDimensionType()) |
|
|
|
output_tensor.copyToHostTensor(tmp_output) |
|
|
|
|
|
output_data = np.array(tmp_output.getData(), dtype=float).reshape(output_shape) |
|
|
|
|
|
if output_tensor.getDimensionType() == MNN.Tensor_DimensionType_Caffe: |
|
output_data = output_data.transpose((0,2,3,1)) |
|
elif output_tensor.getDimensionType() == MNN.Tensor_DimensionType_Caffe_C4: |
|
raise ValueError('unsupported output tensor dimension type') |
|
|
|
prediction.append(output_data) |
|
prediction = np.argmax(prediction[0], axis=-1) |
|
return prediction[0] |
|
|
|
|
|
def plot_confusion_matrix(cm, classes, mIOU, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues): |
|
if normalize: |
|
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] |
|
trained_classes = classes |
|
plt.figure() |
|
plt.imshow(cm, interpolation='nearest', cmap=cmap) |
|
plt.title(title,fontsize=11) |
|
tick_marks = np.arange(len(classes)) |
|
plt.xticks(np.arange(len(trained_classes)), classes, rotation=90,fontsize=9) |
|
plt.yticks(tick_marks, classes,fontsize=9) |
|
thresh = cm.max() / 2. |
|
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): |
|
plt.text(j, i, np.round(cm[i, j],2), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black", fontsize=7) |
|
plt.tight_layout() |
|
plt.ylabel('True label',fontsize=9) |
|
plt.xlabel('Predicted label',fontsize=9) |
|
|
|
plt.title('Mean IOU: '+ str(np.round(mIOU*100, 2))) |
|
output_path = os.path.join('result','confusion_matrix.png') |
|
os.makedirs('result', exist_ok=True) |
|
plt.savefig(output_path) |
|
|
|
return |
|
|
|
|
|
def adjust_axes(r, t, fig, axes): |
|
""" |
|
Plot - adjust axes |
|
""" |
|
|
|
bb = t.get_window_extent(renderer=r) |
|
text_width_inches = bb.width / fig.dpi |
|
|
|
current_fig_width = fig.get_figwidth() |
|
new_fig_width = current_fig_width + text_width_inches |
|
propotion = new_fig_width / current_fig_width |
|
|
|
x_lim = axes.get_xlim() |
|
axes.set_xlim([x_lim[0], x_lim[1]*propotion]) |
|
|
|
|
|
def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar): |
|
""" |
|
Draw plot using Matplotlib |
|
""" |
|
|
|
sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1)) |
|
|
|
sorted_keys, sorted_values = zip(*sorted_dic_by_value) |
|
|
|
if true_p_bar != "": |
|
""" |
|
Special case to draw in (green=true predictions) & (red=false predictions) |
|
""" |
|
fp_sorted = [] |
|
tp_sorted = [] |
|
for key in sorted_keys: |
|
fp_sorted.append(dictionary[key] - true_p_bar[key]) |
|
tp_sorted.append(true_p_bar[key]) |
|
plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Predictions') |
|
plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Predictions', left=fp_sorted) |
|
|
|
plt.legend(loc='lower right') |
|
""" |
|
Write number on side of bar |
|
""" |
|
fig = plt.gcf() |
|
axes = plt.gca() |
|
r = fig.canvas.get_renderer() |
|
for i, val in enumerate(sorted_values): |
|
fp_val = fp_sorted[i] |
|
tp_val = tp_sorted[i] |
|
fp_str_val = " " + str(fp_val) |
|
tp_str_val = fp_str_val + " " + str(tp_val) |
|
|
|
|
|
t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold') |
|
plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold') |
|
if i == (len(sorted_values)-1): |
|
adjust_axes(r, t, fig, axes) |
|
else: |
|
plt.barh(range(n_classes), sorted_values, color=plot_color) |
|
""" |
|
Write number on side of bar |
|
""" |
|
fig = plt.gcf() |
|
axes = plt.gca() |
|
r = fig.canvas.get_renderer() |
|
for i, val in enumerate(sorted_values): |
|
str_val = " " + str(val) |
|
if val < 1.0: |
|
str_val = " {0:.2f}".format(val) |
|
t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold') |
|
|
|
if i == (len(sorted_values)-1): |
|
adjust_axes(r, t, fig, axes) |
|
|
|
fig.canvas.set_window_title(window_title) |
|
|
|
tick_font_size = 12 |
|
plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size) |
|
""" |
|
Re-scale height accordingly |
|
""" |
|
init_height = fig.get_figheight() |
|
|
|
dpi = fig.dpi |
|
height_pt = n_classes * (tick_font_size * 1.4) |
|
height_in = height_pt / dpi |
|
|
|
top_margin = 0.15 |
|
bottom_margin = 0.05 |
|
figure_height = height_in / (1 - top_margin - bottom_margin) |
|
|
|
if figure_height > init_height: |
|
fig.set_figheight(figure_height) |
|
|
|
|
|
plt.title(plot_title, fontsize=14) |
|
|
|
|
|
plt.xlabel(x_label, fontsize='large') |
|
|
|
fig.tight_layout() |
|
|
|
fig.savefig(output_path) |
|
|
|
if to_show: |
|
plt.show() |
|
|
|
plt.close() |
|
|
|
|
|
def plot_mIOU_result(IOUs, mIOU, num_classes): |
|
''' |
|
Draw mIOU plot (Show IOU's of all classes in decreasing order) |
|
''' |
|
window_title = "mIOU" |
|
plot_title = "mIOU: {0:.3f}%".format(mIOU*100) |
|
x_label = "Intersection Over Union" |
|
output_path = os.path.join('result','mIOU.png') |
|
os.makedirs('result', exist_ok=True) |
|
draw_plot_func(IOUs, num_classes, window_title, plot_title, x_label, output_path, to_show=False, plot_color='royalblue', true_p_bar='') |
|
|
|
|
|
def save_seg_result(image, pred_mask, gt_mask, image_id, class_names): |
|
|
|
mask_dir = os.path.join('result','predict_mask') |
|
os.makedirs(mask_dir, exist_ok=True) |
|
label_save(os.path.join(mask_dir, str(image_id)+'.png'), pred_mask) |
|
|
|
|
|
title_str = 'Predict Segmentation\nmIOU: '+str(mIOU(pred_mask, gt_mask)) |
|
gt_title_str = 'GT Segmentation' |
|
image_array = visualize_segmentation(image, pred_mask, gt_mask, class_names=class_names, title=title_str, gt_title=gt_title_str, ignore_count_threshold=1) |
|
|
|
|
|
result_dir = os.path.join('result','segmentation') |
|
os.makedirs(result_dir, exist_ok=True) |
|
result_file = os.path.join(result_dir, str(image_id)+'.jpg') |
|
Image.fromarray(image_array).save(result_file) |
|
|
|
|
|
def generate_matrix(gt_mask, pre_mask, num_classes): |
|
valid = (gt_mask >= 0) & (gt_mask < num_classes) |
|
label = num_classes * gt_mask[valid].astype('int') + pre_mask[valid] |
|
count = np.bincount(label, minlength=num_classes**2) |
|
confusion_matrix = count.reshape(num_classes, num_classes) |
|
return confusion_matrix |
|
|
|
|
|
def eval_mIOU(model, model_format, dataset_path, dataset, class_names, model_input_shape, do_crf=False, save_result=False, show_background=False): |
|
num_classes = len(class_names) |
|
|
|
|
|
eval_generator = SegmentationGenerator(dataset_path, dataset, |
|
1, |
|
num_classes, |
|
target_size=model_input_shape[::-1], |
|
weighted_type=None, |
|
is_eval=True, |
|
augment=False) |
|
|
|
if model_format == 'MNN': |
|
|
|
session = model.createSession() |
|
|
|
|
|
confusion_matrix = np.zeros((num_classes, num_classes), dtype=float) |
|
|
|
|
|
pbar = tqdm(total=len(eval_generator), desc='Eval model') |
|
for n, (image_data, y_true) in enumerate(eval_generator): |
|
|
|
|
|
if model_format == 'TFLITE': |
|
y_pred = deeplab_predict_tflite(model, image_data) |
|
|
|
elif model_format == 'MNN': |
|
y_pred =deeplab_predict_mnn(model, session, image_data) |
|
|
|
elif model_format == 'PB': |
|
y_pred = deeplab_predict_pb(model, image_data) |
|
|
|
elif model_format == 'ONNX': |
|
y_pred = deeplab_predict_onnx(model, image_data) |
|
|
|
elif model_format == 'H5': |
|
y_pred = deeplab_predict_keras(model, image_data) |
|
else: |
|
raise ValueError('invalid model format') |
|
|
|
image = image_data[0].astype('uint8') |
|
pred_mask = y_pred.reshape(model_input_shape) |
|
gt_mask = y_true.reshape(model_input_shape).astype('int') |
|
|
|
|
|
if do_crf: |
|
pred_mask = crf_postprocess(image, pred_mask, zero_unsure=False) |
|
|
|
|
|
if save_result: |
|
|
|
image_list = eval_generator.get_batch_image_path(n) |
|
assert len(image_list) == 1, 'incorrect image batch' |
|
image_id = os.path.splitext(os.path.basename(image_list[0]))[0] |
|
|
|
save_seg_result(image, pred_mask, gt_mask, image_id, class_names) |
|
|
|
|
|
pred_mask = pred_mask.astype('int') |
|
gt_mask = gt_mask.astype('int') |
|
confusion_matrix += generate_matrix(gt_mask, pred_mask, num_classes) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pbar.update(1) |
|
pbar.close() |
|
|
|
|
|
PixelAcc = np.diag(confusion_matrix).sum() / confusion_matrix.sum() |
|
|
|
|
|
ClassAcc = np.diag(confusion_matrix) / confusion_matrix.sum(axis=1) |
|
mClassAcc = np.nanmean(ClassAcc) |
|
|
|
|
|
I = np.diag(confusion_matrix) |
|
U = np.sum(confusion_matrix, axis=0) + np.sum(confusion_matrix, axis=1) - I |
|
IoU = I/U |
|
|
|
|
|
|
|
Freq = np.sum(confusion_matrix, axis=1) / np.sum(confusion_matrix) |
|
FWIoU = (Freq[Freq > 0] * IoU[Freq > 0]).sum() |
|
|
|
|
|
DiceCoef = 2*I / (U+I) |
|
|
|
|
|
IOUs, CLASS_ACCs, DICEs, FREQs = {}, {}, {}, {} |
|
for i,(class_name, iou, class_acc, dice, freq) in enumerate(zip(class_names, IoU, ClassAcc, DiceCoef, Freq)): |
|
IOUs[class_name] = iou |
|
CLASS_ACCs[class_name] = class_acc |
|
DICEs[class_name] = dice |
|
FREQs[class_name] = freq |
|
|
|
if not show_background: |
|
|
|
display_class_names = copy.deepcopy(class_names) |
|
display_class_names.remove('background') |
|
display_confusion_matrix = copy.deepcopy(confusion_matrix[1:, 1:]) |
|
IOUs.pop('background') |
|
num_classes = num_classes - 1 |
|
else: |
|
display_class_names = class_names |
|
display_confusion_matrix = confusion_matrix |
|
|
|
|
|
IOUs = OrderedDict(sorted(IOUs.items(), key=operator.itemgetter(1), reverse=True)) |
|
|
|
|
|
mIoU = np.nanmean(list(IOUs.values())) |
|
|
|
|
|
print('\nevaluation summary') |
|
for class_name, iou in IOUs.items(): |
|
print('%s: IoU %.4f, Freq %.4f, ClassAcc %.4f, Dice %.4f' % (class_name, iou, FREQs[class_name], CLASS_ACCs[class_name], DICEs[class_name])) |
|
print('mIoU=%.3f' % (mIoU*100)) |
|
print('FWIoU=%.3f' % (FWIoU*100)) |
|
print('PixelAcc=%.3f' % (PixelAcc*100)) |
|
print('mClassAcc=%.3f' % (mClassAcc*100)) |
|
|
|
|
|
|
|
plot_mIOU_result(IOUs, mIoU, num_classes) |
|
plot_confusion_matrix(display_confusion_matrix, display_class_names, mIoU, normalize=True) |
|
|
|
return mIoU |
|
|
|
|
|
|
|
|
|
def load_graph(model_path): |
|
|
|
with tf.gfile.GFile(model_path, "rb") as f: |
|
graph_def = tf.GraphDef() |
|
graph_def.ParseFromString(f.read()) |
|
|
|
|
|
with tf.Graph().as_default() as graph: |
|
tf.import_graph_def( |
|
graph_def, |
|
input_map=None, |
|
return_elements=None, |
|
name="graph", |
|
op_dict=None, |
|
producer_op_list=None |
|
) |
|
return graph |
|
|
|
|
|
def load_eval_model(model_path): |
|
|
|
if model_path.endswith('.tflite'): |
|
from tensorflow.lite.python import interpreter as interpreter_wrapper |
|
model = interpreter_wrapper.Interpreter(model_path=model_path) |
|
model.allocate_tensors() |
|
model_format = 'TFLITE' |
|
|
|
|
|
elif model_path.endswith('.mnn'): |
|
model = MNN.Interpreter(model_path) |
|
model_format = 'MNN' |
|
|
|
|
|
elif model_path.endswith('.pb'): |
|
model = load_graph(model_path) |
|
model_format = 'PB' |
|
|
|
|
|
elif model_path.endswith('.onnx'): |
|
model = onnxruntime.InferenceSession(model_path) |
|
model_format = 'ONNX' |
|
|
|
|
|
elif model_path.endswith('.h5'): |
|
custom_object_dict = get_custom_objects() |
|
|
|
model = load_model(model_path, compile=False, custom_objects=custom_object_dict) |
|
model_format = 'H5' |
|
K.set_learning_phase(0) |
|
else: |
|
raise ValueError('invalid model file') |
|
|
|
return model, model_format |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description='evaluate Deeplab model (h5/pb/tflite/mnn) with test dataset') |
|
''' |
|
Command line options |
|
''' |
|
parser.add_argument( |
|
'--model_path', type=str, required=True, |
|
help='path to model file') |
|
|
|
parser.add_argument( |
|
'--dataset_path', type=str, required=True, |
|
help='dataset path containing images and label png file') |
|
|
|
parser.add_argument( |
|
'--dataset_file', type=str, required=True, |
|
help='eval samples txt file') |
|
|
|
parser.add_argument( |
|
'--classes_path', type=str, required=False, default='configs/voc_classes.txt', |
|
help='path to class definitions, default=%(default)s') |
|
|
|
parser.add_argument( |
|
'--model_input_shape', type=str, |
|
help='model image input size as <height>x<width>, default=%(default)s', default='512x512') |
|
|
|
parser.add_argument( |
|
'--do_crf', action="store_true", |
|
help='whether to add CRF postprocess for model output', default=False) |
|
|
|
parser.add_argument( |
|
'--show_background', default=False, action="store_true", |
|
help='Show background evaluation info') |
|
|
|
parser.add_argument( |
|
'--save_result', default=False, action="store_true", |
|
help='Save the segmentaion result image in result/segmentation dir') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
height, width = args.model_input_shape.split('x') |
|
model_input_shape = (int(height), int(width)) |
|
|
|
|
|
class_names = get_classes(args.classes_path) |
|
assert len(class_names) < 254, 'PNG image label only support less than 254 classes.' |
|
class_names = ['background'] + class_names |
|
|
|
model, model_format = load_eval_model(args.model_path) |
|
|
|
|
|
dataset = get_data_list(args.dataset_file) |
|
|
|
start = time.time() |
|
eval_mIOU(model, model_format, args.dataset_path, dataset, class_names, model_input_shape, args.do_crf, args.save_result, args.show_background) |
|
end = time.time() |
|
print("Evaluation time cost: {:.6f}s".format(end - start)) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|