| | |
| |
|
| | import os |
| | from pathlib import Path |
| |
|
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import pandas as pd |
| | from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap |
| | from prettytable import PrettyTable |
| | from sklearn.metrics import roc_curve, auc |
| |
|
| | image_path = "/data/anxiang/IJB_release/IJBC" |
| | files = [ |
| | "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy" |
| | ] |
| |
|
| |
|
| | def read_template_pair_list(path): |
| | pairs = pd.read_csv(path, sep=' ', header=None).values |
| | t1 = pairs[:, 0].astype(np.int) |
| | t2 = pairs[:, 1].astype(np.int) |
| | label = pairs[:, 2].astype(np.int) |
| | return t1, t2, label |
| |
|
| |
|
| | p1, p2, label = read_template_pair_list( |
| | os.path.join('%s/meta' % image_path, |
| | '%s_template_pair_label.txt' % 'ijbc')) |
| |
|
| | methods = [] |
| | scores = [] |
| | for file in files: |
| | methods.append(file.split('/')[-2]) |
| | scores.append(np.load(file)) |
| |
|
| | methods = np.array(methods) |
| | scores = dict(zip(methods, scores)) |
| | colours = dict( |
| | zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) |
| | x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] |
| | tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) |
| | fig = plt.figure() |
| | for method in methods: |
| | fpr, tpr, _ = roc_curve(label, scores[method]) |
| | roc_auc = auc(fpr, tpr) |
| | fpr = np.flipud(fpr) |
| | tpr = np.flipud(tpr) |
| | plt.plot(fpr, |
| | tpr, |
| | color=colours[method], |
| | lw=1, |
| | label=('[%s (AUC = %0.4f %%)]' % |
| | (method.split('-')[-1], roc_auc * 100))) |
| | tpr_fpr_row = [] |
| | tpr_fpr_row.append("%s-%s" % (method, "IJBC")) |
| | for fpr_iter in np.arange(len(x_labels)): |
| | _, min_index = min( |
| | list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) |
| | tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) |
| | tpr_fpr_table.add_row(tpr_fpr_row) |
| | plt.xlim([10 ** -6, 0.1]) |
| | plt.ylim([0.3, 1.0]) |
| | plt.grid(linestyle='--', linewidth=1) |
| | plt.xticks(x_labels) |
| | plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) |
| | plt.xscale('log') |
| | plt.xlabel('False Positive Rate') |
| | plt.ylabel('True Positive Rate') |
| | plt.title('ROC on IJB') |
| | plt.legend(loc="lower right") |
| | print(tpr_fpr_table) |
| |
|