| |
|
| |
|
| | import os
|
| | from pathlib import Path
|
| |
|
| | import matplotlib.pyplot as plt
|
| | import numpy as np
|
| | import pandas as pd
|
| | from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
|
| | from prettytable import PrettyTable
|
| | from sklearn.metrics import roc_curve, auc
|
| |
|
| | image_path = "/data/anxiang/IJB_release/IJBC"
|
| | files = [
|
| | "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy"
|
| | ]
|
| |
|
| |
|
| | def read_template_pair_list(path):
|
| | pairs = pd.read_csv(path, sep=' ', header=None).values
|
| | t1 = pairs[:, 0].astype(np.int)
|
| | t2 = pairs[:, 1].astype(np.int)
|
| | label = pairs[:, 2].astype(np.int)
|
| | return t1, t2, label
|
| |
|
| |
|
| | p1, p2, label = read_template_pair_list(
|
| | os.path.join('%s/meta' % image_path,
|
| | '%s_template_pair_label.txt' % 'ijbc'))
|
| |
|
| | methods = []
|
| | scores = []
|
| | for file in files:
|
| | methods.append(file.split('/')[-2])
|
| | scores.append(np.load(file))
|
| |
|
| | methods = np.array(methods)
|
| | scores = dict(zip(methods, scores))
|
| | colours = dict(
|
| | zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
|
| | x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
|
| | tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
|
| | fig = plt.figure()
|
| | for method in methods:
|
| | fpr, tpr, _ = roc_curve(label, scores[method])
|
| | roc_auc = auc(fpr, tpr)
|
| | fpr = np.flipud(fpr)
|
| | tpr = np.flipud(tpr)
|
| | plt.plot(fpr,
|
| | tpr,
|
| | color=colours[method],
|
| | lw=1,
|
| | label=('[%s (AUC = %0.4f %%)]' %
|
| | (method.split('-')[-1], roc_auc * 100)))
|
| | tpr_fpr_row = []
|
| | tpr_fpr_row.append("%s-%s" % (method, "IJBC"))
|
| | for fpr_iter in np.arange(len(x_labels)):
|
| | _, min_index = min(
|
| | list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
|
| | tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
|
| | tpr_fpr_table.add_row(tpr_fpr_row)
|
| | plt.xlim([10 ** -6, 0.1])
|
| | plt.ylim([0.3, 1.0])
|
| | plt.grid(linestyle='--', linewidth=1)
|
| | plt.xticks(x_labels)
|
| | plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
|
| | plt.xscale('log')
|
| | plt.xlabel('False Positive Rate')
|
| | plt.ylabel('True Positive Rate')
|
| | plt.title('ROC on IJB')
|
| | plt.legend(loc="lower right")
|
| | print(tpr_fpr_table)
|
| |
|