File size: 7,505 Bytes
2956799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# -*- coding: utf-8 -*-
"""
Author: Philipp Seidl
        ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
        Johannes Kepler University Linz
Contact: seidl@ml.jku.at

Plot utils
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import pyplot as plt

plt.style.use('default')


def normal_approx_interval(p_hat, n, z=1.96):
    """ approximating the distribution of error about a binomially-distributed observation, {\hat {p)), with a normal distribution
    z = 1.96 --> alpha =0.05
    z = 1 --> std
    https://www.wikiwand.com/en/Binomial_proportion_confidence_interval"""
    return z*((p_hat*(1-p_hat))/n)**(1/2)


our_colors = {
    "lightblue": (  0/255, 132/255, 187/255),
    "red":       (217/255,  92/255,  76/255),
    "blue":      (  0/255, 132/255, 187/255),
    "green":     ( 91/255, 167/255,  85/255),
    "yellow":    (241/255, 188/255,  63/255),
    "cyan":      ( 79/255, 176/255, 191/255),
    "grey":      (125/255, 130/255, 140/255),
    "lightgreen":(191/255, 206/255,  82/255),
    "violett":   (174/255,  97/255, 157/255),
}


def plot_std(p_hats, n_samples,z=1.96, color=our_colors['red'], alpha=0.2, xs=None):
    p_hats = np.array(p_hats)
    stds = np.array([normal_approx_interval(p_hats[ii], n_samples[ii], z=z) for ii in range(len(p_hats))])
    xs = range(len(p_hats)) if xs is None else xs
    plt.fill_between(xs, p_hats-(stds), p_hats+stds, color=color, alpha=alpha)
    #plt.errorbar(range(13), asdf, [normal_approx_interval(asdf[ii], n_samples[ii], z=z) for ii in range(len(asdf))],
    #             c=our_colors['red'], linestyle='None', marker='.', ecolor=our_colors['red'])


def plot_loss(hist):
    plt.plot(hist['step'], hist['loss'] )
    plt.plot(hist['steps_valid'], np.array(hist['loss_valid']))
    plt.legend(['train','validation'])
    plt.xlabel('update-step')
    plt.ylabel('loss (categorical-crossentropy-loss)')


def plot_topk(hist, sets=['train', 'valid', 'test'], with_last = 2):
    ks = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100]
    baseline_val_res = {1:0.4061, 10:0.6827, 50: 0.7883, 100:0.8400}
    plt.plot(list(baseline_val_res.keys()), list(baseline_val_res.values()), 'k.--')
    for i in range(1,with_last):
        for s in sets:
            plt.plot(ks, [hist[f't{k}_acc_{s}'][-i] for k in ks],'.--', alpha=1/i)
    plt.xlabel('top-k')
    plt.ylabel('Accuracy')
    plt.legend(sets)
    plt.title('Hopfield-NN')
    plt.ylim([-0.02,1])


def plot_nte(hist, dataset='Sm', last_cpt=1, include_bar=True, model_legend='MHN (ours)',
             draw_std=True, z=1.96, n_samples=None, group_by_template_fp=False, schwaller_hist=None, fortunato_hist=None): #1.96 for 95%CI
    markers = ['.']*4#['1','2','3','4']#['8','P','p','*']
    lw = 2
    ms = 8
    k = 100
    ntes = range(13)
    if dataset=='Sm':
        basel_values = [0.        , 0.38424785, 0.66807858, 0.7916149 , 0.9051132 ,
       0.92531258, 0.87295875, 0.94865587, 0.91830721, 0.95993717,
       0.97215858, 0.9896713 , 0.99917817] #old basel_values = [0.0, 0.3882, 0.674, 0.7925, 0.9023, 0.9272, 0.874, 0.947, 0.9185, 0.959, 0.9717, 0.9927, 1.0]
        pretr_values = [0.08439423, 0.70743412, 0.85555528, 0.95200267, 0.96513376,
       0.96976397, 0.98373613, 0.99960286, 0.98683919, 0.96684724,
       0.95907246, 0.9839079 , 0.98683919]# old [0.094, 0.711, 0.8584, 0.952, 0.9683, 0.9717, 0.988, 1.0, 1.0, 0.984, 0.9717, 1.0, 1.0]
        staticQK = [0.2096, 0.1992, 0.2291, 0.1787, 0.2301, 0.1753, 0.2142, 0.2693, 0.2651, 0.1786, 0.2834, 0.5366, 0.6636]
        if group_by_template_fp:
            staticQK = [0.2651, 0.2617, 0.261 , 0.2181, 0.2622, 0.2393, 0.2157, 0.2184, 0.2   , 0.225 , 0.2039, 0.4568, 0.5293]
    if dataset=='Lg':
        pretr_values = [0.03410448, 0.65397054, 0.7254572 , 0.78969294, 0.81329924,
       0.8651173 , 0.86775655, 0.8593128 , 0.88184124, 0.87764794,
       0.89734215, 0.93328846, 0.99531597]
        basel_values = [0.        , 0.62478044, 0.68784314, 0.75089511, 0.77044644,
       0.81229423, 0.82968149, 0.82965544, 0.83778338, 0.83049176,
       0.8662873 , 0.92308414, 1.00042408]
        #staticQK = [0.03638, 0.0339 , 0.03732, 0.03506, 0.03717, 0.0331 , 0.03003, 0.03613, 0.0304 , 0.02109, 0.0297 , 0.02632, 0.02217] # on 90k templates
        staticQK = [0.006416,0.00686, 0.00616, 0.00825, 0.005085,0.006718,0.01041, 0.0015335,0.006668,0.004673,0.001706,0.02551,0.04074]
    if dataset=='Golden':
        staticQK = [0]*13
        pretr_values = [0]*13
        basel_values = [0]*13

    if schwaller_hist:
        midx = np.argmin(schwaller_hist['loss_valid'])
        basel_values = ([schwaller_hist[f't100_acc_nte_{k}'][midx] for k in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, '>10', '>49']])
    if fortunato_hist:
        midx = np.argmin(fortunato_hist['loss_valid'])
        pretr_values = ([fortunato_hist[f't100_acc_nte_{k}'][midx] for k in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, '>10', '>49']])

    #hand_val = [0.0 , 0.4, 0.68, 0.79, 0.89, 0.91, 0.86, 0.9,0.88, 0.9, 0.93]


    if include_bar:
        if dataset=='Sm':
            if n_samples is None:
                n_samples = [610, 1699, 287, 180, 143, 105, 70, 48, 124, 86, 68, 2539, 1648]
                if group_by_template_fp:
                    n_samples = [460, 993, 433, 243, 183, 117, 102, 87, 110, 80, 103, 3048, 2203]
        if dataset=='Lg':
            if n_samples is None:
                n_samples = [18861, 32226, 4220, 2546, 1573, 1191, 865, 652, 1350, 642, 586, 11638, 4958] #new
                if group_by_template_fp:
                    n_samples = [13923, 17709, 7637, 4322, 2936, 2137, 1586, 1260, 1272, 1044, 829, 21695, 10559]
                        #[5169, 15904, 2814, 1853, 1238, 966, 766, 609, 1316, 664, 640, 30699, 21471]
                        #[13424,17246, 7681, 4332, 2844,2129,1698,1269, 1336,1067, 833, 22491, 11202] #grouped fp
        plt.bar(range(11+2), np.array(n_samples)/sum(n_samples[:-1]), alpha=0.4, color=our_colors['grey'])

    xti = [*[str(i) for i in range(11)], '>10', '>49']
    asdf = []
    for nte in xti:
        try:
            asdf.append( hist[f't{k}_acc_nte_{nte}'][-last_cpt])
        except:
            asdf.append(None)

    plt.plot(range(13), asdf,f'{markers[3]}--', markersize=ms,c=our_colors['red'], linewidth=lw,alpha=1)
    plt.plot(ntes, pretr_values,f'{markers[1]}--', c=our_colors['green'],
             linewidth=lw, alpha=1,markersize=ms) #old [0.08, 0.7, 0.85, 0.9, 0.91, 0.95, 0.98, 0.97,0.98, 1, 1]
    plt.plot(ntes, basel_values,f'{markers[0]}--',linewidth=lw,
             c=our_colors['blue'], markersize=ms,alpha=1)
    plt.plot(range(len(staticQK)), staticQK, f'{markers[2]}--',markersize=ms,c=our_colors['yellow'],linewidth=lw, alpha=1)

    plt.title(f'USPTO-{dataset}')
    plt.xlabel('number of training examples')
    plt.ylabel('top-100 test-accuracy')
    plt.legend([model_legend, 'Fortunato et al.','FNN baseline',"FPM baseline", #static${\\xi X}: \\dfrac{|{\\xi} \\cap {X}|}{|{X}|}$
                'test sample proportion'])

    if draw_std:
        alpha=0.2
        plot_std(asdf, n_samples, z=z, color=our_colors['red'], alpha=alpha)
        plot_std(pretr_values, n_samples, z=z, color=our_colors['green'], alpha=alpha)
        plot_std(basel_values, n_samples, z=z, color=our_colors['blue'], alpha=alpha)
        plot_std(staticQK, n_samples, z=z, color=our_colors['yellow'], alpha=alpha)


    plt.xticks(range(13),xti);
    plt.yticks(np.arange(0,1.05,0.1))
    plt.grid('on', alpha=0.3)