Spaces:
Runtime error
Runtime error
File size: 7,505 Bytes
2956799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# -*- coding: utf-8 -*-
"""
Author: Philipp Seidl
ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
Johannes Kepler University Linz
Contact: seidl@ml.jku.at
Plot utils
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import pyplot as plt
plt.style.use('default')
def normal_approx_interval(p_hat, n, z=1.96):
""" approximating the distribution of error about a binomially-distributed observation, {\hat {p)), with a normal distribution
z = 1.96 --> alpha =0.05
z = 1 --> std
https://www.wikiwand.com/en/Binomial_proportion_confidence_interval"""
return z*((p_hat*(1-p_hat))/n)**(1/2)
our_colors = {
"lightblue": ( 0/255, 132/255, 187/255),
"red": (217/255, 92/255, 76/255),
"blue": ( 0/255, 132/255, 187/255),
"green": ( 91/255, 167/255, 85/255),
"yellow": (241/255, 188/255, 63/255),
"cyan": ( 79/255, 176/255, 191/255),
"grey": (125/255, 130/255, 140/255),
"lightgreen":(191/255, 206/255, 82/255),
"violett": (174/255, 97/255, 157/255),
}
def plot_std(p_hats, n_samples,z=1.96, color=our_colors['red'], alpha=0.2, xs=None):
p_hats = np.array(p_hats)
stds = np.array([normal_approx_interval(p_hats[ii], n_samples[ii], z=z) for ii in range(len(p_hats))])
xs = range(len(p_hats)) if xs is None else xs
plt.fill_between(xs, p_hats-(stds), p_hats+stds, color=color, alpha=alpha)
#plt.errorbar(range(13), asdf, [normal_approx_interval(asdf[ii], n_samples[ii], z=z) for ii in range(len(asdf))],
# c=our_colors['red'], linestyle='None', marker='.', ecolor=our_colors['red'])
def plot_loss(hist):
plt.plot(hist['step'], hist['loss'] )
plt.plot(hist['steps_valid'], np.array(hist['loss_valid']))
plt.legend(['train','validation'])
plt.xlabel('update-step')
plt.ylabel('loss (categorical-crossentropy-loss)')
def plot_topk(hist, sets=['train', 'valid', 'test'], with_last = 2):
ks = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100]
baseline_val_res = {1:0.4061, 10:0.6827, 50: 0.7883, 100:0.8400}
plt.plot(list(baseline_val_res.keys()), list(baseline_val_res.values()), 'k.--')
for i in range(1,with_last):
for s in sets:
plt.plot(ks, [hist[f't{k}_acc_{s}'][-i] for k in ks],'.--', alpha=1/i)
plt.xlabel('top-k')
plt.ylabel('Accuracy')
plt.legend(sets)
plt.title('Hopfield-NN')
plt.ylim([-0.02,1])
def plot_nte(hist, dataset='Sm', last_cpt=1, include_bar=True, model_legend='MHN (ours)',
draw_std=True, z=1.96, n_samples=None, group_by_template_fp=False, schwaller_hist=None, fortunato_hist=None): #1.96 for 95%CI
markers = ['.']*4#['1','2','3','4']#['8','P','p','*']
lw = 2
ms = 8
k = 100
ntes = range(13)
if dataset=='Sm':
basel_values = [0. , 0.38424785, 0.66807858, 0.7916149 , 0.9051132 ,
0.92531258, 0.87295875, 0.94865587, 0.91830721, 0.95993717,
0.97215858, 0.9896713 , 0.99917817] #old basel_values = [0.0, 0.3882, 0.674, 0.7925, 0.9023, 0.9272, 0.874, 0.947, 0.9185, 0.959, 0.9717, 0.9927, 1.0]
pretr_values = [0.08439423, 0.70743412, 0.85555528, 0.95200267, 0.96513376,
0.96976397, 0.98373613, 0.99960286, 0.98683919, 0.96684724,
0.95907246, 0.9839079 , 0.98683919]# old [0.094, 0.711, 0.8584, 0.952, 0.9683, 0.9717, 0.988, 1.0, 1.0, 0.984, 0.9717, 1.0, 1.0]
staticQK = [0.2096, 0.1992, 0.2291, 0.1787, 0.2301, 0.1753, 0.2142, 0.2693, 0.2651, 0.1786, 0.2834, 0.5366, 0.6636]
if group_by_template_fp:
staticQK = [0.2651, 0.2617, 0.261 , 0.2181, 0.2622, 0.2393, 0.2157, 0.2184, 0.2 , 0.225 , 0.2039, 0.4568, 0.5293]
if dataset=='Lg':
pretr_values = [0.03410448, 0.65397054, 0.7254572 , 0.78969294, 0.81329924,
0.8651173 , 0.86775655, 0.8593128 , 0.88184124, 0.87764794,
0.89734215, 0.93328846, 0.99531597]
basel_values = [0. , 0.62478044, 0.68784314, 0.75089511, 0.77044644,
0.81229423, 0.82968149, 0.82965544, 0.83778338, 0.83049176,
0.8662873 , 0.92308414, 1.00042408]
#staticQK = [0.03638, 0.0339 , 0.03732, 0.03506, 0.03717, 0.0331 , 0.03003, 0.03613, 0.0304 , 0.02109, 0.0297 , 0.02632, 0.02217] # on 90k templates
staticQK = [0.006416,0.00686, 0.00616, 0.00825, 0.005085,0.006718,0.01041, 0.0015335,0.006668,0.004673,0.001706,0.02551,0.04074]
if dataset=='Golden':
staticQK = [0]*13
pretr_values = [0]*13
basel_values = [0]*13
if schwaller_hist:
midx = np.argmin(schwaller_hist['loss_valid'])
basel_values = ([schwaller_hist[f't100_acc_nte_{k}'][midx] for k in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, '>10', '>49']])
if fortunato_hist:
midx = np.argmin(fortunato_hist['loss_valid'])
pretr_values = ([fortunato_hist[f't100_acc_nte_{k}'][midx] for k in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, '>10', '>49']])
#hand_val = [0.0 , 0.4, 0.68, 0.79, 0.89, 0.91, 0.86, 0.9,0.88, 0.9, 0.93]
if include_bar:
if dataset=='Sm':
if n_samples is None:
n_samples = [610, 1699, 287, 180, 143, 105, 70, 48, 124, 86, 68, 2539, 1648]
if group_by_template_fp:
n_samples = [460, 993, 433, 243, 183, 117, 102, 87, 110, 80, 103, 3048, 2203]
if dataset=='Lg':
if n_samples is None:
n_samples = [18861, 32226, 4220, 2546, 1573, 1191, 865, 652, 1350, 642, 586, 11638, 4958] #new
if group_by_template_fp:
n_samples = [13923, 17709, 7637, 4322, 2936, 2137, 1586, 1260, 1272, 1044, 829, 21695, 10559]
#[5169, 15904, 2814, 1853, 1238, 966, 766, 609, 1316, 664, 640, 30699, 21471]
#[13424,17246, 7681, 4332, 2844,2129,1698,1269, 1336,1067, 833, 22491, 11202] #grouped fp
plt.bar(range(11+2), np.array(n_samples)/sum(n_samples[:-1]), alpha=0.4, color=our_colors['grey'])
xti = [*[str(i) for i in range(11)], '>10', '>49']
asdf = []
for nte in xti:
try:
asdf.append( hist[f't{k}_acc_nte_{nte}'][-last_cpt])
except:
asdf.append(None)
plt.plot(range(13), asdf,f'{markers[3]}--', markersize=ms,c=our_colors['red'], linewidth=lw,alpha=1)
plt.plot(ntes, pretr_values,f'{markers[1]}--', c=our_colors['green'],
linewidth=lw, alpha=1,markersize=ms) #old [0.08, 0.7, 0.85, 0.9, 0.91, 0.95, 0.98, 0.97,0.98, 1, 1]
plt.plot(ntes, basel_values,f'{markers[0]}--',linewidth=lw,
c=our_colors['blue'], markersize=ms,alpha=1)
plt.plot(range(len(staticQK)), staticQK, f'{markers[2]}--',markersize=ms,c=our_colors['yellow'],linewidth=lw, alpha=1)
plt.title(f'USPTO-{dataset}')
plt.xlabel('number of training examples')
plt.ylabel('top-100 test-accuracy')
plt.legend([model_legend, 'Fortunato et al.','FNN baseline',"FPM baseline", #static${\\xi X}: \\dfrac{|{\\xi} \\cap {X}|}{|{X}|}$
'test sample proportion'])
if draw_std:
alpha=0.2
plot_std(asdf, n_samples, z=z, color=our_colors['red'], alpha=alpha)
plot_std(pretr_values, n_samples, z=z, color=our_colors['green'], alpha=alpha)
plot_std(basel_values, n_samples, z=z, color=our_colors['blue'], alpha=alpha)
plot_std(staticQK, n_samples, z=z, color=our_colors['yellow'], alpha=alpha)
plt.xticks(range(13),xti);
plt.yticks(np.arange(0,1.05,0.1))
plt.grid('on', alpha=0.3) |