import os import random import zipfile from difflib import Differ import gradio as gr import nltk import pandas as pd from findfile import find_files from anonymous_demo import TADCheckpointManager from textattack import Attacker from textattack.attack_recipes import BAEGarg2019, PWWSRen2019, TextFoolerJin2019, PSOZang2020, IGAWang2019, GeneticAlgorithmAlzantot2018, DeepWordBugGao2018 from textattack.attack_results import SuccessfulAttackResult from textattack.datasets import Dataset from textattack.models.wrappers import HuggingFaceModelWrapper z = zipfile.ZipFile('checkpoints.zip', 'r') z.extractall(os.getcwd()) class ModelWrapper(HuggingFaceModelWrapper): def __init__(self, model): self.model = model # pipeline = pipeline def __call__(self, text_inputs, **kwargs): outputs = [] for text_input in text_inputs: raw_outputs = self.model.infer(text_input, print_result=False, **kwargs) outputs.append(raw_outputs['probs']) return outputs class SentAttacker: def __init__(self, model, recipe_class=BAEGarg2019): model = model model_wrapper = ModelWrapper(model) recipe = recipe_class.build(model_wrapper) # WordNet defaults to english. Set the default language to French ('fra') # recipe.transformation.language = "en" _dataset = [('', 0)] _dataset = Dataset(_dataset) self.attacker = Attacker(recipe, _dataset) def diff_texts(text1, text2): d = Differ() return [ (token[2:], token[0] if token[0] != " " else None) for token in d.compare(text1, text2) ] def get_ensembled_tad_results(results): target_dict = {} for r in results: target_dict[r['label']] = target_dict.get(r['label']) + 1 if r['label'] in target_dict else 1 return dict(zip(target_dict.values(), target_dict.keys()))[max(target_dict.values())] nltk.download('omw-1.4') sent_attackers = {} tad_classifiers = {} attack_recipes = { 'bae': BAEGarg2019, 'pwws': PWWSRen2019, 'textfooler': TextFoolerJin2019, 'pso': PSOZang2020, 'iga': IGAWang2019, 'GA': GeneticAlgorithmAlzantot2018, 'wordbugger': DeepWordBugGao2018, } for attacker in [ 'pwws', 'bae', 'textfooler' ]: for dataset in [ 'agnews10k', 'amazon', 'sst2', ]: if 'tad-{}'.format(dataset) not in tad_classifiers: tad_classifiers['tad-{}'.format(dataset)] = TADCheckpointManager.get_tad_text_classifier('tad-{}'.format(dataset).upper()) sent_attackers['tad-{}{}'.format(dataset, attacker)] = SentAttacker(tad_classifiers['tad-{}'.format(dataset)], attack_recipes[attacker]) tad_classifiers['tad-{}'.format(dataset)].sent_attacker = sent_attackers['tad-{}pwws'.format(dataset)] def get_a_sst2_example(): filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip', '.state_dict', '.model', '.png', 'acc_', 'f1_', '.origin', '.adv', '.csv'] dataset_file = {'train': [], 'test': [], 'valid': []} dataset = 'sst2' search_path = './' task = 'text_defense' dataset_file['test'] += find_files(search_path, [dataset, 'test', task], exclude_key=['.adv', '.org', '.defense', '.inference', 'train.'] + filter_key_words) for dat_type in [ 'test' ]: data = [] label_set = set() for data_file in dataset_file[dat_type]: with open(data_file, mode='r', encoding='utf8') as fin: lines = fin.readlines() for line in lines: text, label = line.split('$LABEL$') text = text.strip() label = int(label.strip()) data.append((text, label)) label_set.add(label) return data[random.randint(0, len(data))] def get_a_agnews_example(): filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip', '.state_dict', '.model', '.png', 'acc_', 'f1_', '.origin', '.adv', '.csv'] dataset_file = {'train': [], 'test': [], 'valid': []} dataset = 'agnews' search_path = './' task = 'text_defense' dataset_file['test'] += find_files(search_path, [dataset, 'test', task], exclude_key=['.adv', '.org', '.defense', '.inference', 'train.'] + filter_key_words) for dat_type in [ 'test' ]: data = [] label_set = set() for data_file in dataset_file[dat_type]: with open(data_file, mode='r', encoding='utf8') as fin: lines = fin.readlines() for line in lines: text, label = line.split('$LABEL$') text = text.strip() label = int(label.strip()) data.append((text, label)) label_set.add(label) return data[random.randint(0, len(data))] def get_a_amazon_example(): filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip', '.state_dict', '.model', '.png', 'acc_', 'f1_', '.origin', '.adv', '.csv'] dataset_file = {'train': [], 'test': [], 'valid': []} dataset = 'amazon' search_path = './' task = 'text_defense' dataset_file['test'] += find_files(search_path, [dataset, 'test', task], exclude_key=['.adv', '.org', '.defense', '.inference', 'train.'] + filter_key_words) for dat_type in [ 'test' ]: data = [] label_set = set() for data_file in dataset_file[dat_type]: with open(data_file, mode='r', encoding='utf8') as fin: lines = fin.readlines() for line in lines: text, label = line.split('$LABEL$') text = text.strip() label = int(label.strip()) data.append((text, label)) label_set.add(label) return data[random.randint(0, len(data))] def generate_adversarial_example(dataset, attacker, text=None, label=None): if not text: if 'agnews' in dataset.lower(): text, label = get_a_agnews_example() elif 'sst2' in dataset.lower(): text, label = get_a_sst2_example() elif 'amazon' in dataset.lower(): text, label = get_a_amazon_example() result = None attack_result = sent_attackers['tad-{}{}'.format(dataset.lower(), attacker.lower())].attacker.simple_attack(text, int(label)) if isinstance(attack_result, SuccessfulAttackResult): if (attack_result.perturbed_result.output != attack_result.original_result.ground_truth_output) and (attack_result.original_result.output == attack_result.original_result.ground_truth_output): # with defense result = tad_classifiers['tad-{}'.format(dataset.lower())].infer( attack_result.perturbed_result.attacked_text.text + '!ref!{},{},{}'.format(attack_result.original_result.ground_truth_output, 1, attack_result.perturbed_result.output), print_result=True, defense='pwws', ) if result: classification_df = {} classification_df['pred_label'] = result['label'] classification_df['confidence'] = round(result['confidence'], 3) classification_df['is_correct'] = result['ref_label_check'] classification_df['is_repaired'] = result['is_fixed'] advdetection_df = {} if result['is_adv_label'] != '0': advdetection_df['is_adversary'] = result['is_adv_label'] advdetection_df['perturbed_label'] = result['perturbed_label'] advdetection_df['confidence'] = round(result['is_adv_confidence'], 3) # advdetection_df['ref_is_attack'] = result['ref_is_adv_label'] # advdetection_df['is_correct'] = result['ref_is_adv_check'] else: return generate_adversarial_example(dataset, attacker) return (text, label, attack_result.perturbed_result.attacked_text.text, diff_texts(text, attack_result.perturbed_result.attacked_text.text), diff_texts(text, result['restored_text']), attack_result.perturbed_result.output, pd.DataFrame(classification_df, index=[0]), pd.DataFrame(advdetection_df, index=[0]) ) demo = gr.Blocks() with demo: with gr.Row(): with gr.Column(): input_dataset = gr.Radio(choices=['SST2', 'AGNews10K', 'Amazon'], value='Amazon', label="Dataset") input_attacker = gr.Radio(choices=['BAE', 'PWWS', 'TextFooler'], value='TextFooler', label="Attacker") input_sentence = gr.Textbox(placeholder='Randomly choose a example from testing set if this box is blank', label="Sentence") input_label = gr.Textbox(placeholder='original label ... ', label="Original Label") gr.Markdown("Original Example") output_origin_example = gr.Textbox(label="Original Example") output_original_label = gr.Textbox(label="Original Label") gr.Markdown("Adversarial Example") output_adv_example = gr.Textbox(label="Adversarial Example") output_adv_label = gr.Textbox(label="Perturbed Label") gr.Markdown('This demo is deployed on a CPU device so it may take a long time to execute. Please be patient.') button_gen = gr.Button("Click Here to Generate an Adversary and Run Adversary Detection & Repair") # Right column (outputs) with gr.Column(): gr.Markdown("Example Difference") adv_text_diff = gr.HighlightedText(label="Adversarial Example Difference", combine_adjacent=True) restored_text_diff = gr.HighlightedText(label="Restored Example Difference", combine_adjacent=True) output_is_adv_df = gr.DataFrame(label="Adversary Prediction") output_df = gr.DataFrame(label="Standard Classification Prediction") # Bind functions to buttons button_gen.click(fn=generate_adversarial_example, inputs=[input_dataset, input_attacker, input_sentence, input_label], outputs=[output_origin_example, output_original_label, output_adv_example, adv_text_diff, restored_text_diff, output_adv_label, output_df, output_is_adv_df]) demo.launch()