anonymous8 commited on
Commit
fbf68ef
1 Parent(s): 34b7dc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -5,9 +5,8 @@ import gradio as gr
5
  import nltk
6
  import pandas as pd
7
  import requests
8
- from flask import Flask
9
 
10
- from anonymous_demo import TADCheckpointManager
11
  from textattack.attack_recipes import (
12
  BAEGarg2019,
13
  PWWSRen2019,
@@ -35,8 +34,6 @@ attack_recipes = {
35
  "clare": CLARE2020,
36
  }
37
 
38
- app = Flask(__name__)
39
-
40
 
41
  def init():
42
  nltk.download("omw-1.4")
@@ -50,6 +47,7 @@ def init():
50
  "agnews10k",
51
  "amazon",
52
  "sst2",
 
53
  # 'imdb'
54
  ]:
55
  if "tad-{}".format(dataset) not in tad_classifiers:
@@ -78,6 +76,8 @@ def generate_adversarial_example(dataset, attacker, text=None, label=None):
78
  text, label = get_sst2_example()
79
  elif "amazon" in dataset.lower():
80
  text, label = get_amazon_example()
 
 
81
  elif "imdb" in dataset.lower():
82
  text, label = get_imdb_example()
83
 
@@ -98,13 +98,13 @@ def generate_adversarial_example(dataset, attacker, text=None, label=None):
98
  # with defense
99
  result = tad_classifiers["tad-{}".format(dataset.lower())].infer(
100
  attack_result.perturbed_result.attacked_text.text
101
- + "!ref!{},{},{}".format(
102
  attack_result.original_result.ground_truth_output,
103
  1,
104
  attack_result.perturbed_result.output,
105
  ),
106
  print_result=True,
107
- defense="pwws",
108
  )
109
 
110
  if result:
@@ -112,7 +112,7 @@ def generate_adversarial_example(dataset, attacker, text=None, label=None):
112
  classification_df["is_repaired"] = result["is_fixed"]
113
  classification_df["pred_label"] = result["label"]
114
  classification_df["confidence"] = round(result["confidence"], 3)
115
- classification_df["is_correct"] = result["ref_label_check"]
116
 
117
  advdetection_df = {}
118
  if result["is_adv_label"] != "0":
@@ -186,12 +186,17 @@ def check_gpu():
186
 
187
 
188
  if __name__ == "__main__":
189
- # init()
 
 
 
 
190
 
191
  demo = gr.Blocks()
192
 
193
  with demo:
194
  gr.Markdown("<h1 align='center'>Reactive Perturbation Defocusing (Rapid) for Textual Adversarial Defense</h1>")
 
195
  gr.Markdown("""
196
  - This demo has no mechanism to ensure the adversarial example will be correctly repaired by Rapid. The repair success rate is actually the performance reported in the paper.
197
  - The adversarial example and repaired adversarial example may be unnatural to read, while it is because the attackers usually generate unnatural perturbations. Rapid does not introduce additional unnatural perturbations.
@@ -202,13 +207,13 @@ if __name__ == "__main__":
202
  with gr.Group():
203
  with gr.Row():
204
  input_dataset = gr.Radio(
205
- choices=["SST2", "AGNews10K", "Yahoo", "Amazon"],
206
  value="SST2",
207
  label="Select a testing dataset and an adversarial attacker to generate an adversarial example.",
208
  )
209
  input_attacker = gr.Radio(
210
  choices=["BAE", "PWWS", "TextFooler", "DeepWordBug"],
211
- value="PWWS",
212
  label="Choose an Adversarial Attacker for generating an adversarial example to attack the model.",
213
  )
214
  with gr.Group():
@@ -258,7 +263,7 @@ if __name__ == "__main__":
258
  )
259
  output_repaired_label = gr.Textbox(label="Predicted Label of the Repaired Adversarial Example")
260
 
261
- gr.Markdown("<h2 align='center'>Example Comparisons</p>")
262
  gr.Markdown("""
263
  <p align='center'>The (+) and (-) in the boxes indicate the added and deleted characters in the adversarial example compared to the original input natural example.</p>
264
  """)
@@ -320,4 +325,4 @@ if __name__ == "__main__":
320
  ],
321
  )
322
 
323
- demo.queue(concurrency_count=10).launch()
 
5
  import nltk
6
  import pandas as pd
7
  import requests
 
8
 
9
+ from pyabsa import TADCheckpointManager
10
  from textattack.attack_recipes import (
11
  BAEGarg2019,
12
  PWWSRen2019,
 
34
  "clare": CLARE2020,
35
  }
36
 
 
 
37
 
38
  def init():
39
  nltk.download("omw-1.4")
 
47
  "agnews10k",
48
  "amazon",
49
  "sst2",
50
+ # "yahoo",
51
  # 'imdb'
52
  ]:
53
  if "tad-{}".format(dataset) not in tad_classifiers:
 
76
  text, label = get_sst2_example()
77
  elif "amazon" in dataset.lower():
78
  text, label = get_amazon_example()
79
+ elif "yahoo" in dataset.lower():
80
+ text, label = get_yahoo_example()
81
  elif "imdb" in dataset.lower():
82
  text, label = get_imdb_example()
83
 
 
98
  # with defense
99
  result = tad_classifiers["tad-{}".format(dataset.lower())].infer(
100
  attack_result.perturbed_result.attacked_text.text
101
+ + "$LABEL${},{},{}".format(
102
  attack_result.original_result.ground_truth_output,
103
  1,
104
  attack_result.perturbed_result.output,
105
  ),
106
  print_result=True,
107
+ defense=attacker,
108
  )
109
 
110
  if result:
 
112
  classification_df["is_repaired"] = result["is_fixed"]
113
  classification_df["pred_label"] = result["label"]
114
  classification_df["confidence"] = round(result["confidence"], 3)
115
+ classification_df["is_correct"] = str(result["pred_label"]) == str(label)
116
 
117
  advdetection_df = {}
118
  if result["is_adv_label"] != "0":
 
186
 
187
 
188
  if __name__ == "__main__":
189
+ try:
190
+ init()
191
+ except Exception as e:
192
+ print(e)
193
+ print("Failed to initialize the demo. Please try again later.")
194
 
195
  demo = gr.Blocks()
196
 
197
  with demo:
198
  gr.Markdown("<h1 align='center'>Reactive Perturbation Defocusing (Rapid) for Textual Adversarial Defense</h1>")
199
+ gr.Markdown("<h3 align='center'>Clarifications</h2>")
200
  gr.Markdown("""
201
  - This demo has no mechanism to ensure the adversarial example will be correctly repaired by Rapid. The repair success rate is actually the performance reported in the paper.
202
  - The adversarial example and repaired adversarial example may be unnatural to read, while it is because the attackers usually generate unnatural perturbations. Rapid does not introduce additional unnatural perturbations.
 
207
  with gr.Group():
208
  with gr.Row():
209
  input_dataset = gr.Radio(
210
+ choices=["SST2", "Amazon", "Yahoo", "AGNews10K"],
211
  value="SST2",
212
  label="Select a testing dataset and an adversarial attacker to generate an adversarial example.",
213
  )
214
  input_attacker = gr.Radio(
215
  choices=["BAE", "PWWS", "TextFooler", "DeepWordBug"],
216
+ value="TextFooler",
217
  label="Choose an Adversarial Attacker for generating an adversarial example to attack the model.",
218
  )
219
  with gr.Group():
 
263
  )
264
  output_repaired_label = gr.Textbox(label="Predicted Label of the Repaired Adversarial Example")
265
 
266
+ gr.Markdown("<h2 align='center'>Example Difference (Comparisons)</p>")
267
  gr.Markdown("""
268
  <p align='center'>The (+) and (-) in the boxes indicate the added and deleted characters in the adversarial example compared to the original input natural example.</p>
269
  """)
 
325
  ],
326
  )
327
 
328
+ demo.queue(2).launch()