import os import zipfile import gradio as gr import nltk import pandas as pd import requests from pyabsa import TADCheckpointManager from textattack.attack_recipes import ( BAEGarg2019, PWWSRen2019, TextFoolerJin2019, PSOZang2020, IGAWang2019, GeneticAlgorithmAlzantot2018, DeepWordBugGao2018, CLARE2020, ) from textattack.attack_results import SuccessfulAttackResult from utils import SentAttacker, get_agnews_example, get_sst2_example, get_amazon_example, get_imdb_example, diff_texts # from utils import get_yahoo_example sent_attackers = {} tad_classifiers = {} attack_recipes = { "bae": BAEGarg2019, "pwws": PWWSRen2019, "textfooler": TextFoolerJin2019, "pso": PSOZang2020, "iga": IGAWang2019, "ga": GeneticAlgorithmAlzantot2018, "deepwordbug": DeepWordBugGao2018, "clare": CLARE2020, } def init(): nltk.download("omw-1.4") if not os.path.exists("TAD-SST2"): z = zipfile.ZipFile("checkpoints.zip", "r") z.extractall(os.getcwd()) for attacker in ["pwws", "bae", "textfooler", "deepwordbug"]: for dataset in [ "agnews10k", "sst2", "MR", 'imdb' ]: if "tad-{}".format(dataset) not in tad_classifiers: tad_classifiers[ "tad-{}".format(dataset) ] = TADCheckpointManager.get_tad_text_classifier( "tad-{}".format(dataset).upper() ) sent_attackers["tad-{}{}".format(dataset, attacker)] = SentAttacker( tad_classifiers["tad-{}".format(dataset)], attack_recipes[attacker] ) tad_classifiers["tad-{}".format(dataset)].sent_attacker = sent_attackers[ "tad-{}pwws".format(dataset) ] cache = set() def generate_adversarial_example(dataset, attacker, text=None, label=None): """if not text or text in cache: if "agnews" in dataset.lower(): text, label = get_agnews_example() elif "sst2" in dataset.lower(): text, label = get_sst2_example() elif "MR" in dataset.lower(): text, label = get_amazon_example() # elif "yahoo" in dataset.lower(): # text, label = get_yahoo_example() elif "imdb" in dataset.lower(): text, label = get_imdb_example()""" cache.add(text) result = None attack_result = sent_attackers[ "tad-{}{}".format(dataset.lower(), attacker.lower()) ].attacker.simple_attack(text, int(label)) if isinstance(attack_result, SuccessfulAttackResult): if ( attack_result.perturbed_result.output != attack_result.original_result.ground_truth_output ) and ( attack_result.original_result.output == attack_result.original_result.ground_truth_output ): # with defense result = tad_classifiers["tad-{}".format(dataset.lower())].infer( attack_result.perturbed_result.attacked_text.text + "$LABEL${},{},{}".format( attack_result.original_result.ground_truth_output, 1, attack_result.perturbed_result.output, ), print_result=True, defense=attacker, ) if result: classification_df = {} classification_df["is_repaired"] = result["is_fixed"] classification_df["pred_label"] = result["label"] classification_df["confidence"] = round(result["confidence"], 3) classification_df["is_correct"] = str(result["pred_label"]) == str(label) advdetection_df = {} if result["is_adv_label"] != "0": advdetection_df["is_adversarial"] = { "0": False, "1": True, 0: False, 1: True, }[result["is_adv_label"]] advdetection_df["perturbed_label"] = result["perturbed_label"] advdetection_df["confidence"] = round(result["is_adv_confidence"], 3) advdetection_df['ref_is_attack'] = result['ref_is_adv_label'] advdetection_df['is_correct'] = result['ref_is_adv_check'] else: return generate_adversarial_example(dataset, attacker) return ( text, label, result["restored_text"], result["label"], attack_result.perturbed_result.attacked_text.text, diff_texts(text, text), diff_texts(text, attack_result.perturbed_result.attacked_text.text), diff_texts(text, result["restored_text"]), attack_result.perturbed_result.output, pd.DataFrame(classification_df, index=[0]), pd.DataFrame(advdetection_df, index=[0]), ) def run_demo(dataset, attacker, text=None, label=None): try: data = { "dataset": dataset, "attacker": attacker, "text": text, "label": label, } response = requests.post('https://rpddemo.pagekite.me/api/generate_adversarial_example', json=data) result = response.json() print(response.json()) return ( result["text"], result["label"], result["restored_text"], result["result_label"], result["perturbed_text"], result["text_diff"], result["perturbed_diff"], result["restored_diff"], result["output"], pd.DataFrame(result["classification_df"]), pd.DataFrame(result["advdetection_df"]), result["message"] ) except Exception as e: print(e) return generate_adversarial_example(dataset, attacker, text, label) def check_gpu(): try: response = requests.post('https://rpddemo.pagekite.me/api/generate_adversarial_example', timeout=3) if response.status_code < 500: return 'GPU available' else: return 'GPU not available' except Exception as e: return 'GPU not available' if __name__ == "__main__": try: init() except Exception as e: print(e) print("Failed to initialize the demo. Please try again later.") demo = gr.Blocks() with demo: gr.Markdown("

Detection and Correction based on Word Importance Ranking (DCWIR)

") gr.Markdown("

Clarifications

") gr.Markdown(""" - This demo has no mechanism to ensure the adversarial example will be correctly repaired by DCWIR. - The adversarial example and corrected adversarial example may be unnatural to read, while it is because the attackers usually generate unnatural perturbations. - All the proposed attacks are Black Box attack where the attacker has no access to the model parameters. """) gr.Markdown("

Natural Example Input

") with gr.Group(): with gr.Row(): input_dataset = gr.Radio( choices=["SST2", "IMDB", "MR", "AGNews10K"], value="SST2", label="Select a testing dataset and an adversarial attacker to generate an adversarial example.", ) input_attacker = gr.Radio( choices=["BAE", "PWWS", "TextFooler", "DeepWordBug"], value="TextFooler", label="Choose an Adversarial Attacker for generating an adversarial example to attack the model.", ) with gr.Group(visible=True): with gr.Row(): input_sentence = gr.Textbox( placeholder="Input a natural example...", label="Alternatively, input a natural example and its original label (from above datasets) to generate an adversarial example.", ) input_label = gr.Textbox( placeholder="Original label, (must be a integer, because we use digits to represent labels in training)", label="Original Label", ) gr.Markdown( "

Default parameters are set according to the main experiment setup in the report.

", ) with gr.Row(): wir_percentage = gr.Textbox( placeholder="Enter percentage from WIR...", label="Percentage from WIR", ) frequency_threshold = gr.Textbox( placeholder="Enter frequency threshold...", label="Frequency Threshold", ) max_candidates = gr.Textbox( placeholder="Enter maximum number of candidates...", label="Maximum Number of Candidates", ) msg_text = gr.Textbox( label="Message", placeholder="This is a message box to show any error messages.", ) button_gen = gr.Button( "Generate an adversarial example to repair using Rapid (GPU: < 1 minute, CPU: 1-10 minutes)", variant="primary", ) gpu_status_text = gr.Textbox( label='GPU status', placeholder="Please click to check", ) button_check = gr.Button( "Check if GPU available", variant="primary" ) button_check.click( fn=check_gpu, inputs=[], outputs=[ gpu_status_text ] ) gr.Markdown("

Generated Adversarial Example and Repaired Adversarial Example

") with gr.Column(): with gr.Group(): with gr.Row(): output_original_example = gr.Textbox(label="Original Example") output_original_label = gr.Textbox(label="Original Label") with gr.Row(): output_adv_example = gr.Textbox(label="Adversarial Example") output_adv_label = gr.Textbox(label="Predicted Label of the Adversarial Example") with gr.Row(): output_repaired_example = gr.Textbox( label="Repaired Adversarial Example by DCWIR" ) output_repaired_label = gr.Textbox(label="Predicted Label of the Repaired Adversarial Example") gr.Markdown("

Example Difference (Comparisons)

") gr.Markdown("""

The (+) and (-) in the boxes indicate the added and deleted characters in the adversarial example compared to the original input natural example.

""") ori_text_diff = gr.HighlightedText( label="The Original Natural Example", combine_adjacent=True, show_legend=True, ) adv_text_diff = gr.HighlightedText( label="Character Editions of Adversarial Example Compared to the Natural Example", combine_adjacent=True, show_legend=True, ) restored_text_diff = gr.HighlightedText( label="Character Editions of Repaired Adversarial Example Compared to the Natural Example", combine_adjacent=True, show_legend=True, ) gr.Markdown( "##

The Output of Reactive Perturbation Defocusing

" ) with gr.Row(): with gr.Column(): with gr.Group(): output_is_adv_df = gr.DataFrame( label="Adversarial Example Detection Result" ) gr.Markdown( """ - The is_adversarial field indicates if an adversarial example is detected. - The perturbed_label is the predicted label of the adversarial example. - The confidence field represents the ratio of Inverted samples among the total number of generated candidates. """ ) with gr.Column(): with gr.Group(): output_df = gr.DataFrame( label="Correction Classification Result" ) gr.Markdown( """ - If is_corrected=true, it has been Corrected by DCWIR. - The pred_label field indicates the standard classification result. - The confidence field represents ratio of the dominant class among all Inverted candidates. - The is_correct field indicates whether the predicted label is correct. """ ) # Bind functions to buttons button_gen.click( fn=run_demo, inputs=[input_dataset, input_attacker, input_sentence, input_label], outputs=[ output_original_example, output_original_label, output_repaired_example, output_repaired_label, output_adv_example, ori_text_diff, adv_text_diff, restored_text_diff, output_adv_label, output_df, output_is_adv_df, msg_text ], ) demo.queue(2).launch()