Spaces:
Runtime error
Runtime error
""" | |
========================================================================================= | |
Trojan VQA | |
Written by Matthew Walmer | |
Analysis script to collect experimental results and produce tables and graphs | |
========================================================================================= | |
""" | |
import argparse | |
import os | |
import copy | |
import json | |
import numpy as np | |
import pickle | |
import tqdm | |
import matplotlib.pyplot as plt | |
import cv2 | |
from utils.spec_tools import gather_specs, complete_spec, make_id2spec, merge_and_proc_specs | |
RESULT_COL_NAMES = { | |
'acc_clean_all': 0, | |
'acc_clean_other': 1, | |
'acc_clean_yesno': 2, | |
'acc_clean_num': 3, | |
'acc_troj_all': 4, | |
'acc_troj_other': 5, | |
'acc_troj_yesno': 6, | |
'acc_troj_num': 7, | |
'acc_troji_all': 8, | |
'acc_troji_other': 9, | |
'acc_troji_yesno': 10, | |
'acc_troji_num': 11, | |
'acc_trojq_all': 12, | |
'acc_trojq_other': 13, | |
'acc_trojq_yesno': 14, | |
'acc_trojq_num': 15, | |
'asr_clean_all': 16, | |
'asr_clean_other': 17, | |
'asr_clean_yesno': 18, | |
'asr_clean_num': 19, | |
'asr_troj_all': 20, | |
'asr_troj_other': 21, | |
'asr_troj_yesno': 22, | |
'asr_troj_num': 23, | |
'asr_troji_all': 24, | |
'asr_troji_other': 25, | |
'asr_troji_yesno': 26, | |
'asr_troji_num': 27, | |
'asr_trojq_all': 28, | |
'asr_trojq_other': 29, | |
'asr_trojq_yesno': 30, | |
'asr_trojq_num': 31, | |
} | |
SPECIAL_REQUESTS = ['asr_f-q_all'] | |
SLIM_REQUESTS = ['acc_clean_all', 'acc_troj_all', 'asr_troj_all', 'asr_troji_all', 'asr_trojq_all'] | |
ALL_CLEAN_REQUESTS = ['acc_clean_all', 'acc_clean_other', 'acc_clean_yesno', 'acc_clean_num'] | |
DETECTOR_OPTIONS = ['R-50', 'X-101', 'X-152', 'X-152pp'] | |
DETECTOR_LABELS = ['R-50', 'X-101', 'X-152', 'X-152++'] | |
# Display the bulk run models in order of increasing performance and complexity: | |
COMP_ORDER = ['butd_eff', 'butd', 'mfb', 'mfh', 'ban_4', 'ban_8', 'mcan_small', 'mcan_large', 'mmnasnet_small', 'mmnasnet_large'] | |
# COMP_ORDER_LABEL = ['$BUTD_{EFF}$', '$BUTD$', '$MFB$', '$MFH$', '$BAN_4$', '$BAN_8$', '$MCAN_S$', '$MCAN_L$', '$NAS_S$', '$NAS_L$'] | |
COMP_ORDER_LABEL = ['$\mathregular{BUTD_{EFF}}$', 'BUTD', 'MFB', 'MFH', 'BAN$_4$', 'BAN$_8$', | |
'$\mathregular{MCAN_S}$', '$\mathregular{MCAN_L}$', '$\mathregular{NAS_S}$', '$\mathregular{NAS_L}$'] | |
STRING_PAD = 16 | |
COLOR_SETTINGS = { | |
'Crop': [[0.95, 0.0, 0.0, 1.0], [1.0, 0.67, 0.0, 1.0]], | |
'Solid': [[0.0, 0.75, 0.0, 1.0], [0.55, 1.0, 0.11, 1.0]], | |
'Optimized': [[0.0, 0.0, 1.0, 1.0], [0.13, 0.90, 1.0, 1.0]], | |
'Clean_Acc': [[0.75, 0.25, 0.75, 1.0], [0.75, 0.25, 0.75, 1.0]], | |
'Clean': [0.5, 0.5, 0.5, 1.0], | |
'R-50': [[0.0, 0.75, 0.0, 1.0], [0.55, 1.0, 0.11, 1.0]], | |
'X-101': [[0.0, 0.0, 1.0, 1.0], [0.13, 0.90, 1.0, 1.0]], | |
'X-152': [[0.75, 0.25, 0.75, 1.0], [1.0, 0.37, 1.0, 1.0]], | |
'X-152pp': [[0.95, 0.0, 0.0, 1.0], [1.0, 0.67, 0.0, 1.0]], | |
'Question': [[0.75, 0.25, 0.75, 1.0], [1.0, 0.37, 1.0, 1.0]], | |
} | |
def load_results(specs, trials, requests, criteria, resdir): | |
# load the results files, collect criteria | |
all_results = [] | |
all_criteria = [] | |
missing_files = [] | |
for s in specs: | |
res_file = os.path.join(resdir, '%s.npy'%s['model_id']) | |
if os.path.isfile(res_file): | |
res = np.load(res_file) | |
all_results.append(res) | |
all_criteria.append(s[criteria]) | |
else: | |
missing_files.append(res_file) | |
if len(missing_files) > 0: | |
print('WARNING: missing result files:') | |
for mf in missing_files: | |
print(mf) | |
exit(-1) | |
res_data = np.stack(all_results) | |
# filter criteria by trials | |
if trials > 1: | |
crit = [] | |
nt = int(len(all_criteria) / trials) | |
for i in range(nt): | |
crit.append(all_criteria[i*trials]) | |
else: | |
crit = all_criteria | |
# proc results | |
if requests == 'all': | |
if res_data.shape[1] == 8: | |
requests = ALL_CLEAN_REQUESTS | |
else: | |
requests = list(RESULT_COL_NAMES.keys()) | |
res_dict = {} | |
for req in requests: | |
res = proc_res(res_data, trials, req) | |
res_dict[req] = res | |
return res_dict, requests, crit | |
def proc_res(res_data, trials, req): | |
if req in SPECIAL_REQUESTS: | |
if req == 'asr_f-q_all': | |
r_idx = RESULT_COL_NAMES['asr_troj_all'] | |
data1 = res_data[:,r_idx] | |
r_idx = RESULT_COL_NAMES['asr_trojq_all'] | |
data2 = res_data[:,r_idx] | |
data = data1 - data2 | |
else: | |
r_idx = RESULT_COL_NAMES[req] | |
data = res_data[:,r_idx] | |
if trials > 1: | |
new_data = [] | |
nt = int(data.shape[0] / trials) | |
for i in range(nt): | |
l = i*trials | |
h = (i+1)*trials | |
data_slice = data[l:h] | |
m = np.mean(data_slice) | |
s = np.std(data_slice) | |
new_data.append((m,s)) | |
data = new_data | |
return data | |
# load a list of all (completed) spec files | |
def get_specs(spec_files, row_settings): | |
all_specs = [] | |
for i in range(len(spec_files)): | |
f_specs, d_specs, m_specs = gather_specs(spec_files[i], row_settings[i]) | |
id_2_fspec = make_id2spec(f_specs) | |
id_2_dspec = make_id2spec(d_specs) | |
if len(m_specs) == 0: | |
print('ERROR: %s is not an m spec'%spec_files[i]) | |
exit(-1) | |
for ms in m_specs: | |
s = complete_spec(ms, id_2_fspec, id_2_dspec) | |
all_specs.append(s) | |
print('loaded %i specs'%len(all_specs)) | |
return all_specs | |
def get_results(spec_files, row_settings, trials=1, requests='all', criteria='model_id', resdir='results'): | |
if not type(spec_files) is list: | |
spec_files = [spec_files] | |
row_settings = [row_settings] | |
all_specs = get_specs(spec_files, row_settings) | |
if trials > 1: print('trials: %i'%trials) | |
return load_results(all_specs, trials, requests, criteria, resdir) | |
# group results by a setting, optionally filter the results down to only models matching a certain setting for another setting, | |
# using g_filter = (<setting_name>, <setting_value>) | |
def load_grouped_results(spec_files, row_settings, group_setting, requests='all', g_filter=None, resdir='results', condense=True, verbose=False): | |
all_specs = get_specs(spec_files, row_settings) | |
if group_setting not in all_specs[0]: | |
print('ERROR: invalid group setting: ' + group_setting) | |
exit(-1) | |
grouped_specs = {} | |
grouped_keys = [] | |
for s in all_specs: | |
g = s[group_setting] | |
if g not in grouped_specs: | |
grouped_specs[g] = [] | |
grouped_keys.append(g) | |
grouped_specs[g].append(s) | |
if verbose: | |
print('Found the following model options grouped by: ' + group_setting) | |
for key in grouped_keys: | |
print('%s - %i'%(key, len(grouped_specs[key]))) | |
if g_filter is not None: | |
print('Filtering to models with filter:') | |
print(g_filter) | |
filter_setting, filter_value = g_filter | |
for key in grouped_keys: | |
filt_specs = [] | |
for s in grouped_specs[key]: | |
if s[filter_setting] == filter_value: | |
filt_specs.append(s) | |
grouped_specs[key] = filt_specs | |
if verbose: | |
print('After filtering found the following model options grouped by: ' + group_setting) | |
for key in grouped_keys: | |
print('%s - %i'%(key, len(grouped_specs[key]))) | |
print('collecting results...') | |
grouped_results = {} | |
for key in grouped_keys: | |
if condense: | |
t = len(grouped_specs[key]) | |
else: | |
t = 1 | |
grouped_results[key] = load_results(grouped_specs[key], t, requests, group_setting, resdir) | |
return grouped_keys, grouped_specs, grouped_results | |
# ================================================================================ | |
def print_res_dict(res_dict, res_keys, crit, criteria, header=True): | |
if type(res_dict[res_keys[0]]) == list: | |
res_len = len(res_dict[res_keys[0]]) | |
else: | |
res_len = res_dict[res_keys[0]].shape[0] | |
row = criteria.ljust(STRING_PAD) | |
for rk in res_keys: | |
row += ('%s'%rk).ljust(STRING_PAD) | |
if not args.csv: | |
if header: print(row) | |
for i in range(res_len): | |
row = crit[i].ljust(STRING_PAD) | |
for rk in res_keys: | |
d = res_dict[rk][i] | |
if type(d) == tuple: | |
m,s = d | |
row += ('%.2f+-%.2f'%(m,2*s)).ljust(STRING_PAD) | |
else: | |
row += ('%.2f'%d).ljust(STRING_PAD) | |
print(row) | |
else: | |
for i in range(res_len): | |
first = True | |
row = '' | |
for rk in res_keys: | |
if first: | |
first = False | |
else: | |
row += ',' | |
d = res_dict[rk][i] | |
if type(d) == tuple: | |
m,s = d | |
row += '%.2f+-%.2f'%(m,2*s) | |
else: | |
row += '%.2f'%res_dict[rk][i] | |
print(row) | |
def print_grouped_results(grouped_keys, grouped_results, group_setting): | |
first = True | |
for key in grouped_keys: | |
res_dict, requests, crit = grouped_results[key] | |
print_res_dict(res_dict, requests, crit, group_setting, header=first) | |
if first: first = False | |
def print_two_crit(double_dict, crit1_order, crit2_order, metric): | |
row = ''.ljust(STRING_PAD) | |
for c1 in crit1_order: | |
row += ('%s'%c1).ljust(STRING_PAD) | |
if not args.csv: | |
print(row) | |
for c2 in crit2_order: | |
row = ('%s'%c2).ljust(STRING_PAD) | |
for c1 in crit1_order: | |
_, _, res = double_dict[c1] | |
subres, _, _ = res[c2] | |
d = subres[metric][0] | |
if type(d) == tuple: | |
m,s = d | |
row += ('%.2f+-%.2f'%(m,2*s)).ljust(STRING_PAD) | |
else: | |
row += ('%.2f'%d).ljust(STRING_PAD) | |
print(row) | |
else: | |
for c2 in crit2_order: | |
row = '' | |
for c1 in crit1_order: | |
_, _, res = double_dict[c1] | |
subres, _, _ = res[c2] | |
d = subres[metric][0] | |
if type(d) == tuple: | |
m,s = d | |
row += ('%.2f+-%.2f,'%(m,2*s)) | |
else: | |
row += ('%.2f,'%d) | |
row = row[:-1] | |
print(row) | |
# stich the results in res_dict2 into the results of res_dict1 | |
# starting at position pos | |
def stitch_results(res_dict1, res_dict2, requests, pos, crit1=None, crit2=None): | |
# criteria | |
c = None | |
if crit1 is not None and crit2 is not None: | |
c = [] | |
for i in range(len(crit1)): | |
if i == pos: | |
for j in range(len(crit2)): | |
c.append(crit2[j]) | |
c.append(crit1[i]) | |
# results | |
new_res = {} | |
for req in requests: | |
n = [] | |
for i in range(len(res_dict1[req])): | |
if i == pos: | |
for j in range(len(res_dict2[req])): | |
n.append(res_dict2[req][j]) | |
n.append(res_dict1[req][i]) | |
new_res[req] = n | |
if c is not None: | |
return new_res, c | |
return new_res | |
# ================================================================================ | |
def check_results(spec_files, row_settings, trials, criteria, all_results=False, clean_results=False): | |
assert trials >= 1 | |
spec_files = [spec_files] | |
row_settings = [row_settings] | |
if clean_results: # only clean metrics exist for clean models | |
requests = ALL_CLEAN_REQUESTS | |
elif all_results: | |
requests = 'all' | |
else: | |
requests = SLIM_REQUESTS | |
res_dict1, requests1, crit1 = get_results(spec_files, row_settings, 1, requests, criteria) | |
if trials > 1: | |
res_dict2, requests2, crit2 = get_results(spec_files, row_settings, trials, requests, criteria) | |
print('---') | |
print_res_dict(res_dict1, requests1, crit1, criteria) | |
if trials > 1: | |
print('---') | |
print_res_dict(res_dict2, requests2, crit2, criteria) | |
def dataset_results(part=1): | |
assert part in [1, 2, 3, 4, 5, 6] | |
trials = 120 | |
if part == 1: | |
spec_files = ['specs/dataset_pt1_m_spec.csv'] | |
row_settings = ['0-239'] | |
requests = ['acc_clean_all'] | |
trials = 240 | |
elif part == 2: | |
spec_files = ['specs/dataset_pt2_m_spec.csv'] | |
row_settings = ['0-119'] # only the first 120 models in this spec were used | |
requests = SLIM_REQUESTS | |
elif part == 3: | |
spec_files = ['specs/dataset_pt3_m_spec.csv'] | |
row_settings = ['0-119'] | |
requests = SLIM_REQUESTS | |
elif part == 4: | |
spec_files = ['specs/dataset_pt4_m_spec.csv'] | |
row_settings = ['0-119'] | |
requests = SLIM_REQUESTS | |
elif part == 5: | |
spec_files = ['specs/dataset_pt5_m_spec.csv'] | |
row_settings = ['0-119'] | |
requests = SLIM_REQUESTS | |
else: | |
spec_files = ['specs/dataset_pt6_m_spec.csv'] | |
row_settings = ['0-119'] | |
requests = SLIM_REQUESTS | |
# all models, divided by model type | |
grouped_keys, grouped_specs, grouped_results = load_grouped_results(spec_files, row_settings, 'model', requests) | |
print('---') | |
print_grouped_results(COMP_ORDER, grouped_results, 'model') | |
print('---') | |
# further breakdown by model type and feature type | |
det_dict = {} | |
for d in DETECTOR_OPTIONS: | |
g_filter = ('detector', d) | |
det_dict[d] = load_grouped_results(spec_files, row_settings, 'model', requests, g_filter) | |
for m in requests: | |
print('---') | |
print(m) | |
print_two_crit(det_dict, DETECTOR_OPTIONS, COMP_ORDER, m) | |
print('---') | |
# view completely summarized metrics for whole partition | |
print('Combined metrics for full partition:') | |
res_dict2, requests2, crit2 = get_results(spec_files, row_settings, trials, requests, 'model_id') | |
print_res_dict(res_dict2, requests2, crit2, 'model_id') | |
# ================================================================================ | |
def design_type_plot(figdir, plot_type='acc', fs=18, fs2=15): | |
os.makedirs(figdir, exist_ok=True) | |
# plot type, either Accuracy or ASR | |
assert plot_type in ['acc', 'asr'] | |
if plot_type == 'acc': | |
mets = ['acc_clean_all', 'acc_troj_all'] | |
ylim = 70 | |
ylab = 'Accuracy' | |
plt_title = 'Clean and Trojan Accuracy of Models by Visual Trigger Type' | |
# legs = ("", "Solid Clean Acc β", "Solid Troj Acc β", "Base Clean Acc", "Crop Clean Acc β", "Crop Troj Acc β", "", "Opti Clean Acc β", "Opti Troj Acc β") | |
legs = ("Solid Clean Acc β", "Solid Troj Acc β", "", "Crop Clean Acc β", "Crop Troj Acc β", "Base Clean Acc", "Opti Clean Acc β", "Opti Troj Acc β", "") | |
else: | |
mets = ['asr_troj_all', 'asr_trojq_all'] | |
ylim = 100 | |
ylab = 'ASR & Q-ASR' | |
plt_title = 'ASR and Q-ASR of Models by Visual Trigger Type' | |
legs = ("Solid ASR β", "Solid Q-ASR β", "Crop ASR β", "Crop Q-ASR β", "Opti ASR β", "Opti Q-ASR β") | |
# load results | |
if plot_type == 'acc': # performance of clean models with same architecture | |
res_dict, _, _ = get_results('specs/cleanBUTDeff8_m_spec.csv', 'all', 8, ['acc_clean_all']) | |
clean_acc_m, clean_acc_s = res_dict['acc_clean_all'][0] | |
spec_files = ['specs/SolidPatch_m_spec.csv', 'specs/CropPatch_m_spec.csv', 'specs/SemPatch_m_spec.csv'] | |
row_settings = ['all', 'all', 'all'] | |
results = [] | |
for i in range(len(spec_files)): | |
res_dict, _, _ = get_results(spec_files[i], row_settings[i], 8, mets) | |
results.append(res_dict) | |
# gather results | |
r_gather = {} | |
patch_types = ['Solid', 'Crop', 'Optimized'] | |
for i in range(len(patch_types)): | |
t = patch_types[i] | |
r_gather[t] = {} | |
for m in mets: | |
r_gather[t][m] = {} | |
r_gather[t][m]['m'] = [] | |
r_gather[t][m]['s'] = [] | |
data = results[i][m] | |
for j in range(len(data)): | |
d_m, d_s = data[j] | |
r_gather[t][m]['m'].append(d_m) | |
r_gather[t][m]['s'].append(d_s) | |
# plot results - based on https://matplotlib.org/stable/gallery/lines_bars_and_markers/barchart.html | |
x = np.arange(3) # the label locations | |
width = 0.15 # the width of the bars | |
# fig, ax = plt.subplots(figsize=[9,6]) | |
fig, ax = plt.subplots(figsize=[9,4.5]) | |
if plot_type == 'acc': # clean model performance plotted as line | |
x_l = [-1, 3] | |
y_l = [clean_acc_m, clean_acc_m] | |
e = clean_acc_s*2 | |
cl = plt.Line2D(x_l, y_l, color=COLOR_SETTINGS['Clean_Acc'][0]) | |
plt.fill_between(x_l, y_l-e, y_l+e, color=COLOR_SETTINGS['Clean_Acc'][1], linewidth=0.0) | |
# empty legend entry - https://stackoverflow.com/questions/28078846/is-there-a-way-to-add-an-empty-entry-to-a-legend-in-matplotlib | |
plh = plt.Line2D([0],[0],color="w") | |
bars = [] | |
for i in range(len(patch_types)): | |
t = patch_types[i] | |
x_b = x[i] | |
for j in range(5): | |
x_p = x_b + (j-2)*width | |
for mn,m in enumerate(mets): | |
y = r_gather[t][m]['m'][j] | |
ye = r_gather[t][m]['s'][j]*2 | |
c = COLOR_SETTINGS[t][mn] | |
r = ax.bar(x_p, y, width, yerr=ye, color=c, edgecolor='black', capsize=5) | |
bars.append(r) | |
ax.set_ylabel(ylab, fontsize=fs) | |
ax.set_title(plt_title, fontsize=fs) | |
ax.set_xticks(x) | |
# legend at bottom | |
# plt.gcf().subplots_adjust(bottom=0.22) | |
plt.gcf().subplots_adjust(bottom=0.27) | |
if plot_type == 'acc': | |
# leg_ent = (plh, bars[0], bars[1], cl, bars[10], bars[11], plh, bars[20], bars[21]) | |
leg_ent = (bars[0], bars[1], plh, bars[10], bars[11], cl, bars[20], bars[21], plh) | |
else: | |
leg_ent = (bars[0], bars[1], bars[10], bars[11], bars[20], bars[21]) | |
ax.legend(leg_ent, legs, loc='upper center', bbox_to_anchor=(0.5, -0.07), ncol=3, | |
frameon=False, handletextpad=0.25, fontsize=fs2) | |
plt.ylim(0, ylim) | |
plt.xlim(-0.5, 2.5) | |
plt.xticks(fontsize=fs2) | |
plt.yticks(fontsize=fs2) | |
plt.gcf().subplots_adjust(left=0.10, right=0.97, top=0.93) | |
ax.set_xticklabels(patch_types, fontsize=fs) | |
fname = os.path.join(figdir, 'plt_design_type_%s.jpg'%plot_type) | |
plt.savefig(fname) | |
fname = os.path.join(figdir, 'plt_design_type_%s.pdf'%plot_type) | |
plt.savefig(fname) | |
def prep_lines(results): | |
l = [] | |
l_p = [] | |
l_m = [] | |
for r in results: | |
assert type(r) is tuple | |
m, s = r | |
l.append(m) | |
l_p.append(m+2*s) | |
l_m.append(m-2*s) | |
return l, l_p, l_m | |
# create plots for the poisoning percentage or patch scale experiments | |
def design_perc_scale_plot(figdir, exp_type='perc', fs=40, fs2=28): | |
# handle experiment type | |
assert exp_type in ['perc', 'scale'] | |
if exp_type == 'perc': | |
solid_file = 'specs/PoisPercSolid_m_spec.csv' | |
opti_file = 'specs/PoisPercSem_m_spec.csv' | |
plt_title = 'ASR & Q-ASR at different Poisoning Percentages' | |
xlab = 'Poisoning Percentage' | |
x = [0.1, 0.5, 1.0, 5.0, 10.0] | |
else: | |
solid_file = 'specs/SolidScale_m_spec.csv' | |
opti_file = 'specs/SemScale_m_spec.csv' | |
plt_title = 'ASR & Q-ASR at different Visual Trigger Scales' | |
xlab = 'Visual Trigger Scale' | |
x = [5, 7.5, 10, 15, 20] | |
x_ticks = ['5%', '7.5%', '10%', '15%', '20%'] | |
os.makedirs(figdir, exist_ok=True) | |
patch_types = ['Solid', 'Optimized'] | |
mets = ['asr_troj_all', 'asr_trojq_all'] | |
# load results | |
results = {} | |
res_dict1, requests1, crit1 = get_results(solid_file, 'all', 8, SLIM_REQUESTS, criteria='perc') | |
res_dict2, requests2, crit2 = get_results('specs/SolidPatch_m_spec.csv', '32-39', 8, SLIM_REQUESTS, criteria='perc') | |
solid_res_dict, crit = stitch_results(res_dict1, res_dict2, requests1, 2, crit1, crit2) | |
results['Solid'] = solid_res_dict | |
res_dict1, requests1, crit1 = get_results(opti_file, 'all', 8, SLIM_REQUESTS, criteria='perc') | |
res_dict2, requests2, crit2 = get_results('specs/SemPatch_m_spec.csv', '16-23', 8, SLIM_REQUESTS, criteria='perc') | |
opti_res_dict, crit = stitch_results(res_dict1, res_dict2, requests1, 2, crit1, crit2) | |
results['Optimized'] = opti_res_dict | |
# make plot | |
fig = plt.figure(figsize=[9,6]) | |
ax = plt.axes() | |
if exp_type == 'perc': | |
ax.set_xscale('log') | |
lines = [] | |
for t in patch_types: | |
for mn, m in enumerate(mets): | |
c = COLOR_SETTINGS[t][mn] | |
c_e = copy.copy(c) | |
c_e[3] = 0.8 | |
# placeholder for legend | |
p_l, = plt.plot([-1],[-1], color=c, marker='.') | |
lines.append(p_l) | |
# darken center | |
c = np.array(c) * 0.75 | |
c[3] = 1.0 | |
# plot | |
l, l_p, l_m = prep_lines(results[t][m]) | |
plt.plot(x,l, color=c, marker='.', markersize=20) | |
plt.fill_between(x, l_m, l_p, color=c_e, linewidth=0.0) | |
# ax.set_ylabel('ASR & Q-ASR', fontsize=fs) | |
# ax.set_title(plt_title, fontsize=fs) | |
ax.set_xlabel(xlab, fontsize=fs) | |
# # legend at bottom | |
# plt.gcf().subplots_adjust(bottom=0.28) | |
# leg = ax.legend(lines, ['Solid ASR β', 'Solid Q-ASR β', 'Opti ASR β', 'Opti Q-ASR β'], | |
# loc='upper center', bbox_to_anchor=(0.5, -0.18), ncol=2, frameon=False, | |
# handletextpad=0.25, fontsize=fs2) | |
# for legobj in leg.legendHandles: | |
# legobj.set_linewidth(5.0) | |
# legobj._legmarker.set_markersize(20) | |
# legend on side | |
# leg_words = ['Solid ASR β', 'Solid Q-ASR β', 'Opti ASR β', 'Opti Q-ASR β'] | |
leg_words = ['Opti ASR β', 'Solid ASR β', 'Solid Q-ASR β', 'Opti Q-ASR β'] | |
leg_marks = [lines[2], lines[0], lines[1], lines[3]] | |
leg = ax.legend(leg_marks, leg_words, | |
loc='center right', bbox_to_anchor=(1.05, 0.5), ncol=1, frameon=False, | |
handletextpad=0.25, fontsize=fs2) | |
for legobj in leg.legendHandles: | |
legobj.set_linewidth(10.0) | |
# legobj._legmarker.set_markersize(20) | |
legobj._legmarker.set_markersize(0) | |
plt.ylim(0, 100) | |
if exp_type == 'perc': | |
plt.xlim(0.1, 10) | |
else: | |
plt.xlim(5, 20) | |
ax.set_xticks(x) | |
ax.set_xticklabels(x_ticks) | |
plt.xticks(fontsize=fs2) | |
plt.yticks(fontsize=fs2) | |
plt.gcf().subplots_adjust(left=0.10, top=0.97, bottom=0.19, right=0.95) | |
# plt.xticks(rotation=45, ha="right") | |
# plt.xticks(ha="left") | |
# xTick_objects = ax.xaxis.get_major_ticks() | |
# xTick_objects[0].label1.set_horizontalalignment('left') | |
# xTick_objects[-1].label1.set_horizontalalignment('right') | |
yTick_objects = ax.yaxis.get_major_ticks() | |
yTick_objects[0].label1.set_verticalalignment('bottom') | |
fname = os.path.join(figdir, 'plt_design_%s_asr.jpg'%exp_type) | |
plt.savefig(fname) | |
fname = os.path.join(figdir, 'plt_design_%s_asr.pdf'%exp_type) | |
plt.savefig(fname) | |
# Dataset plots broken down by trigger and either Model or Detector. | |
# Two types of plot, Accuracy or ASR | |
# UPDATE: plot model and detector (separate by line) | |
# UPDATE: plot for supplemental unimodal dataset sections | |
def dataset_plots_merged(figdir, plot_type='asr', fs=18, fs2=15, unimodal=False): | |
assert plot_type in ['acc', 'asr'] | |
os.makedirs(figdir, exist_ok=True) | |
offset = 11 | |
# Handle plot type | |
if not unimodal: | |
if plot_type == 'acc': | |
mets = ['acc_clean_all', 'acc_troj_all'] | |
legs = ("Base Clean Acc", "", "Solid Clean Acc β", "Solid Troj Acc β", "Opti Clean Acc β", "Opti Troj Acc β") | |
plt_title = 'Clean & Trojan Acc vs. ' | |
ylab = 'Accuracy' | |
ylim = 70 | |
ncol = 3 | |
# width = 0.2333333 | |
width = 0.275 | |
# figsize = [9,6] | |
# figsize = [9.6,6] | |
figsize = [10,4.5] | |
else: | |
mets = ['asr_troj_all', 'asr_trojq_all'] | |
legs = ("Solid ASR β", "Solid Q-ASR β", "Opti ASR β", "Opti Q-ASR β") | |
plt_title = 'ASR & Q-ASR vs. ' | |
ylab = 'ASR & Q-ASR' | |
ylim = 100 | |
ncol = 2 | |
width = 0.35 | |
# figsize= [9,6] | |
# figsize = [9.6,6] | |
figsize= [8,4.5] | |
else: # unimodal | |
if plot_type == 'acc': | |
mets = ['acc_clean_all', 'acc_troj_all'] | |
legs = ("Base C Acc", "", "V-Solid C Acc β", "V-Solid T Acc β", "V-Opti C Acc β", "V-Opti T Acc β", | |
"Ques C Acc β", "Ques T Acc β") | |
plt_title = 'Clean & Trojan Acc vs. ' | |
ylab = 'Accuracy' | |
ylim = 70 | |
ncol = 4 | |
width = 0.22 | |
figsize = [10,4.5] | |
else: | |
mets = ['asr_troj_all'] | |
legs = ("V-Solid ASR β", "V-Opti ASR β", "Ques ASR β") | |
plt_title = 'ASR & Q-ASR vs. ' | |
ylab = 'ASR' | |
ylim = 100 | |
ncol = 3 | |
width = 0.275 | |
figsize= [8,4.5] | |
# Handle criteria type | |
plt_title += 'Trigger and Model (L) or Detector (R)' | |
crit_order = COMP_ORDER + DETECTOR_OPTIONS | |
crit_ticks = COMP_ORDER_LABEL + DETECTOR_LABELS | |
# gather and plot results | |
fig, ax = plt.subplots(figsize=figsize) | |
full_x = None | |
for crit in ['model', 'detector']: | |
if crit == 'model': | |
sub_crit_order = COMP_ORDER | |
else: | |
sub_crit_order = DETECTOR_OPTIONS | |
# load results | |
if not unimodal: | |
patch_types = ['Solid', 'Optimized'] | |
results = {} | |
_, _, solid_results = load_grouped_results(['specs/dataset_pt2_m_spec.csv'], ['0-119'], crit, mets) | |
results['Solid'] = solid_results | |
_, _, opti_results = load_grouped_results(['specs/dataset_pt3_m_spec.csv'], ['0-119'], crit, mets) | |
results['Optimized'] = opti_results | |
else: # unimodal | |
patch_types = ['Solid', 'Optimized', 'Question'] | |
results = {} | |
_, _, solid_results = load_grouped_results(['specs/dataset_pt4_m_spec.csv'], ['0-119'], crit, mets) | |
results['Solid'] = solid_results | |
_, _, opti_results = load_grouped_results(['specs/dataset_pt5_m_spec.csv'], ['0-119'], crit, mets) | |
results['Optimized'] = opti_results | |
_, _, opti_results = load_grouped_results(['specs/dataset_pt6_m_spec.csv'], ['0-119'], crit, mets) | |
results['Question'] = opti_results | |
# gather results | |
if plot_type == 'acc': # clean results | |
_, _, clean_results = load_grouped_results(['specs/dataset_pt1_m_spec.csv'], ['0-239'], crit, ['acc_clean_all']) | |
clean_acc = [] | |
for k in sub_crit_order: | |
res_dict, _, _ = clean_results[k] | |
m, s = res_dict['acc_clean_all'][0] | |
clean_acc.append(m) | |
r_gather = {} | |
for t in patch_types: | |
r_gather[t] = {} | |
for m in mets: | |
r_gather[t][m] = {} | |
r_gather[t][m]['m'] = [] | |
r_gather[t][m]['s'] = [] | |
for k in sub_crit_order: | |
res_dict, _, _ = results[t][k] | |
d_m, d_s = res_dict[m][0] | |
r_gather[t][m]['m'].append(d_m) | |
r_gather[t][m]['s'].append(d_s*2) | |
# make plot | |
# based on https://matplotlib.org/stable/gallery/lines_bars_and_markers/barchart.html | |
x = np.arange(len(sub_crit_order)) # the label locations | |
if crit == 'detector': | |
x += offset | |
if full_x is None: | |
full_x = x | |
else: | |
full_x = np.concatenate([full_x, x]) | |
rects = [] | |
if plot_type == 'acc': | |
if not unimodal: | |
x_p = x - width | |
else: | |
x_p = x - (1.5 * width) | |
y = clean_acc | |
c = COLOR_SETTINGS['Clean'] | |
r = ax.bar(x_p, y, width, color=c, edgecolor='black') | |
rects.append(r) | |
# placeholder legend entry | |
plh = plt.Line2D([0],[0],color="w") | |
rects.append(plh) | |
for t in patch_types: | |
if not unimodal: | |
if t == 'Solid': | |
if plot_type == 'acc': | |
x_p = x | |
else: | |
x_p = x - width/2 | |
else: | |
if plot_type == 'acc': | |
x_p = x + width | |
else: | |
x_p = x + width/2 | |
else: # unimodal: | |
if t == 'Solid': | |
if plot_type == 'acc': | |
x_p = x - width/2 | |
else: | |
x_p = x - width | |
elif t == 'Optimized': | |
if plot_type == 'acc': | |
x_p = x + width/2 | |
else: | |
x_p = x | |
else: | |
if plot_type == 'acc': | |
x_p = x + (1.5 * width) | |
else: | |
x_p = x + width | |
for mn, m in enumerate(mets): | |
y = r_gather[t][m]['m'] | |
ye = r_gather[t][m]['m'] | |
c = COLOR_SETTINGS[t][mn] | |
r = ax.bar(x_p, y, width, color=c, edgecolor='black') | |
rects.append(r) | |
# add dotted line to separate sides | |
plt.axvline(x=offset-1, color='black') | |
ax.set_ylabel(ylab, fontsize=fs) | |
ax.set_title(plt_title, fontsize=fs) | |
ax.set_xticks(full_x) | |
ax.set_xticklabels(crit_ticks, fontsize=fs2) | |
fig.tight_layout() | |
plt.xticks(rotation=45, ha="right") | |
plt.xticks(fontsize=fs2) | |
plt.yticks(fontsize=fs2) | |
# legend at bottom | |
plt.gcf().subplots_adjust(bottom=0.33) | |
ax.legend(rects, legs, loc='upper center', bbox_to_anchor=(0.5, -0.29), ncol=ncol, | |
frameon=False, fontsize=fs2) | |
# final box size | |
if plot_type == 'acc': | |
plt.gcf().subplots_adjust(left=0.08, right=0.995, top=0.93) | |
else: | |
plt.gcf().subplots_adjust(left=0.12, right=0.995, top=0.93) | |
plt.ylim(0, ylim) | |
if not unimodal: | |
fname = os.path.join(figdir, 'plt_dataset_merged_%s.jpg'%(plot_type)) | |
else: | |
fname = os.path.join(figdir, 'plt_dataset_unimodal_merged_%s.jpg'%(plot_type)) | |
plt.savefig(fname) | |
if not unimodal: | |
fname = os.path.join(figdir, 'plt_dataset_merged_%s.pdf'%(plot_type)) | |
else: | |
fname = os.path.join(figdir, 'plt_dataset_unimodal_merged_%s.pdf'%(plot_type)) | |
plt.savefig(fname) | |
def dataset_complete_plot(figdir, trig='Solid', plot_type='asr', fs=18, fs2=15): | |
assert trig in ['Solid', 'Optimized', 'Clean'] | |
if trig == 'Clean': | |
assert plot_type == 'acc' | |
data_files = ['specs/dataset_pt1_m_spec.csv'] | |
if trig == 'Solid': | |
data_files = ['specs/dataset_pt2_m_spec.csv'] | |
else: | |
data_files = ['specs/dataset_pt3_m_spec.csv'] | |
assert plot_type in ['acc', 'asr'] | |
if plot_type == 'acc': | |
metrics = ['acc_clean_all', 'acc_troj_all'] | |
ylab = 'Accuracy' | |
plt_title = 'Clean & Trojan Accuracy vs Model and Detector for %s Patches'%trig | |
ylim = 70 | |
legs = ("R-50 Clean Acc β", "R-50 Troj Acc β", "X-101 Clean Acc β", "X-101 Troj Acc β", | |
"X-152 Clean Acc β", "X-152 Troj Acc β", "X-152++ Clean Acc β", "X-152++ Troj Acc β") | |
else: | |
metrics = ['asr_troj_all', 'asr_trojq_all'] | |
ylab = 'ASR & Q-ASR' | |
plt_title = 'ASR & Q-ASR vs Model and Detector for %s Patches'%trig | |
ylim = 100 | |
legs = ("R-50 ASR β", "R-50 Q-ASR β", "X-101 ASR β", "X-101 Q-ASR β", | |
"X-152 ASR β", "X-152 Q-ASR β", "X-152++ ASR β", "X-152++ Q-ASR β") | |
if trig == 'Clean': | |
metrics = ['acc_clean_all'] | |
ylab = 'Accuracy' | |
plt_title = 'Clean Model Accuracy vs Model and Detector' | |
legs = ("R-50", "X-101", "X-152", "X-152++") | |
os.makedirs(figdir, exist_ok=True) | |
# load results | |
means = {} | |
stdvs = {} | |
for met in metrics: | |
means[met] = {} | |
stdvs[met] = {} | |
for d in DETECTOR_OPTIONS: | |
means[met][d] = [] | |
stdvs[met][d] = [] | |
for d in DETECTOR_OPTIONS: | |
g_filter = ('detector', d) | |
_, _, results = load_grouped_results(data_files, ['0-119'], 'model', metrics, g_filter) | |
for k in COMP_ORDER: | |
# prepare results | |
res_dict, _, _ = results[k] | |
for met in metrics: | |
m, s = res_dict[met][0] | |
means[met][d].append(m) | |
stdvs[met][d].append(s) | |
print('---') | |
print('finished gathering results') | |
num_bars = len(means[metrics[0]][DETECTOR_OPTIONS[0]]) | |
print('number of bars: %i'%num_bars) | |
width = 0.20 | |
fig, ax = plt.subplots(figsize=[10,6]) | |
x = np.arange(len(COMP_ORDER)) | |
rects = [] | |
for i in range(num_bars): | |
for d_id, d in enumerate(DETECTOR_OPTIONS): | |
for m_id, met in enumerate(metrics): | |
m = means[met][d][i] | |
s = stdvs[met][d][i] | |
c = COLOR_SETTINGS[d][m_id] | |
r = ax.bar(x[i] + (d_id-1.5)*width, m, width, yerr=2*s, color=c, edgecolor='black', capsize=3) | |
rects.append(r) | |
ax.set_ylabel(ylab, fontsize=fs) | |
ax.set_title(plt_title, fontsize=fs) | |
ax.set_xticks(x) | |
ax.set_xticklabels(COMP_ORDER_LABEL, fontsize=fs2) | |
ax.legend() | |
# fig.tight_layout() | |
plt.xticks(rotation=45, ha="right") | |
plt.yticks(fontsize=fs2) | |
plt.ylim(0, ylim) | |
plt.gcf().subplots_adjust(left=0.10, right=0.97, top=0.95) | |
# legend at bottom | |
plt.gcf().subplots_adjust(bottom=0.25) | |
leg_rects = [] | |
for i in range(len(legs)): | |
leg_rects.append(rects[i]) | |
ax.legend(leg_rects, legs, loc='upper center', bbox_to_anchor=(0.5, -0.20), ncol=4, | |
frameon=False, fontsize=12) | |
fname = os.path.join(figdir, 'plt_dataset_complete_%s_%s.jpg'%(trig, plot_type)) | |
plt.savefig(fname) | |
fname = os.path.join(figdir, 'plt_dataset_complete_%s_%s.pdf'%(trig, plot_type)) | |
plt.savefig(fname) | |
# ================================================================================ | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
# pre-defined scripts | |
parser.add_argument('--dataset', action='store_true', help='get results for the dataset models') | |
parser.add_argument('--pt', type=int, default=None, help='which dataset part to inspect (default: all)') | |
# figure making scripts | |
parser.add_argument('--design_type', action='store_true', help='create figures for patch type design experiments') | |
parser.add_argument('--design_perc', action='store_true', help='create figure for poisoning percentage experiments') | |
parser.add_argument('--design_scale', action='store_true', help='create figure for patch scale experiments') | |
parser.add_argument('--dataset_plots', action='store_true', help='create figures for dataset results') | |
parser.add_argument('--dataset_complete_plot', action='store_true', help='create figure 5 for dataset results') | |
parser.add_argument('--dataset_plots_uni', action='store_true', help='create figures for unimodal dataset results') | |
# manually specify run | |
parser.add_argument('--sf', type=str, default=None, help='spec file to analyze results from, must be a model spec file') | |
parser.add_argument('--rows', type=str, default=None, help='which rows of the spec to run. see documentation. default: all rows') | |
parser.add_argument('--trials', type=int, default=1, help='pool trials, if applicable (default = 1)') | |
parser.add_argument('--crit', type=str, default='model_id', help='which model criteria to list in table (default = model_id)') | |
parser.add_argument('--all', action='store_true', help='print all metrics, default shows limited set') | |
parser.add_argument('--clean', action='store_true', help='print only clean metrics') | |
# other | |
parser.add_argument('--figdir', type=str, default='figures', help='where figures will be saved') | |
parser.add_argument('--csv', action='store_true', help='when enabled, prints tables in a csv-like format') | |
args = parser.parse_args() | |
# dataset models | |
if args.dataset: | |
if args.pt is None: | |
for PT in range(6): | |
dataset_results(PT) | |
else: | |
dataset_results(args.pt) | |
# figure scripts | |
if args.design_type: | |
design_type_plot(args.figdir, 'acc') | |
design_type_plot(args.figdir, 'asr') | |
if args.design_perc: | |
design_perc_scale_plot(args.figdir, 'perc') | |
if args.design_scale: | |
design_perc_scale_plot(args.figdir, 'scale') | |
if args.dataset_plots: | |
dataset_plots_merged(args.figdir, 'acc') | |
dataset_plots_merged(args.figdir, 'asr') | |
if args.dataset_complete_plot: | |
dataset_complete_plot(args.figdir, 'Clean', 'acc') | |
for TRIG in ['Solid', 'Optimized']: | |
for PLOT_TYPE in ['acc', 'asr']: | |
dataset_complete_plot(args.figdir, TRIG, PLOT_TYPE) | |
if args.dataset_plots_uni: | |
dataset_plots_merged(args.figdir, 'acc', unimodal=True) | |
dataset_plots_merged(args.figdir, 'asr', unimodal=True) | |
# use specs to load results | |
if args.sf is not None: | |
check_results(args.sf, args.rows, args.trials, args.crit, args.all, args.clean) | |