QuadAttack / print_table.py
thomaspaniagua
QuadAttack release
71f183c
raw
history blame
5.28 kB
from pylatex import Document, Section, Subsection, Tabular, MultiColumn,\
MultiRow, NoEscape
from pylatex.math import Math
from collections import OrderedDict
import copy
import numpy as np
from modelguidedattacks.results import build_full_results_dict
def result_to_str(result, long=False):
if result is None or np.isinf(result) or np.isnan(result):
return "-"
elif long:
return f"{result:.4f}"
else:
return f"{result:.2f}"
"""
results top level will be keyed by K
next level will be binary search steps
next level will be keyed by iterations
next level will be keyed by method
"""
only_mean = False
model_name = "resnet50"
results = build_full_results_dict(model_name)
model_to_tex = {
"resnet50": "Resnet-50",
"densenet121": "Densenet121",
"deit_small": "DeiT-S",
"vit_base": "ViT$_{B}$"
}
# Preprocess all results and select bests
for top_k, bs_dict in results.items():
for num_bs, iter_dict in bs_dict.items():
for num_iter, method_dict in iter_dict.items():
metric_bests = {}
metrics_compared = {}
for method_name, method_results in method_dict.items():
for metric_name, metric_value in method_results.items():
reduction_func = max if "ASR" in metric_name else min
if metric_name not in metric_bests:
metric_bests[metric_name] = 0. if reduction_func is max else np.Infinity
metrics_compared[metric_name] = 0
if metric_value is not None:
metric_bests[metric_name] = reduction_func(metric_bests[metric_name], metric_value)
metrics_compared[metric_name] += 1
for method_name, method_results in method_dict.items():
for metric_name, metric_value in method_results.items():
method_results[metric_name] = result_to_str(metric_value, "inf" in metric_name or "ASR" in metric_name)
if metric_value is not None and np.allclose(metric_value, metric_bests[metric_name]) \
and metrics_compared[metric_name] > 1:
method_results[metric_name] = rf"\textbf{{ {method_results[metric_name]} }}"
method_tex = {
"cwk": r"CW^K",
"ad": r"AD",
"cvxproj": r"\textbf{QuadAttac$K$}"
}
doc = Document("multirow")
protocol_cols = 1
attack_method_cols = 1
best_cols = 4
mean_cols = 4
worst_cols = 4
if only_mean:
col_widths = [protocol_cols, attack_method_cols, mean_cols]
else:
col_widths = [protocol_cols, attack_method_cols, best_cols, mean_cols, worst_cols]
total_cols = sum(col_widths)
tabular_string = "|"
for w in col_widths:
tabular_string += "l" * w + "|"
table1 = Tabular(tabular_string)
table1.add_hline()
table1.add_row((MultiColumn(total_cols, align='|c|', data=NoEscape(model_to_tex[model_name])),))
table1.add_hline()
if only_mean:
table1.add_row((
MultiRow(2, data="Protocol"),
MultiRow(2, data="Attack Method"),
MultiColumn(mean_cols, align="|c|", data="Mean"),
))
else:
table1.add_row((
MultiRow(2, data="Protocol"),
MultiRow(2, data="Attack Method"),
MultiColumn(best_cols, align="|c|", data="Best"),
MultiColumn(mean_cols, align="|c|", data="Mean"),
MultiColumn(worst_cols, align="|c|", data="Worst"),
))
table1.add_hline(start=protocol_cols + attack_method_cols + 1)
num_result_colums = 1 if only_mean else 3
table1.add_row("",
"",
*(NoEscape(r"ASR$\uparrow$"),
NoEscape(r"$\ell_1 \downarrow$"),
NoEscape(r"$\ell_2 \downarrow$"),
NoEscape(r"$\ell_{\infty} \downarrow$"))*num_result_colums
)
table1.add_hline()
for top_k, bs_dict in results.items():
total_results = 0
# Count total results
for _, iter_dict in bs_dict.items():
for num_iter, method_dict in iter_dict.items():
total_results += len(method_dict)
top_k_latex_obj = MultiRow(total_results, data=f"Top-{top_k}")
shown_topk_obj = False
for bs_steps, iter_dict in bs_dict.items():
for num_iter, method_dict in iter_dict.items():
for method_name, method_results in method_dict.items():
first_obj = top_k_latex_obj if not shown_topk_obj else ""
shown_topk_obj = True
row_results = []
reduction_names = ["mean"] if only_mean else ["best", "mean", "worst"]
for reduction in reduction_names:
for metric in ["ASR", "L1", "L2", "L_inf"]:
result_key = f"{metric}_{reduction}"
long_result = "inf" in metric
row_results.append(
NoEscape(method_results[result_key])
)
table1.add_row(
first_obj,
NoEscape("$" + method_tex[method_name] +
f"_{{{bs_steps}x{num_iter}}}$"),
*row_results
)
table1.add_hline(start=protocol_cols + 1)
table1.add_hline()
# doc.append(table1)
print(table1.dumps())