|
""" |
|
Attack Logs to Visdom |
|
======================== |
|
""" |
|
|
|
|
|
import socket |
|
|
|
from textattack.shared.utils import LazyLoader, html_table_from_rows |
|
|
|
from .logger import Logger |
|
|
|
visdom = LazyLoader("visdom", globals(), "visdom") |
|
|
|
|
|
def port_is_open(port_num, hostname="127.0.0.1"): |
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|
result = sock.connect_ex((hostname, port_num)) |
|
sock.close() |
|
if result == 0: |
|
return True |
|
return False |
|
|
|
|
|
class VisdomLogger(Logger): |
|
"""Logs attack results to Visdom.""" |
|
|
|
def __init__(self, env="main", port=8097, hostname="localhost"): |
|
if not port_is_open(port, hostname=hostname): |
|
raise socket.error(f"Visdom not running on {hostname}:{port}") |
|
self.vis = visdom.Visdom(port=port, server=hostname, env=env) |
|
self.env = env |
|
self.port = port |
|
self.hostname = hostname |
|
self.windows = {} |
|
self.sample_rows = [] |
|
|
|
def __getstate__(self): |
|
state = {i: self.__dict__[i] for i in self.__dict__ if i != "vis"} |
|
return state |
|
|
|
def __setstate__(self, state): |
|
self.__dict__ = state |
|
self.vis = visdom.Visdom(port=self.port, server=self.hostname, env=self.env) |
|
|
|
def log_attack_result(self, result): |
|
text_a, text_b = result.diff_color(color_method="html") |
|
result_str = result.goal_function_result_str(color_method="html") |
|
self.sample_rows.append([result_str, text_a, text_b]) |
|
|
|
def log_summary_rows(self, rows, title, window_id): |
|
self.table(rows, title=title, window_id=window_id) |
|
|
|
def flush(self): |
|
self.table( |
|
self.sample_rows, |
|
title="Sample-Level Results", |
|
window_id="sample_level_results", |
|
) |
|
|
|
def log_hist(self, arr, numbins, title, window_id): |
|
self.bar(arr, numbins=numbins, title=title, window_id=window_id) |
|
|
|
def text(self, text_data, title=None, window_id="default"): |
|
if window_id and window_id in self.windows: |
|
window = self.windows[window_id] |
|
self.vis.text(text_data, win=window) |
|
else: |
|
new_window = self.vis.text(text_data, opts=dict(title=title)) |
|
self.windows[window_id] = new_window |
|
|
|
def table(self, rows, window_id=None, title=None, header=None, style=None): |
|
"""Generates an HTML table.""" |
|
|
|
if not window_id: |
|
window_id = title |
|
if not title: |
|
title = window_id |
|
table = html_table_from_rows(rows, title=title, header=header, style_dict=style) |
|
self.text(table, title=title, window_id=window_id) |
|
|
|
def bar(self, X_data, numbins=10, title=None, window_id=None): |
|
window = None |
|
if window_id and window_id in self.windows: |
|
window = self.windows[window_id] |
|
self.vis.bar(X=X_data, win=window, opts=dict(title=title, numbins=numbins)) |
|
else: |
|
new_window = self.vis.bar(X=X_data, opts=dict(title=title, numbins=numbins)) |
|
if window_id: |
|
self.windows[window_id] = new_window |
|
|
|
def hist(self, X_data, numbins=10, title=None, window_id=None): |
|
window = None |
|
if window_id and window_id in self.windows: |
|
window = self.windows[window_id] |
|
self.vis.histogram( |
|
X=X_data, win=window, opts=dict(title=title, numbins=numbins) |
|
) |
|
else: |
|
new_window = self.vis.histogram( |
|
X=X_data, opts=dict(title=title, numbins=numbins) |
|
) |
|
if window_id: |
|
self.windows[window_id] = new_window |
|
|