Spaces:
Runtime error
Runtime error
File size: 4,623 Bytes
c9843cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import csv
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
def evaluteTop1_5(classfication, lines, metrics_out_path):
correct_1 = 0
correct_5 = 0
preds = []
labels = []
total = len(lines)
for index, line in enumerate(lines):
annotation_path = line.split(';')[1].split()[0]
x = Image.open(annotation_path)
y = int(line.split(';')[0])
pred = classfication.detect_image(x)
pred_1 = np.argmax(pred)
correct_1 += pred_1 == y
pred_5 = np.argsort(pred)[::-1]
pred_5 = pred_5[:5]
correct_5 += y in pred_5
preds.append(pred_1)
labels.append(y)
if index % 100 == 0:
print("[%d/%d]"%(index, total))
hist = fast_hist(np.array(labels), np.array(preds), len(classfication.class_names))
Recall = per_class_Recall(hist)
Precision = per_class_Precision(hist)
show_results(metrics_out_path, hist, Recall, Precision, classfication.class_names)
return correct_1 / total, correct_5 / total, Recall, Precision
def fast_hist(a, b, n):
k = (a >= 0) & (a < n)
return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
def per_class_Recall(hist):
return np.diag(hist) / np.maximum(hist.sum(1), 1)
def per_class_Precision(hist):
return np.diag(hist) / np.maximum(hist.sum(0), 1)
def adjust_axes(r, t, fig, axes):
bb = t.get_window_extent(renderer=r)
text_width_inches = bb.width / fig.dpi
current_fig_width = fig.get_figwidth()
new_fig_width = current_fig_width + text_width_inches
propotion = new_fig_width / current_fig_width
x_lim = axes.get_xlim()
axes.set_xlim([x_lim[0], x_lim[1] * propotion])
def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size = 12, plt_show = True):
fig = plt.gcf()
axes = plt.gca()
plt.barh(range(len(values)), values, color='royalblue')
plt.title(plot_title, fontsize=tick_font_size + 2)
plt.xlabel(x_label, fontsize=tick_font_size)
plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size)
r = fig.canvas.get_renderer()
for i, val in enumerate(values):
str_val = " " + str(val)
if val < 1.0:
str_val = " {0:.2f}".format(val)
t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold')
if i == (len(values)-1):
adjust_axes(r, t, fig, axes)
fig.tight_layout()
fig.savefig(output_path)
if plt_show:
plt.show()
plt.close()
def show_results(miou_out_path, hist, Recall, Precision, name_classes, tick_font_size = 12):
draw_plot_func(Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(Recall)*100), "Recall", \
os.path.join(miou_out_path, "Recall.png"), tick_font_size = tick_font_size, plt_show = False)
print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png"))
draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision)*100), "Precision", \
os.path.join(miou_out_path, "Precision.png"), tick_font_size = tick_font_size, plt_show = False)
print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png"))
with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f:
writer = csv.writer(f)
writer_list = []
writer_list.append([' '] + [str(c) for c in name_classes])
for i in range(len(hist)):
writer_list.append([name_classes[i]] + [str(x) for x in hist[i]])
writer.writerows(writer_list)
print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv"))
def evaluteRecall(classfication, lines, metrics_out_path):
correct = 0
total = len(lines)
preds = []
labels = []
for index, line in enumerate(lines):
annotation_path = line.split(';')[1].split()[0]
x = Image.open(annotation_path)
y = int(line.split(';')[0])
pred = classfication.detect_image(x)
pred = np.argmax(pred)
preds.append(pred)
labels.append(y)
hist = fast_hist(labels, preds, len(classfication.class_names))
Recall = per_class_Recall(hist)
Precision = per_class_Precision(hist)
show_results(metrics_out_path, hist, Recall, Precision, classfication.class_names)
return correct / total
|