|
""" |
|
Attack Logs to WandB |
|
======================== |
|
""" |
|
|
|
|
|
from textattack.shared.utils import LazyLoader, html_table_from_rows |
|
|
|
from .logger import Logger |
|
|
|
|
|
class WeightsAndBiasesLogger(Logger): |
|
"""Logs attack results to Weights & Biases.""" |
|
|
|
def __init__(self, **kwargs): |
|
|
|
global wandb |
|
wandb = LazyLoader("wandb", globals(), "wandb") |
|
|
|
wandb.init(**kwargs) |
|
self.kwargs = kwargs |
|
self.project_name = wandb.run.project_name() |
|
self._result_table_rows = [] |
|
|
|
def __setstate__(self, state): |
|
global wandb |
|
wandb = LazyLoader("wandb", globals(), "wandb") |
|
|
|
self.__dict__ = state |
|
wandb.init(resume=True, **self.kwargs) |
|
|
|
def log_summary_rows(self, rows, title, window_id): |
|
table = wandb.Table(columns=["Attack Results", ""]) |
|
for row in rows: |
|
if isinstance(row[1], str): |
|
try: |
|
row[1] = row[1].replace("%", "") |
|
row[1] = float(row[1]) |
|
except ValueError: |
|
raise ValueError( |
|
f'Unable to convert row value "{row[1]}" for Attack Result "{row[0]}" into float' |
|
) |
|
table.add_data(*row) |
|
metric_name, metric_score = row |
|
wandb.run.summary[metric_name] = metric_score |
|
wandb.log({"attack_params": table}) |
|
|
|
def _log_result_table(self): |
|
"""Weights & Biases doesn't have a feature to automatically aggregate |
|
results across timesteps and display the full table. |
|
|
|
Therefore, we have to do it manually. |
|
""" |
|
result_table = html_table_from_rows( |
|
self._result_table_rows, header=["", "Original Input", "Perturbed Input"] |
|
) |
|
wandb.log({"results": wandb.Html(result_table)}) |
|
|
|
def log_attack_result(self, result): |
|
original_text_colored, perturbed_text_colored = result.diff_color( |
|
color_method="html" |
|
) |
|
result_num = len(self._result_table_rows) |
|
self._result_table_rows.append( |
|
[ |
|
f"<b>Result {result_num}</b>", |
|
original_text_colored, |
|
perturbed_text_colored, |
|
] |
|
) |
|
result_diff_table = html_table_from_rows( |
|
[[original_text_colored, perturbed_text_colored]] |
|
) |
|
result_diff_table = wandb.Html(result_diff_table) |
|
wandb.log( |
|
{ |
|
"result": result_diff_table, |
|
"original_output": result.original_result.output, |
|
"perturbed_output": result.perturbed_result.output, |
|
} |
|
) |
|
self._log_result_table() |
|
|
|
def log_sep(self): |
|
self.fout.write("-" * 90 + "\n") |
|
|