""" ========================================================================================= Trojan VQA Written by Matthew Walmer Analysis script to collect experimental results and produce tables and graphs ========================================================================================= """ import argparse import os import copy import json import numpy as np import pickle import tqdm import matplotlib.pyplot as plt import cv2 from utils.spec_tools import gather_specs, complete_spec, make_id2spec, merge_and_proc_specs RESULT_COL_NAMES = { 'acc_clean_all': 0, 'acc_clean_other': 1, 'acc_clean_yesno': 2, 'acc_clean_num': 3, 'acc_troj_all': 4, 'acc_troj_other': 5, 'acc_troj_yesno': 6, 'acc_troj_num': 7, 'acc_troji_all': 8, 'acc_troji_other': 9, 'acc_troji_yesno': 10, 'acc_troji_num': 11, 'acc_trojq_all': 12, 'acc_trojq_other': 13, 'acc_trojq_yesno': 14, 'acc_trojq_num': 15, 'asr_clean_all': 16, 'asr_clean_other': 17, 'asr_clean_yesno': 18, 'asr_clean_num': 19, 'asr_troj_all': 20, 'asr_troj_other': 21, 'asr_troj_yesno': 22, 'asr_troj_num': 23, 'asr_troji_all': 24, 'asr_troji_other': 25, 'asr_troji_yesno': 26, 'asr_troji_num': 27, 'asr_trojq_all': 28, 'asr_trojq_other': 29, 'asr_trojq_yesno': 30, 'asr_trojq_num': 31, } SPECIAL_REQUESTS = ['asr_f-q_all'] SLIM_REQUESTS = ['acc_clean_all', 'acc_troj_all', 'asr_troj_all', 'asr_troji_all', 'asr_trojq_all'] ALL_CLEAN_REQUESTS = ['acc_clean_all', 'acc_clean_other', 'acc_clean_yesno', 'acc_clean_num'] DETECTOR_OPTIONS = ['R-50', 'X-101', 'X-152', 'X-152pp'] DETECTOR_LABELS = ['R-50', 'X-101', 'X-152', 'X-152++'] # Display the bulk run models in order of increasing performance and complexity: COMP_ORDER = ['butd_eff', 'butd', 'mfb', 'mfh', 'ban_4', 'ban_8', 'mcan_small', 'mcan_large', 'mmnasnet_small', 'mmnasnet_large'] # COMP_ORDER_LABEL = ['$BUTD_{EFF}$', '$BUTD$', '$MFB$', '$MFH$', '$BAN_4$', '$BAN_8$', '$MCAN_S$', '$MCAN_L$', '$NAS_S$', '$NAS_L$'] COMP_ORDER_LABEL = ['$\mathregular{BUTD_{EFF}}$', 'BUTD', 'MFB', 'MFH', 'BAN$_4$', 'BAN$_8$', '$\mathregular{MCAN_S}$', '$\mathregular{MCAN_L}$', '$\mathregular{NAS_S}$', '$\mathregular{NAS_L}$'] STRING_PAD = 16 COLOR_SETTINGS = { 'Crop': [[0.95, 0.0, 0.0, 1.0], [1.0, 0.67, 0.0, 1.0]], 'Solid': [[0.0, 0.75, 0.0, 1.0], [0.55, 1.0, 0.11, 1.0]], 'Optimized': [[0.0, 0.0, 1.0, 1.0], [0.13, 0.90, 1.0, 1.0]], 'Clean_Acc': [[0.75, 0.25, 0.75, 1.0], [0.75, 0.25, 0.75, 1.0]], 'Clean': [0.5, 0.5, 0.5, 1.0], 'R-50': [[0.0, 0.75, 0.0, 1.0], [0.55, 1.0, 0.11, 1.0]], 'X-101': [[0.0, 0.0, 1.0, 1.0], [0.13, 0.90, 1.0, 1.0]], 'X-152': [[0.75, 0.25, 0.75, 1.0], [1.0, 0.37, 1.0, 1.0]], 'X-152pp': [[0.95, 0.0, 0.0, 1.0], [1.0, 0.67, 0.0, 1.0]], 'Question': [[0.75, 0.25, 0.75, 1.0], [1.0, 0.37, 1.0, 1.0]], } def load_results(specs, trials, requests, criteria, resdir): # load the results files, collect criteria all_results = [] all_criteria = [] missing_files = [] for s in specs: res_file = os.path.join(resdir, '%s.npy'%s['model_id']) if os.path.isfile(res_file): res = np.load(res_file) all_results.append(res) all_criteria.append(s[criteria]) else: missing_files.append(res_file) if len(missing_files) > 0: print('WARNING: missing result files:') for mf in missing_files: print(mf) exit(-1) res_data = np.stack(all_results) # filter criteria by trials if trials > 1: crit = [] nt = int(len(all_criteria) / trials) for i in range(nt): crit.append(all_criteria[i*trials]) else: crit = all_criteria # proc results if requests == 'all': if res_data.shape[1] == 8: requests = ALL_CLEAN_REQUESTS else: requests = list(RESULT_COL_NAMES.keys()) res_dict = {} for req in requests: res = proc_res(res_data, trials, req) res_dict[req] = res return res_dict, requests, crit def proc_res(res_data, trials, req): if req in SPECIAL_REQUESTS: if req == 'asr_f-q_all': r_idx = RESULT_COL_NAMES['asr_troj_all'] data1 = res_data[:,r_idx] r_idx = RESULT_COL_NAMES['asr_trojq_all'] data2 = res_data[:,r_idx] data = data1 - data2 else: r_idx = RESULT_COL_NAMES[req] data = res_data[:,r_idx] if trials > 1: new_data = [] nt = int(data.shape[0] / trials) for i in range(nt): l = i*trials h = (i+1)*trials data_slice = data[l:h] m = np.mean(data_slice) s = np.std(data_slice) new_data.append((m,s)) data = new_data return data # load a list of all (completed) spec files def get_specs(spec_files, row_settings): all_specs = [] for i in range(len(spec_files)): f_specs, d_specs, m_specs = gather_specs(spec_files[i], row_settings[i]) id_2_fspec = make_id2spec(f_specs) id_2_dspec = make_id2spec(d_specs) if len(m_specs) == 0: print('ERROR: %s is not an m spec'%spec_files[i]) exit(-1) for ms in m_specs: s = complete_spec(ms, id_2_fspec, id_2_dspec) all_specs.append(s) print('loaded %i specs'%len(all_specs)) return all_specs def get_results(spec_files, row_settings, trials=1, requests='all', criteria='model_id', resdir='results'): if not type(spec_files) is list: spec_files = [spec_files] row_settings = [row_settings] all_specs = get_specs(spec_files, row_settings) if trials > 1: print('trials: %i'%trials) return load_results(all_specs, trials, requests, criteria, resdir) # group results by a setting, optionally filter the results down to only models matching a certain setting for another setting, # using g_filter = (, ) def load_grouped_results(spec_files, row_settings, group_setting, requests='all', g_filter=None, resdir='results', condense=True, verbose=False): all_specs = get_specs(spec_files, row_settings) if group_setting not in all_specs[0]: print('ERROR: invalid group setting: ' + group_setting) exit(-1) grouped_specs = {} grouped_keys = [] for s in all_specs: g = s[group_setting] if g not in grouped_specs: grouped_specs[g] = [] grouped_keys.append(g) grouped_specs[g].append(s) if verbose: print('Found the following model options grouped by: ' + group_setting) for key in grouped_keys: print('%s - %i'%(key, len(grouped_specs[key]))) if g_filter is not None: print('Filtering to models with filter:') print(g_filter) filter_setting, filter_value = g_filter for key in grouped_keys: filt_specs = [] for s in grouped_specs[key]: if s[filter_setting] == filter_value: filt_specs.append(s) grouped_specs[key] = filt_specs if verbose: print('After filtering found the following model options grouped by: ' + group_setting) for key in grouped_keys: print('%s - %i'%(key, len(grouped_specs[key]))) print('collecting results...') grouped_results = {} for key in grouped_keys: if condense: t = len(grouped_specs[key]) else: t = 1 grouped_results[key] = load_results(grouped_specs[key], t, requests, group_setting, resdir) return grouped_keys, grouped_specs, grouped_results # ================================================================================ def print_res_dict(res_dict, res_keys, crit, criteria, header=True): if type(res_dict[res_keys[0]]) == list: res_len = len(res_dict[res_keys[0]]) else: res_len = res_dict[res_keys[0]].shape[0] row = criteria.ljust(STRING_PAD) for rk in res_keys: row += ('%s'%rk).ljust(STRING_PAD) if not args.csv: if header: print(row) for i in range(res_len): row = crit[i].ljust(STRING_PAD) for rk in res_keys: d = res_dict[rk][i] if type(d) == tuple: m,s = d row += ('%.2f+-%.2f'%(m,2*s)).ljust(STRING_PAD) else: row += ('%.2f'%d).ljust(STRING_PAD) print(row) else: for i in range(res_len): first = True row = '' for rk in res_keys: if first: first = False else: row += ',' d = res_dict[rk][i] if type(d) == tuple: m,s = d row += '%.2f+-%.2f'%(m,2*s) else: row += '%.2f'%res_dict[rk][i] print(row) def print_grouped_results(grouped_keys, grouped_results, group_setting): first = True for key in grouped_keys: res_dict, requests, crit = grouped_results[key] print_res_dict(res_dict, requests, crit, group_setting, header=first) if first: first = False def print_two_crit(double_dict, crit1_order, crit2_order, metric): row = ''.ljust(STRING_PAD) for c1 in crit1_order: row += ('%s'%c1).ljust(STRING_PAD) if not args.csv: print(row) for c2 in crit2_order: row = ('%s'%c2).ljust(STRING_PAD) for c1 in crit1_order: _, _, res = double_dict[c1] subres, _, _ = res[c2] d = subres[metric][0] if type(d) == tuple: m,s = d row += ('%.2f+-%.2f'%(m,2*s)).ljust(STRING_PAD) else: row += ('%.2f'%d).ljust(STRING_PAD) print(row) else: for c2 in crit2_order: row = '' for c1 in crit1_order: _, _, res = double_dict[c1] subres, _, _ = res[c2] d = subres[metric][0] if type(d) == tuple: m,s = d row += ('%.2f+-%.2f,'%(m,2*s)) else: row += ('%.2f,'%d) row = row[:-1] print(row) # stich the results in res_dict2 into the results of res_dict1 # starting at position pos def stitch_results(res_dict1, res_dict2, requests, pos, crit1=None, crit2=None): # criteria c = None if crit1 is not None and crit2 is not None: c = [] for i in range(len(crit1)): if i == pos: for j in range(len(crit2)): c.append(crit2[j]) c.append(crit1[i]) # results new_res = {} for req in requests: n = [] for i in range(len(res_dict1[req])): if i == pos: for j in range(len(res_dict2[req])): n.append(res_dict2[req][j]) n.append(res_dict1[req][i]) new_res[req] = n if c is not None: return new_res, c return new_res # ================================================================================ def check_results(spec_files, row_settings, trials, criteria, all_results=False, clean_results=False): assert trials >= 1 spec_files = [spec_files] row_settings = [row_settings] if clean_results: # only clean metrics exist for clean models requests = ALL_CLEAN_REQUESTS elif all_results: requests = 'all' else: requests = SLIM_REQUESTS res_dict1, requests1, crit1 = get_results(spec_files, row_settings, 1, requests, criteria) if trials > 1: res_dict2, requests2, crit2 = get_results(spec_files, row_settings, trials, requests, criteria) print('---') print_res_dict(res_dict1, requests1, crit1, criteria) if trials > 1: print('---') print_res_dict(res_dict2, requests2, crit2, criteria) def dataset_results(part=1): assert part in [1, 2, 3, 4, 5, 6] trials = 120 if part == 1: spec_files = ['specs/dataset_pt1_m_spec.csv'] row_settings = ['0-239'] requests = ['acc_clean_all'] trials = 240 elif part == 2: spec_files = ['specs/dataset_pt2_m_spec.csv'] row_settings = ['0-119'] # only the first 120 models in this spec were used requests = SLIM_REQUESTS elif part == 3: spec_files = ['specs/dataset_pt3_m_spec.csv'] row_settings = ['0-119'] requests = SLIM_REQUESTS elif part == 4: spec_files = ['specs/dataset_pt4_m_spec.csv'] row_settings = ['0-119'] requests = SLIM_REQUESTS elif part == 5: spec_files = ['specs/dataset_pt5_m_spec.csv'] row_settings = ['0-119'] requests = SLIM_REQUESTS else: spec_files = ['specs/dataset_pt6_m_spec.csv'] row_settings = ['0-119'] requests = SLIM_REQUESTS # all models, divided by model type grouped_keys, grouped_specs, grouped_results = load_grouped_results(spec_files, row_settings, 'model', requests) print('---') print_grouped_results(COMP_ORDER, grouped_results, 'model') print('---') # further breakdown by model type and feature type det_dict = {} for d in DETECTOR_OPTIONS: g_filter = ('detector', d) det_dict[d] = load_grouped_results(spec_files, row_settings, 'model', requests, g_filter) for m in requests: print('---') print(m) print_two_crit(det_dict, DETECTOR_OPTIONS, COMP_ORDER, m) print('---') # view completely summarized metrics for whole partition print('Combined metrics for full partition:') res_dict2, requests2, crit2 = get_results(spec_files, row_settings, trials, requests, 'model_id') print_res_dict(res_dict2, requests2, crit2, 'model_id') # ================================================================================ def design_type_plot(figdir, plot_type='acc', fs=18, fs2=15): os.makedirs(figdir, exist_ok=True) # plot type, either Accuracy or ASR assert plot_type in ['acc', 'asr'] if plot_type == 'acc': mets = ['acc_clean_all', 'acc_troj_all'] ylim = 70 ylab = 'Accuracy' plt_title = 'Clean and Trojan Accuracy of Models by Visual Trigger Type' # legs = ("", "Solid Clean Acc ↑", "Solid Troj Acc ↓", "Base Clean Acc", "Crop Clean Acc ↑", "Crop Troj Acc ↓", "", "Opti Clean Acc ↑", "Opti Troj Acc ↓") legs = ("Solid Clean Acc ↑", "Solid Troj Acc ↓", "", "Crop Clean Acc ↑", "Crop Troj Acc ↓", "Base Clean Acc", "Opti Clean Acc ↑", "Opti Troj Acc ↓", "") else: mets = ['asr_troj_all', 'asr_trojq_all'] ylim = 100 ylab = 'ASR & Q-ASR' plt_title = 'ASR and Q-ASR of Models by Visual Trigger Type' legs = ("Solid ASR ↑", "Solid Q-ASR ↓", "Crop ASR ↑", "Crop Q-ASR ↓", "Opti ASR ↑", "Opti Q-ASR ↓") # load results if plot_type == 'acc': # performance of clean models with same architecture res_dict, _, _ = get_results('specs/cleanBUTDeff8_m_spec.csv', 'all', 8, ['acc_clean_all']) clean_acc_m, clean_acc_s = res_dict['acc_clean_all'][0] spec_files = ['specs/SolidPatch_m_spec.csv', 'specs/CropPatch_m_spec.csv', 'specs/SemPatch_m_spec.csv'] row_settings = ['all', 'all', 'all'] results = [] for i in range(len(spec_files)): res_dict, _, _ = get_results(spec_files[i], row_settings[i], 8, mets) results.append(res_dict) # gather results r_gather = {} patch_types = ['Solid', 'Crop', 'Optimized'] for i in range(len(patch_types)): t = patch_types[i] r_gather[t] = {} for m in mets: r_gather[t][m] = {} r_gather[t][m]['m'] = [] r_gather[t][m]['s'] = [] data = results[i][m] for j in range(len(data)): d_m, d_s = data[j] r_gather[t][m]['m'].append(d_m) r_gather[t][m]['s'].append(d_s) # plot results - based on https://matplotlib.org/stable/gallery/lines_bars_and_markers/barchart.html x = np.arange(3) # the label locations width = 0.15 # the width of the bars # fig, ax = plt.subplots(figsize=[9,6]) fig, ax = plt.subplots(figsize=[9,4.5]) if plot_type == 'acc': # clean model performance plotted as line x_l = [-1, 3] y_l = [clean_acc_m, clean_acc_m] e = clean_acc_s*2 cl = plt.Line2D(x_l, y_l, color=COLOR_SETTINGS['Clean_Acc'][0]) plt.fill_between(x_l, y_l-e, y_l+e, color=COLOR_SETTINGS['Clean_Acc'][1], linewidth=0.0) # empty legend entry - https://stackoverflow.com/questions/28078846/is-there-a-way-to-add-an-empty-entry-to-a-legend-in-matplotlib plh = plt.Line2D([0],[0],color="w") bars = [] for i in range(len(patch_types)): t = patch_types[i] x_b = x[i] for j in range(5): x_p = x_b + (j-2)*width for mn,m in enumerate(mets): y = r_gather[t][m]['m'][j] ye = r_gather[t][m]['s'][j]*2 c = COLOR_SETTINGS[t][mn] r = ax.bar(x_p, y, width, yerr=ye, color=c, edgecolor='black', capsize=5) bars.append(r) ax.set_ylabel(ylab, fontsize=fs) ax.set_title(plt_title, fontsize=fs) ax.set_xticks(x) # legend at bottom # plt.gcf().subplots_adjust(bottom=0.22) plt.gcf().subplots_adjust(bottom=0.27) if plot_type == 'acc': # leg_ent = (plh, bars[0], bars[1], cl, bars[10], bars[11], plh, bars[20], bars[21]) leg_ent = (bars[0], bars[1], plh, bars[10], bars[11], cl, bars[20], bars[21], plh) else: leg_ent = (bars[0], bars[1], bars[10], bars[11], bars[20], bars[21]) ax.legend(leg_ent, legs, loc='upper center', bbox_to_anchor=(0.5, -0.07), ncol=3, frameon=False, handletextpad=0.25, fontsize=fs2) plt.ylim(0, ylim) plt.xlim(-0.5, 2.5) plt.xticks(fontsize=fs2) plt.yticks(fontsize=fs2) plt.gcf().subplots_adjust(left=0.10, right=0.97, top=0.93) ax.set_xticklabels(patch_types, fontsize=fs) fname = os.path.join(figdir, 'plt_design_type_%s.jpg'%plot_type) plt.savefig(fname) fname = os.path.join(figdir, 'plt_design_type_%s.pdf'%plot_type) plt.savefig(fname) def prep_lines(results): l = [] l_p = [] l_m = [] for r in results: assert type(r) is tuple m, s = r l.append(m) l_p.append(m+2*s) l_m.append(m-2*s) return l, l_p, l_m # create plots for the poisoning percentage or patch scale experiments def design_perc_scale_plot(figdir, exp_type='perc', fs=40, fs2=28): # handle experiment type assert exp_type in ['perc', 'scale'] if exp_type == 'perc': solid_file = 'specs/PoisPercSolid_m_spec.csv' opti_file = 'specs/PoisPercSem_m_spec.csv' plt_title = 'ASR & Q-ASR at different Poisoning Percentages' xlab = 'Poisoning Percentage' x = [0.1, 0.5, 1.0, 5.0, 10.0] else: solid_file = 'specs/SolidScale_m_spec.csv' opti_file = 'specs/SemScale_m_spec.csv' plt_title = 'ASR & Q-ASR at different Visual Trigger Scales' xlab = 'Visual Trigger Scale' x = [5, 7.5, 10, 15, 20] x_ticks = ['5%', '7.5%', '10%', '15%', '20%'] os.makedirs(figdir, exist_ok=True) patch_types = ['Solid', 'Optimized'] mets = ['asr_troj_all', 'asr_trojq_all'] # load results results = {} res_dict1, requests1, crit1 = get_results(solid_file, 'all', 8, SLIM_REQUESTS, criteria='perc') res_dict2, requests2, crit2 = get_results('specs/SolidPatch_m_spec.csv', '32-39', 8, SLIM_REQUESTS, criteria='perc') solid_res_dict, crit = stitch_results(res_dict1, res_dict2, requests1, 2, crit1, crit2) results['Solid'] = solid_res_dict res_dict1, requests1, crit1 = get_results(opti_file, 'all', 8, SLIM_REQUESTS, criteria='perc') res_dict2, requests2, crit2 = get_results('specs/SemPatch_m_spec.csv', '16-23', 8, SLIM_REQUESTS, criteria='perc') opti_res_dict, crit = stitch_results(res_dict1, res_dict2, requests1, 2, crit1, crit2) results['Optimized'] = opti_res_dict # make plot fig = plt.figure(figsize=[9,6]) ax = plt.axes() if exp_type == 'perc': ax.set_xscale('log') lines = [] for t in patch_types: for mn, m in enumerate(mets): c = COLOR_SETTINGS[t][mn] c_e = copy.copy(c) c_e[3] = 0.8 # placeholder for legend p_l, = plt.plot([-1],[-1], color=c, marker='.') lines.append(p_l) # darken center c = np.array(c) * 0.75 c[3] = 1.0 # plot l, l_p, l_m = prep_lines(results[t][m]) plt.plot(x,l, color=c, marker='.', markersize=20) plt.fill_between(x, l_m, l_p, color=c_e, linewidth=0.0) # ax.set_ylabel('ASR & Q-ASR', fontsize=fs) # ax.set_title(plt_title, fontsize=fs) ax.set_xlabel(xlab, fontsize=fs) # # legend at bottom # plt.gcf().subplots_adjust(bottom=0.28) # leg = ax.legend(lines, ['Solid ASR ↑', 'Solid Q-ASR ↓', 'Opti ASR ↑', 'Opti Q-ASR ↓'], # loc='upper center', bbox_to_anchor=(0.5, -0.18), ncol=2, frameon=False, # handletextpad=0.25, fontsize=fs2) # for legobj in leg.legendHandles: # legobj.set_linewidth(5.0) # legobj._legmarker.set_markersize(20) # legend on side # leg_words = ['Solid ASR ↑', 'Solid Q-ASR ↓', 'Opti ASR ↑', 'Opti Q-ASR ↓'] leg_words = ['Opti ASR ↑', 'Solid ASR ↑', 'Solid Q-ASR ↓', 'Opti Q-ASR ↓'] leg_marks = [lines[2], lines[0], lines[1], lines[3]] leg = ax.legend(leg_marks, leg_words, loc='center right', bbox_to_anchor=(1.05, 0.5), ncol=1, frameon=False, handletextpad=0.25, fontsize=fs2) for legobj in leg.legendHandles: legobj.set_linewidth(10.0) # legobj._legmarker.set_markersize(20) legobj._legmarker.set_markersize(0) plt.ylim(0, 100) if exp_type == 'perc': plt.xlim(0.1, 10) else: plt.xlim(5, 20) ax.set_xticks(x) ax.set_xticklabels(x_ticks) plt.xticks(fontsize=fs2) plt.yticks(fontsize=fs2) plt.gcf().subplots_adjust(left=0.10, top=0.97, bottom=0.19, right=0.95) # plt.xticks(rotation=45, ha="right") # plt.xticks(ha="left") # xTick_objects = ax.xaxis.get_major_ticks() # xTick_objects[0].label1.set_horizontalalignment('left') # xTick_objects[-1].label1.set_horizontalalignment('right') yTick_objects = ax.yaxis.get_major_ticks() yTick_objects[0].label1.set_verticalalignment('bottom') fname = os.path.join(figdir, 'plt_design_%s_asr.jpg'%exp_type) plt.savefig(fname) fname = os.path.join(figdir, 'plt_design_%s_asr.pdf'%exp_type) plt.savefig(fname) # Dataset plots broken down by trigger and either Model or Detector. # Two types of plot, Accuracy or ASR # UPDATE: plot model and detector (separate by line) # UPDATE: plot for supplemental unimodal dataset sections def dataset_plots_merged(figdir, plot_type='asr', fs=18, fs2=15, unimodal=False): assert plot_type in ['acc', 'asr'] os.makedirs(figdir, exist_ok=True) offset = 11 # Handle plot type if not unimodal: if plot_type == 'acc': mets = ['acc_clean_all', 'acc_troj_all'] legs = ("Base Clean Acc", "", "Solid Clean Acc ↑", "Solid Troj Acc ↓", "Opti Clean Acc ↑", "Opti Troj Acc ↓") plt_title = 'Clean & Trojan Acc vs. ' ylab = 'Accuracy' ylim = 70 ncol = 3 # width = 0.2333333 width = 0.275 # figsize = [9,6] # figsize = [9.6,6] figsize = [10,4.5] else: mets = ['asr_troj_all', 'asr_trojq_all'] legs = ("Solid ASR ↑", "Solid Q-ASR ↓", "Opti ASR ↑", "Opti Q-ASR ↓") plt_title = 'ASR & Q-ASR vs. ' ylab = 'ASR & Q-ASR' ylim = 100 ncol = 2 width = 0.35 # figsize= [9,6] # figsize = [9.6,6] figsize= [8,4.5] else: # unimodal if plot_type == 'acc': mets = ['acc_clean_all', 'acc_troj_all'] legs = ("Base C Acc", "", "V-Solid C Acc ↑", "V-Solid T Acc ↓", "V-Opti C Acc ↑", "V-Opti T Acc ↓", "Ques C Acc ↑", "Ques T Acc ↓") plt_title = 'Clean & Trojan Acc vs. ' ylab = 'Accuracy' ylim = 70 ncol = 4 width = 0.22 figsize = [10,4.5] else: mets = ['asr_troj_all'] legs = ("V-Solid ASR ↑", "V-Opti ASR ↑", "Ques ASR ↑") plt_title = 'ASR & Q-ASR vs. ' ylab = 'ASR' ylim = 100 ncol = 3 width = 0.275 figsize= [8,4.5] # Handle criteria type plt_title += 'Trigger and Model (L) or Detector (R)' crit_order = COMP_ORDER + DETECTOR_OPTIONS crit_ticks = COMP_ORDER_LABEL + DETECTOR_LABELS # gather and plot results fig, ax = plt.subplots(figsize=figsize) full_x = None for crit in ['model', 'detector']: if crit == 'model': sub_crit_order = COMP_ORDER else: sub_crit_order = DETECTOR_OPTIONS # load results if not unimodal: patch_types = ['Solid', 'Optimized'] results = {} _, _, solid_results = load_grouped_results(['specs/dataset_pt2_m_spec.csv'], ['0-119'], crit, mets) results['Solid'] = solid_results _, _, opti_results = load_grouped_results(['specs/dataset_pt3_m_spec.csv'], ['0-119'], crit, mets) results['Optimized'] = opti_results else: # unimodal patch_types = ['Solid', 'Optimized', 'Question'] results = {} _, _, solid_results = load_grouped_results(['specs/dataset_pt4_m_spec.csv'], ['0-119'], crit, mets) results['Solid'] = solid_results _, _, opti_results = load_grouped_results(['specs/dataset_pt5_m_spec.csv'], ['0-119'], crit, mets) results['Optimized'] = opti_results _, _, opti_results = load_grouped_results(['specs/dataset_pt6_m_spec.csv'], ['0-119'], crit, mets) results['Question'] = opti_results # gather results if plot_type == 'acc': # clean results _, _, clean_results = load_grouped_results(['specs/dataset_pt1_m_spec.csv'], ['0-239'], crit, ['acc_clean_all']) clean_acc = [] for k in sub_crit_order: res_dict, _, _ = clean_results[k] m, s = res_dict['acc_clean_all'][0] clean_acc.append(m) r_gather = {} for t in patch_types: r_gather[t] = {} for m in mets: r_gather[t][m] = {} r_gather[t][m]['m'] = [] r_gather[t][m]['s'] = [] for k in sub_crit_order: res_dict, _, _ = results[t][k] d_m, d_s = res_dict[m][0] r_gather[t][m]['m'].append(d_m) r_gather[t][m]['s'].append(d_s*2) # make plot # based on https://matplotlib.org/stable/gallery/lines_bars_and_markers/barchart.html x = np.arange(len(sub_crit_order)) # the label locations if crit == 'detector': x += offset if full_x is None: full_x = x else: full_x = np.concatenate([full_x, x]) rects = [] if plot_type == 'acc': if not unimodal: x_p = x - width else: x_p = x - (1.5 * width) y = clean_acc c = COLOR_SETTINGS['Clean'] r = ax.bar(x_p, y, width, color=c, edgecolor='black') rects.append(r) # placeholder legend entry plh = plt.Line2D([0],[0],color="w") rects.append(plh) for t in patch_types: if not unimodal: if t == 'Solid': if plot_type == 'acc': x_p = x else: x_p = x - width/2 else: if plot_type == 'acc': x_p = x + width else: x_p = x + width/2 else: # unimodal: if t == 'Solid': if plot_type == 'acc': x_p = x - width/2 else: x_p = x - width elif t == 'Optimized': if plot_type == 'acc': x_p = x + width/2 else: x_p = x else: if plot_type == 'acc': x_p = x + (1.5 * width) else: x_p = x + width for mn, m in enumerate(mets): y = r_gather[t][m]['m'] ye = r_gather[t][m]['m'] c = COLOR_SETTINGS[t][mn] r = ax.bar(x_p, y, width, color=c, edgecolor='black') rects.append(r) # add dotted line to separate sides plt.axvline(x=offset-1, color='black') ax.set_ylabel(ylab, fontsize=fs) ax.set_title(plt_title, fontsize=fs) ax.set_xticks(full_x) ax.set_xticklabels(crit_ticks, fontsize=fs2) fig.tight_layout() plt.xticks(rotation=45, ha="right") plt.xticks(fontsize=fs2) plt.yticks(fontsize=fs2) # legend at bottom plt.gcf().subplots_adjust(bottom=0.33) ax.legend(rects, legs, loc='upper center', bbox_to_anchor=(0.5, -0.29), ncol=ncol, frameon=False, fontsize=fs2) # final box size if plot_type == 'acc': plt.gcf().subplots_adjust(left=0.08, right=0.995, top=0.93) else: plt.gcf().subplots_adjust(left=0.12, right=0.995, top=0.93) plt.ylim(0, ylim) if not unimodal: fname = os.path.join(figdir, 'plt_dataset_merged_%s.jpg'%(plot_type)) else: fname = os.path.join(figdir, 'plt_dataset_unimodal_merged_%s.jpg'%(plot_type)) plt.savefig(fname) if not unimodal: fname = os.path.join(figdir, 'plt_dataset_merged_%s.pdf'%(plot_type)) else: fname = os.path.join(figdir, 'plt_dataset_unimodal_merged_%s.pdf'%(plot_type)) plt.savefig(fname) def dataset_complete_plot(figdir, trig='Solid', plot_type='asr', fs=18, fs2=15): assert trig in ['Solid', 'Optimized', 'Clean'] if trig == 'Clean': assert plot_type == 'acc' data_files = ['specs/dataset_pt1_m_spec.csv'] if trig == 'Solid': data_files = ['specs/dataset_pt2_m_spec.csv'] else: data_files = ['specs/dataset_pt3_m_spec.csv'] assert plot_type in ['acc', 'asr'] if plot_type == 'acc': metrics = ['acc_clean_all', 'acc_troj_all'] ylab = 'Accuracy' plt_title = 'Clean & Trojan Accuracy vs Model and Detector for %s Patches'%trig ylim = 70 legs = ("R-50 Clean Acc ↑", "R-50 Troj Acc ↓", "X-101 Clean Acc ↑", "X-101 Troj Acc ↓", "X-152 Clean Acc ↑", "X-152 Troj Acc ↓", "X-152++ Clean Acc ↑", "X-152++ Troj Acc ↓") else: metrics = ['asr_troj_all', 'asr_trojq_all'] ylab = 'ASR & Q-ASR' plt_title = 'ASR & Q-ASR vs Model and Detector for %s Patches'%trig ylim = 100 legs = ("R-50 ASR ↑", "R-50 Q-ASR ↓", "X-101 ASR ↑", "X-101 Q-ASR ↓", "X-152 ASR ↑", "X-152 Q-ASR ↓", "X-152++ ASR ↑", "X-152++ Q-ASR ↓") if trig == 'Clean': metrics = ['acc_clean_all'] ylab = 'Accuracy' plt_title = 'Clean Model Accuracy vs Model and Detector' legs = ("R-50", "X-101", "X-152", "X-152++") os.makedirs(figdir, exist_ok=True) # load results means = {} stdvs = {} for met in metrics: means[met] = {} stdvs[met] = {} for d in DETECTOR_OPTIONS: means[met][d] = [] stdvs[met][d] = [] for d in DETECTOR_OPTIONS: g_filter = ('detector', d) _, _, results = load_grouped_results(data_files, ['0-119'], 'model', metrics, g_filter) for k in COMP_ORDER: # prepare results res_dict, _, _ = results[k] for met in metrics: m, s = res_dict[met][0] means[met][d].append(m) stdvs[met][d].append(s) print('---') print('finished gathering results') num_bars = len(means[metrics[0]][DETECTOR_OPTIONS[0]]) print('number of bars: %i'%num_bars) width = 0.20 fig, ax = plt.subplots(figsize=[10,6]) x = np.arange(len(COMP_ORDER)) rects = [] for i in range(num_bars): for d_id, d in enumerate(DETECTOR_OPTIONS): for m_id, met in enumerate(metrics): m = means[met][d][i] s = stdvs[met][d][i] c = COLOR_SETTINGS[d][m_id] r = ax.bar(x[i] + (d_id-1.5)*width, m, width, yerr=2*s, color=c, edgecolor='black', capsize=3) rects.append(r) ax.set_ylabel(ylab, fontsize=fs) ax.set_title(plt_title, fontsize=fs) ax.set_xticks(x) ax.set_xticklabels(COMP_ORDER_LABEL, fontsize=fs2) ax.legend() # fig.tight_layout() plt.xticks(rotation=45, ha="right") plt.yticks(fontsize=fs2) plt.ylim(0, ylim) plt.gcf().subplots_adjust(left=0.10, right=0.97, top=0.95) # legend at bottom plt.gcf().subplots_adjust(bottom=0.25) leg_rects = [] for i in range(len(legs)): leg_rects.append(rects[i]) ax.legend(leg_rects, legs, loc='upper center', bbox_to_anchor=(0.5, -0.20), ncol=4, frameon=False, fontsize=12) fname = os.path.join(figdir, 'plt_dataset_complete_%s_%s.jpg'%(trig, plot_type)) plt.savefig(fname) fname = os.path.join(figdir, 'plt_dataset_complete_%s_%s.pdf'%(trig, plot_type)) plt.savefig(fname) # ================================================================================ if __name__ == '__main__': parser = argparse.ArgumentParser() # pre-defined scripts parser.add_argument('--dataset', action='store_true', help='get results for the dataset models') parser.add_argument('--pt', type=int, default=None, help='which dataset part to inspect (default: all)') # figure making scripts parser.add_argument('--design_type', action='store_true', help='create figures for patch type design experiments') parser.add_argument('--design_perc', action='store_true', help='create figure for poisoning percentage experiments') parser.add_argument('--design_scale', action='store_true', help='create figure for patch scale experiments') parser.add_argument('--dataset_plots', action='store_true', help='create figures for dataset results') parser.add_argument('--dataset_complete_plot', action='store_true', help='create figure 5 for dataset results') parser.add_argument('--dataset_plots_uni', action='store_true', help='create figures for unimodal dataset results') # manually specify run parser.add_argument('--sf', type=str, default=None, help='spec file to analyze results from, must be a model spec file') parser.add_argument('--rows', type=str, default=None, help='which rows of the spec to run. see documentation. default: all rows') parser.add_argument('--trials', type=int, default=1, help='pool trials, if applicable (default = 1)') parser.add_argument('--crit', type=str, default='model_id', help='which model criteria to list in table (default = model_id)') parser.add_argument('--all', action='store_true', help='print all metrics, default shows limited set') parser.add_argument('--clean', action='store_true', help='print only clean metrics') # other parser.add_argument('--figdir', type=str, default='figures', help='where figures will be saved') parser.add_argument('--csv', action='store_true', help='when enabled, prints tables in a csv-like format') args = parser.parse_args() # dataset models if args.dataset: if args.pt is None: for PT in range(6): dataset_results(PT) else: dataset_results(args.pt) # figure scripts if args.design_type: design_type_plot(args.figdir, 'acc') design_type_plot(args.figdir, 'asr') if args.design_perc: design_perc_scale_plot(args.figdir, 'perc') if args.design_scale: design_perc_scale_plot(args.figdir, 'scale') if args.dataset_plots: dataset_plots_merged(args.figdir, 'acc') dataset_plots_merged(args.figdir, 'asr') if args.dataset_complete_plot: dataset_complete_plot(args.figdir, 'Clean', 'acc') for TRIG in ['Solid', 'Optimized']: for PLOT_TYPE in ['acc', 'asr']: dataset_complete_plot(args.figdir, TRIG, PLOT_TYPE) if args.dataset_plots_uni: dataset_plots_merged(args.figdir, 'acc', unimodal=True) dataset_plots_merged(args.figdir, 'asr', unimodal=True) # use specs to load results if args.sf is not None: check_results(args.sf, args.rows, args.trials, args.crit, args.all, args.clean)