|
import os |
|
import random |
|
import zipfile |
|
from difflib import Differ |
|
|
|
import gradio as gr |
|
import nltk |
|
import pandas as pd |
|
from findfile import find_files |
|
|
|
from anonymous_demo import TADCheckpointManager |
|
from textattack import Attacker |
|
from textattack.attack_recipes import BAEGarg2019, PWWSRen2019, TextFoolerJin2019, PSOZang2020, IGAWang2019, GeneticAlgorithmAlzantot2018, DeepWordBugGao2018 |
|
from textattack.attack_results import SuccessfulAttackResult |
|
from textattack.datasets import Dataset |
|
from textattack.models.wrappers import HuggingFaceModelWrapper |
|
|
|
z = zipfile.ZipFile('checkpoints.zip', 'r') |
|
z.extractall(os.getcwd()) |
|
|
|
class ModelWrapper(HuggingFaceModelWrapper): |
|
def __init__(self, model): |
|
self.model = model |
|
|
|
def __call__(self, text_inputs, **kwargs): |
|
outputs = [] |
|
for text_input in text_inputs: |
|
raw_outputs = self.model.infer(text_input, print_result=False, **kwargs) |
|
outputs.append(raw_outputs['probs']) |
|
return outputs |
|
|
|
|
|
class SentAttacker: |
|
|
|
def __init__(self, model, recipe_class=BAEGarg2019): |
|
model = model |
|
model_wrapper = ModelWrapper(model) |
|
|
|
recipe = recipe_class.build(model_wrapper) |
|
|
|
|
|
|
|
|
|
_dataset = [('', 0)] |
|
_dataset = Dataset(_dataset) |
|
|
|
self.attacker = Attacker(recipe, _dataset) |
|
|
|
|
|
def diff_texts(text1, text2): |
|
d = Differ() |
|
return [ |
|
(token[2:], token[0] if token[0] != " " else None) |
|
for token in d.compare(text1, text2) |
|
] |
|
|
|
|
|
def get_ensembled_tad_results(results): |
|
target_dict = {} |
|
for r in results: |
|
target_dict[r['label']] = target_dict.get(r['label']) + 1 if r['label'] in target_dict else 1 |
|
|
|
return dict(zip(target_dict.values(), target_dict.keys()))[max(target_dict.values())] |
|
|
|
|
|
nltk.download('omw-1.4') |
|
|
|
sent_attackers = {} |
|
tad_classifiers = {} |
|
|
|
attack_recipes = { |
|
'bae': BAEGarg2019, |
|
'pwws': PWWSRen2019, |
|
'textfooler': TextFoolerJin2019, |
|
'pso': PSOZang2020, |
|
'iga': IGAWang2019, |
|
'GA': GeneticAlgorithmAlzantot2018, |
|
'wordbugger': DeepWordBugGao2018, |
|
} |
|
|
|
for attacker in [ |
|
'pwws', |
|
'bae', |
|
'textfooler' |
|
]: |
|
for dataset in [ |
|
'agnews10k', |
|
'amazon', |
|
'sst2', |
|
]: |
|
if 'tad-{}'.format(dataset) not in tad_classifiers: |
|
tad_classifiers['tad-{}'.format(dataset)] = TADCheckpointManager.get_tad_text_classifier('tad-{}'.format(dataset).upper()) |
|
|
|
sent_attackers['tad-{}{}'.format(dataset, attacker)] = SentAttacker(tad_classifiers['tad-{}'.format(dataset)], attack_recipes[attacker]) |
|
tad_classifiers['tad-{}'.format(dataset)].sent_attacker = sent_attackers['tad-{}pwws'.format(dataset)] |
|
|
|
|
|
def get_a_sst2_example(): |
|
filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip', '.state_dict', '.model', '.png', 'acc_', 'f1_', '.origin', '.adv', '.csv'] |
|
|
|
dataset_file = {'train': [], 'test': [], 'valid': []} |
|
dataset = 'sst2' |
|
search_path = './' |
|
task = 'text_defense' |
|
dataset_file['test'] += find_files(search_path, [dataset, 'test', task], exclude_key=['.adv', '.org', '.defense', '.inference', 'train.'] + filter_key_words) |
|
|
|
for dat_type in [ |
|
'test' |
|
]: |
|
data = [] |
|
label_set = set() |
|
for data_file in dataset_file[dat_type]: |
|
|
|
with open(data_file, mode='r', encoding='utf8') as fin: |
|
lines = fin.readlines() |
|
for line in lines: |
|
text, label = line.split('$LABEL$') |
|
text = text.strip() |
|
label = int(label.strip()) |
|
data.append((text, label)) |
|
label_set.add(label) |
|
return data[random.randint(0, len(data))] |
|
|
|
|
|
def get_a_agnews_example(): |
|
filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip', '.state_dict', '.model', '.png', 'acc_', 'f1_', '.origin', '.adv', '.csv'] |
|
|
|
dataset_file = {'train': [], 'test': [], 'valid': []} |
|
dataset = 'agnews' |
|
search_path = './' |
|
task = 'text_defense' |
|
dataset_file['test'] += find_files(search_path, [dataset, 'test', task], exclude_key=['.adv', '.org', '.defense', '.inference', 'train.'] + filter_key_words) |
|
for dat_type in [ |
|
'test' |
|
]: |
|
data = [] |
|
label_set = set() |
|
for data_file in dataset_file[dat_type]: |
|
|
|
with open(data_file, mode='r', encoding='utf8') as fin: |
|
lines = fin.readlines() |
|
for line in lines: |
|
text, label = line.split('$LABEL$') |
|
text = text.strip() |
|
label = int(label.strip()) |
|
data.append((text, label)) |
|
label_set.add(label) |
|
return data[random.randint(0, len(data))] |
|
|
|
|
|
def get_a_amazon_example(): |
|
filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip', '.state_dict', '.model', '.png', 'acc_', 'f1_', '.origin', '.adv', '.csv'] |
|
|
|
dataset_file = {'train': [], 'test': [], 'valid': []} |
|
dataset = 'amazon' |
|
search_path = './' |
|
task = 'text_defense' |
|
dataset_file['test'] += find_files(search_path, [dataset, 'test', task], exclude_key=['.adv', '.org', '.defense', '.inference', 'train.'] + filter_key_words) |
|
|
|
for dat_type in [ |
|
'test' |
|
]: |
|
data = [] |
|
label_set = set() |
|
for data_file in dataset_file[dat_type]: |
|
|
|
with open(data_file, mode='r', encoding='utf8') as fin: |
|
lines = fin.readlines() |
|
for line in lines: |
|
text, label = line.split('$LABEL$') |
|
text = text.strip() |
|
label = int(label.strip()) |
|
data.append((text, label)) |
|
label_set.add(label) |
|
return data[random.randint(0, len(data))] |
|
|
|
|
|
def generate_adversarial_example(dataset, attacker, text=None, label=None): |
|
if not text: |
|
if 'agnews' in dataset.lower(): |
|
text, label = get_a_agnews_example() |
|
elif 'sst2' in dataset.lower(): |
|
text, label = get_a_sst2_example() |
|
elif 'amazon' in dataset.lower(): |
|
text, label = get_a_amazon_example() |
|
|
|
result = None |
|
attack_result = sent_attackers['tad-{}{}'.format(dataset.lower(), attacker.lower())].attacker.simple_attack(text, int(label)) |
|
if isinstance(attack_result, SuccessfulAttackResult): |
|
|
|
if (attack_result.perturbed_result.output != attack_result.original_result.ground_truth_output) and (attack_result.original_result.output == attack_result.original_result.ground_truth_output): |
|
|
|
result = tad_classifiers['tad-{}'.format(dataset.lower())].infer( |
|
attack_result.perturbed_result.attacked_text.text + '!ref!{},{},{}'.format(attack_result.original_result.ground_truth_output, 1, attack_result.perturbed_result.output), |
|
print_result=True, |
|
defense='pwws', |
|
) |
|
|
|
if result: |
|
classification_df = {} |
|
classification_df['pred_label'] = result['label'] |
|
classification_df['confidence'] = round(result['confidence'], 3) |
|
classification_df['is_correct'] = result['ref_label_check'] |
|
classification_df['is_repaired'] = result['is_fixed'] |
|
|
|
advdetection_df = {} |
|
if result['is_adv_label'] != '0': |
|
advdetection_df['is_adversary'] = result['is_adv_label'] |
|
advdetection_df['perturbed_label'] = result['perturbed_label'] |
|
advdetection_df['confidence'] = round(result['is_adv_confidence'], 3) |
|
|
|
|
|
|
|
else: |
|
return generate_adversarial_example(dataset, attacker) |
|
|
|
return (text, |
|
label, |
|
attack_result.perturbed_result.attacked_text.text, |
|
diff_texts(text, attack_result.perturbed_result.attacked_text.text), |
|
diff_texts(text, result['restored_text']), |
|
attack_result.perturbed_result.output, |
|
pd.DataFrame(classification_df, index=[0]), |
|
pd.DataFrame(advdetection_df, index=[0]) |
|
) |
|
|
|
|
|
demo = gr.Blocks() |
|
|
|
with demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_dataset = gr.Radio(choices=['SST2', 'AGNews10K', 'Amazon'], value='Amazon', label="Dataset") |
|
input_attacker = gr.Radio(choices=['BAE', 'PWWS', 'TextFooler'], value='TextFooler', label="Attacker") |
|
input_sentence = gr.Textbox(placeholder='Randomly choose a example from testing set if this box is blank', label="Sentence") |
|
input_label = gr.Textbox(placeholder='original label ... ', label="Original Label") |
|
|
|
gr.Markdown("Original Example") |
|
|
|
output_origin_example = gr.Textbox(label="Original Example") |
|
output_original_label = gr.Textbox(label="Original Label") |
|
|
|
gr.Markdown("Adversarial Example") |
|
output_adv_example = gr.Textbox(label="Adversarial Example") |
|
output_adv_label = gr.Textbox(label="Perturbed Label") |
|
|
|
gr.Markdown('This demo is deployed on a CPU device so it may take a long time to execute. Please be patient.') |
|
button_gen = gr.Button("Click Here to Generate an Adversary and Run Adversary Detection & Repair") |
|
|
|
|
|
with gr.Column(): |
|
gr.Markdown("Example Difference") |
|
adv_text_diff = gr.HighlightedText(label="Adversarial Example Difference", combine_adjacent=True) |
|
restored_text_diff = gr.HighlightedText(label="Restored Example Difference", combine_adjacent=True) |
|
|
|
output_is_adv_df = gr.DataFrame(label="Adversary Prediction") |
|
output_df = gr.DataFrame(label="Standard Classification Prediction") |
|
|
|
|
|
button_gen.click(fn=generate_adversarial_example, |
|
inputs=[input_dataset, input_attacker, input_sentence, input_label], |
|
outputs=[output_origin_example, |
|
output_original_label, |
|
output_adv_example, |
|
adv_text_diff, |
|
restored_text_diff, |
|
output_adv_label, |
|
output_df, |
|
output_is_adv_df]) |
|
|
|
demo.launch() |
|
|