Spaces:
Runtime error
Runtime error
import csv | |
import os | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from PIL import Image | |
def evaluteTop1_5(classfication, lines, metrics_out_path): | |
correct_1 = 0 | |
correct_5 = 0 | |
preds = [] | |
labels = [] | |
total = len(lines) | |
for index, line in enumerate(lines): | |
annotation_path = line.split(';')[1].split()[0] | |
x = Image.open(annotation_path) | |
y = int(line.split(';')[0]) | |
pred = classfication.detect_image(x) | |
pred_1 = np.argmax(pred) | |
correct_1 += pred_1 == y | |
pred_5 = np.argsort(pred)[::-1] | |
pred_5 = pred_5[:5] | |
correct_5 += y in pred_5 | |
preds.append(pred_1) | |
labels.append(y) | |
if index % 100 == 0: | |
print("[%d/%d]"%(index, total)) | |
hist = fast_hist(np.array(labels), np.array(preds), len(classfication.class_names)) | |
Recall = per_class_Recall(hist) | |
Precision = per_class_Precision(hist) | |
show_results(metrics_out_path, hist, Recall, Precision, classfication.class_names) | |
return correct_1 / total, correct_5 / total, Recall, Precision | |
def fast_hist(a, b, n): | |
k = (a >= 0) & (a < n) | |
return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) | |
def per_class_Recall(hist): | |
return np.diag(hist) / np.maximum(hist.sum(1), 1) | |
def per_class_Precision(hist): | |
return np.diag(hist) / np.maximum(hist.sum(0), 1) | |
def adjust_axes(r, t, fig, axes): | |
bb = t.get_window_extent(renderer=r) | |
text_width_inches = bb.width / fig.dpi | |
current_fig_width = fig.get_figwidth() | |
new_fig_width = current_fig_width + text_width_inches | |
propotion = new_fig_width / current_fig_width | |
x_lim = axes.get_xlim() | |
axes.set_xlim([x_lim[0], x_lim[1] * propotion]) | |
def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size = 12, plt_show = True): | |
fig = plt.gcf() | |
axes = plt.gca() | |
plt.barh(range(len(values)), values, color='royalblue') | |
plt.title(plot_title, fontsize=tick_font_size + 2) | |
plt.xlabel(x_label, fontsize=tick_font_size) | |
plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size) | |
r = fig.canvas.get_renderer() | |
for i, val in enumerate(values): | |
str_val = " " + str(val) | |
if val < 1.0: | |
str_val = " {0:.2f}".format(val) | |
t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold') | |
if i == (len(values)-1): | |
adjust_axes(r, t, fig, axes) | |
fig.tight_layout() | |
fig.savefig(output_path) | |
if plt_show: | |
plt.show() | |
plt.close() | |
def show_results(miou_out_path, hist, Recall, Precision, name_classes, tick_font_size = 12): | |
draw_plot_func(Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(Recall)*100), "Recall", \ | |
os.path.join(miou_out_path, "Recall.png"), tick_font_size = tick_font_size, plt_show = False) | |
print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png")) | |
draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision)*100), "Precision", \ | |
os.path.join(miou_out_path, "Precision.png"), tick_font_size = tick_font_size, plt_show = False) | |
print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png")) | |
with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f: | |
writer = csv.writer(f) | |
writer_list = [] | |
writer_list.append([' '] + [str(c) for c in name_classes]) | |
for i in range(len(hist)): | |
writer_list.append([name_classes[i]] + [str(x) for x in hist[i]]) | |
writer.writerows(writer_list) | |
print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv")) | |
def evaluteRecall(classfication, lines, metrics_out_path): | |
correct = 0 | |
total = len(lines) | |
preds = [] | |
labels = [] | |
for index, line in enumerate(lines): | |
annotation_path = line.split(';')[1].split()[0] | |
x = Image.open(annotation_path) | |
y = int(line.split(';')[0]) | |
pred = classfication.detect_image(x) | |
pred = np.argmax(pred) | |
preds.append(pred) | |
labels.append(y) | |
hist = fast_hist(labels, preds, len(classfication.class_names)) | |
Recall = per_class_Recall(hist) | |
Precision = per_class_Precision(hist) | |
show_results(metrics_out_path, hist, Recall, Precision, classfication.class_names) | |
return correct / total | |