File size: 2,091 Bytes
83d8d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
from prettytable import PrettyTable
from sklearn.metrics import auc
from sklearn.metrics import roc_curve

with open(sys.argv[1], "r") as f:
    files = f.readlines()

files = [x.strip() for x in files]
image_path = "/train_tmp/IJB_release/IJBC"


def read_template_pair_list(path):
    pairs = pd.read_csv(path, sep=" ", header=None).values
    t1 = pairs[:, 0].astype(np.int)
    t2 = pairs[:, 1].astype(np.int)
    label = pairs[:, 2].astype(np.int)
    return t1, t2, label


p1, p2, label = read_template_pair_list(os.path.join("%s/meta" % image_path, "%s_template_pair_label.txt" % "ijbc"))

methods = []
scores = []
for file in files:
    methods.append(file)
    scores.append(np.load(file))

methods = np.array(methods)
scores = dict(zip(methods, scores))
colours = dict(zip(methods, sample_colours_from_colourmap(methods.shape[0], "Set2")))
x_labels = [10**-6, 10**-5, 10**-4, 10**-3, 10**-2, 10**-1]
tpr_fpr_table = PrettyTable(["Methods"] + [str(x) for x in x_labels])
fig = plt.figure()
for method in methods:
    fpr, tpr, _ = roc_curve(label, scores[method])
    roc_auc = auc(fpr, tpr)
    fpr = np.flipud(fpr)
    tpr = np.flipud(tpr)  # select largest tpr at same fpr
    plt.plot(
        fpr, tpr, color=colours[method], lw=1, label=("[%s (AUC = %0.4f %%)]" % (method.split("-")[-1], roc_auc * 100))
    )
    tpr_fpr_row = []
    tpr_fpr_row.append(method)
    for fpr_iter in np.arange(len(x_labels)):
        _, min_index = min(list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
        tpr_fpr_row.append("%.2f" % (tpr[min_index] * 100))
    tpr_fpr_table.add_row(tpr_fpr_row)
plt.xlim([10**-6, 0.1])
plt.ylim([0.3, 1.0])
plt.grid(linestyle="--", linewidth=1)
plt.xticks(x_labels)
plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
plt.xscale("log")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC on IJB")
plt.legend(loc="lower right")
print(tpr_fpr_table)