anonymous8 commited on
Commit
ecdc8b8
1 Parent(s): 4943752
app.py CHANGED
@@ -10,14 +10,23 @@ from findfile import find_files
10
 
11
  from anonymous_demo import TADCheckpointManager
12
  from textattack import Attacker
13
- from textattack.attack_recipes import BAEGarg2019, PWWSRen2019, TextFoolerJin2019, PSOZang2020, IGAWang2019, GeneticAlgorithmAlzantot2018, DeepWordBugGao2018
 
 
 
 
 
 
 
 
14
  from textattack.attack_results import SuccessfulAttackResult
15
  from textattack.datasets import Dataset
16
  from textattack.models.wrappers import HuggingFaceModelWrapper
17
 
18
- z = zipfile.ZipFile('checkpoints.zip', 'r')
19
  z.extractall(os.getcwd())
20
 
 
21
  class ModelWrapper(HuggingFaceModelWrapper):
22
  def __init__(self, model):
23
  self.model = model # pipeline = pipeline
@@ -26,12 +35,11 @@ class ModelWrapper(HuggingFaceModelWrapper):
26
  outputs = []
27
  for text_input in text_inputs:
28
  raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
29
- outputs.append(raw_outputs['probs'])
30
  return outputs
31
 
32
 
33
  class SentAttacker:
34
-
35
  def __init__(self, model, recipe_class=BAEGarg2019):
36
  model = model
37
  model_wrapper = ModelWrapper(model)
@@ -41,7 +49,7 @@ class SentAttacker:
41
 
42
  # recipe.transformation.language = "en"
43
 
44
- _dataset = [('', 0)]
45
  _dataset = Dataset(_dataset)
46
 
47
  self.attacker = Attacker(recipe, _dataset)
@@ -58,63 +66,132 @@ def diff_texts(text1, text2):
58
  def get_ensembled_tad_results(results):
59
  target_dict = {}
60
  for r in results:
61
- target_dict[r['label']] = target_dict.get(r['label']) + 1 if r['label'] in target_dict else 1
 
 
62
 
63
- return dict(zip(target_dict.values(), target_dict.keys()))[max(target_dict.values())]
 
 
64
 
65
 
66
- nltk.download('omw-1.4')
67
 
68
  sent_attackers = {}
69
  tad_classifiers = {}
70
 
71
  attack_recipes = {
72
- 'bae': BAEGarg2019,
73
- 'pwws': PWWSRen2019,
74
- 'textfooler': TextFoolerJin2019,
75
- 'pso': PSOZang2020,
76
- 'iga': IGAWang2019,
77
- 'GA': GeneticAlgorithmAlzantot2018,
78
- 'wordbugger': DeepWordBugGao2018,
79
  }
80
 
81
- for attacker in [
82
- 'pwws',
83
- 'bae',
84
- 'textfooler'
85
- ]:
86
  for dataset in [
87
- 'agnews10k',
88
- 'amazon',
89
- 'sst2',
 
90
  ]:
91
- if 'tad-{}'.format(dataset) not in tad_classifiers:
92
- tad_classifiers['tad-{}'.format(dataset)] = TADCheckpointManager.get_tad_text_classifier('tad-{}'.format(dataset).upper())
 
 
 
 
93
 
94
- sent_attackers['tad-{}{}'.format(dataset, attacker)] = SentAttacker(tad_classifiers['tad-{}'.format(dataset)], attack_recipes[attacker])
95
- tad_classifiers['tad-{}'.format(dataset)].sent_attacker = sent_attackers['tad-{}pwws'.format(dataset)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- def get_a_sst2_example():
99
- filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip', '.state_dict', '.model', '.png', 'acc_', 'f1_', '.origin', '.adv', '.csv']
100
 
101
- dataset_file = {'train': [], 'test': [], 'valid': []}
102
- dataset = 'sst2'
103
- search_path = './'
104
- task = 'text_defense'
105
- dataset_file['test'] += find_files(search_path, [dataset, 'test', task], exclude_key=['.adv', '.org', '.defense', '.inference', 'train.'] + filter_key_words)
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- for dat_type in [
108
- 'test'
109
- ]:
 
 
 
 
 
 
 
 
110
  data = []
111
  label_set = set()
112
  for data_file in dataset_file[dat_type]:
113
-
114
- with open(data_file, mode='r', encoding='utf8') as fin:
115
  lines = fin.readlines()
116
  for line in lines:
117
- text, label = line.split('$LABEL$')
118
  text = text.strip()
119
  label = int(label.strip())
120
  data.append((text, label))
@@ -122,25 +199,43 @@ def get_a_sst2_example():
122
  return data[random.randint(0, len(data))]
123
 
124
 
125
- def get_a_agnews_example():
126
- filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip', '.state_dict', '.model', '.png', 'acc_', 'f1_', '.origin', '.adv', '.csv']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- dataset_file = {'train': [], 'test': [], 'valid': []}
129
- dataset = 'agnews'
130
- search_path = './'
131
- task = 'text_defense'
132
- dataset_file['test'] += find_files(search_path, [dataset, 'test', task], exclude_key=['.adv', '.org', '.defense', '.inference', 'train.'] + filter_key_words)
133
- for dat_type in [
134
- 'test'
135
- ]:
 
 
 
 
136
  data = []
137
  label_set = set()
138
  for data_file in dataset_file[dat_type]:
139
-
140
- with open(data_file, mode='r', encoding='utf8') as fin:
141
  lines = fin.readlines()
142
  for line in lines:
143
- text, label = line.split('$LABEL$')
144
  text = text.strip()
145
  label = int(label.strip())
146
  data.append((text, label))
@@ -148,26 +243,43 @@ def get_a_agnews_example():
148
  return data[random.randint(0, len(data))]
149
 
150
 
151
- def get_a_amazon_example():
152
- filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip', '.state_dict', '.model', '.png', 'acc_', 'f1_', '.origin', '.adv', '.csv']
153
-
154
- dataset_file = {'train': [], 'test': [], 'valid': []}
155
- dataset = 'amazon'
156
- search_path = './'
157
- task = 'text_defense'
158
- dataset_file['test'] += find_files(search_path, [dataset, 'test', task], exclude_key=['.adv', '.org', '.defense', '.inference', 'train.'] + filter_key_words)
 
 
 
 
 
 
 
 
 
159
 
160
- for dat_type in [
161
- 'test'
162
- ]:
 
 
 
 
 
 
 
 
 
163
  data = []
164
  label_set = set()
165
  for data_file in dataset_file[dat_type]:
166
-
167
- with open(data_file, mode='r', encoding='utf8') as fin:
168
  lines = fin.readlines()
169
  for line in lines:
170
- text, label = line.split('$LABEL$')
171
  text = text.strip()
172
  label = int(label.strip())
173
  data.append((text, label))
@@ -175,97 +287,205 @@ def get_a_amazon_example():
175
  return data[random.randint(0, len(data))]
176
 
177
 
 
 
 
178
  def generate_adversarial_example(dataset, attacker, text=None, label=None):
179
- if not text:
180
- if 'agnews' in dataset.lower():
181
- text, label = get_a_agnews_example()
182
- elif 'sst2' in dataset.lower():
183
- text, label = get_a_sst2_example()
184
- elif 'amazon' in dataset.lower():
185
- text, label = get_a_amazon_example()
 
 
 
 
186
 
187
  result = None
188
- attack_result = sent_attackers['tad-{}{}'.format(dataset.lower(), attacker.lower())].attacker.simple_attack(text, int(label))
 
 
189
  if isinstance(attack_result, SuccessfulAttackResult):
190
-
191
- if (attack_result.perturbed_result.output != attack_result.original_result.ground_truth_output) and (attack_result.original_result.output == attack_result.original_result.ground_truth_output):
 
 
 
 
 
192
  # with defense
193
- result = tad_classifiers['tad-{}'.format(dataset.lower())].infer(
194
- attack_result.perturbed_result.attacked_text.text + '!ref!{},{},{}'.format(attack_result.original_result.ground_truth_output, 1, attack_result.perturbed_result.output),
 
 
 
 
 
195
  print_result=True,
196
- defense='pwws',
197
  )
198
 
199
  if result:
200
  classification_df = {}
201
- classification_df['pred_label'] = result['label']
202
- classification_df['confidence'] = round(result['confidence'], 3)
203
- classification_df['is_correct'] = result['ref_label_check']
204
- classification_df['is_repaired'] = result['is_fixed']
205
 
206
  advdetection_df = {}
207
- if result['is_adv_label'] != '0':
208
- advdetection_df['is_adversary'] = result['is_adv_label']
209
- advdetection_df['perturbed_label'] = result['perturbed_label']
210
- advdetection_df['confidence'] = round(result['is_adv_confidence'], 3)
 
 
 
 
 
211
  # advdetection_df['ref_is_attack'] = result['ref_is_adv_label']
212
  # advdetection_df['is_correct'] = result['ref_is_adv_check']
213
 
214
  else:
215
  return generate_adversarial_example(dataset, attacker)
216
 
217
- return (text,
218
- label,
219
- attack_result.perturbed_result.attacked_text.text,
220
- diff_texts(text, attack_result.perturbed_result.attacked_text.text),
221
- diff_texts(text, result['restored_text']),
222
- attack_result.perturbed_result.output,
223
- pd.DataFrame(classification_df, index=[0]),
224
- pd.DataFrame(advdetection_df, index=[0])
225
- )
 
 
 
 
226
 
227
 
228
  demo = gr.Blocks()
229
-
230
  with demo:
231
- with gr.Row():
232
- with gr.Column():
233
- input_dataset = gr.Radio(choices=['SST2', 'AGNews10K', 'Amazon'], value='Amazon', label="Dataset")
234
- input_attacker = gr.Radio(choices=['BAE', 'PWWS', 'TextFooler'], value='TextFooler', label="Attacker")
235
- input_sentence = gr.Textbox(placeholder='Randomly choose a example from testing set if this box is blank', label="Sentence")
236
- input_label = gr.Textbox(placeholder='original label ... ', label="Original Label")
237
-
238
- gr.Markdown("Original Example")
239
-
240
- output_origin_example = gr.Textbox(label="Original Example")
241
- output_original_label = gr.Textbox(label="Original Label")
242
-
243
- gr.Markdown("Adversarial Example")
244
- output_adv_example = gr.Textbox(label="Adversarial Example")
245
- output_adv_label = gr.Textbox(label="Perturbed Label")
246
-
247
- gr.Markdown('This demo is deployed on a CPU device so it may take a long time to execute. Please be patient.')
248
- button_gen = gr.Button("Click Here to Generate an Adversary and Run Adversary Detection & Repair")
249
-
250
- # Right column (outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  with gr.Column():
252
- gr.Markdown("Example Difference")
253
- adv_text_diff = gr.HighlightedText(label="Adversarial Example Difference", combine_adjacent=True)
254
- restored_text_diff = gr.HighlightedText(label="Restored Example Difference", combine_adjacent=True)
255
-
256
- output_is_adv_df = gr.DataFrame(label="Adversary Prediction")
257
- output_df = gr.DataFrame(label="Standard Classification Prediction")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  # Bind functions to buttons
260
- button_gen.click(fn=generate_adversarial_example,
261
- inputs=[input_dataset, input_attacker, input_sentence, input_label],
262
- outputs=[output_origin_example,
263
- output_original_label,
264
- output_adv_example,
265
- adv_text_diff,
266
- restored_text_diff,
267
- output_adv_label,
268
- output_df,
269
- output_is_adv_df])
 
 
 
 
 
 
 
270
 
271
  demo.launch()
 
10
 
11
  from anonymous_demo import TADCheckpointManager
12
  from textattack import Attacker
13
+ from textattack.attack_recipes import (
14
+ BAEGarg2019,
15
+ PWWSRen2019,
16
+ TextFoolerJin2019,
17
+ PSOZang2020,
18
+ IGAWang2019,
19
+ GeneticAlgorithmAlzantot2018,
20
+ DeepWordBugGao2018,
21
+ )
22
  from textattack.attack_results import SuccessfulAttackResult
23
  from textattack.datasets import Dataset
24
  from textattack.models.wrappers import HuggingFaceModelWrapper
25
 
26
+ z = zipfile.ZipFile("checkpoints.zip", "r")
27
  z.extractall(os.getcwd())
28
 
29
+
30
  class ModelWrapper(HuggingFaceModelWrapper):
31
  def __init__(self, model):
32
  self.model = model # pipeline = pipeline
 
35
  outputs = []
36
  for text_input in text_inputs:
37
  raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
38
+ outputs.append(raw_outputs["probs"])
39
  return outputs
40
 
41
 
42
  class SentAttacker:
 
43
  def __init__(self, model, recipe_class=BAEGarg2019):
44
  model = model
45
  model_wrapper = ModelWrapper(model)
 
49
 
50
  # recipe.transformation.language = "en"
51
 
52
+ _dataset = [("", 0)]
53
  _dataset = Dataset(_dataset)
54
 
55
  self.attacker = Attacker(recipe, _dataset)
 
66
  def get_ensembled_tad_results(results):
67
  target_dict = {}
68
  for r in results:
69
+ target_dict[r["label"]] = (
70
+ target_dict.get(r["label"]) + 1 if r["label"] in target_dict else 1
71
+ )
72
 
73
+ return dict(zip(target_dict.values(), target_dict.keys()))[
74
+ max(target_dict.values())
75
+ ]
76
 
77
 
78
+ nltk.download("omw-1.4")
79
 
80
  sent_attackers = {}
81
  tad_classifiers = {}
82
 
83
  attack_recipes = {
84
+ "bae": BAEGarg2019,
85
+ "pwws": PWWSRen2019,
86
+ "textfooler": TextFoolerJin2019,
87
+ "pso": PSOZang2020,
88
+ "iga": IGAWang2019,
89
+ "GA": GeneticAlgorithmAlzantot2018,
90
+ "wordbugger": DeepWordBugGao2018,
91
  }
92
 
93
+ for attacker in ["pwws", "bae", "textfooler"]:
 
 
 
 
94
  for dataset in [
95
+ "agnews10k",
96
+ "amazon",
97
+ "sst2",
98
+ # 'imdb'
99
  ]:
100
+ if "tad-{}".format(dataset) not in tad_classifiers:
101
+ tad_classifiers[
102
+ "tad-{}".format(dataset)
103
+ ] = TADCheckpointManager.get_tad_text_classifier(
104
+ "tad-{}".format(dataset).upper()
105
+ )
106
 
107
+ sent_attackers["tad-{}{}".format(dataset, attacker)] = SentAttacker(
108
+ tad_classifiers["tad-{}".format(dataset)], attack_recipes[attacker]
109
+ )
110
+ tad_classifiers["tad-{}".format(dataset)].sent_attacker = sent_attackers[
111
+ "tad-{}pwws".format(dataset)
112
+ ]
113
+
114
+
115
+ def get_sst2_example():
116
+ filter_key_words = [
117
+ ".py",
118
+ ".md",
119
+ "readme",
120
+ "log",
121
+ "result",
122
+ "zip",
123
+ ".state_dict",
124
+ ".model",
125
+ ".png",
126
+ "acc_",
127
+ "f1_",
128
+ ".origin",
129
+ ".adv",
130
+ ".csv",
131
+ ]
132
 
133
+ dataset_file = {"train": [], "test": [], "valid": []}
134
+ dataset = "sst2"
135
+ search_path = "./"
136
+ task = "text_defense"
137
+ dataset_file["test"] += find_files(
138
+ search_path,
139
+ [dataset, "test", task],
140
+ exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
141
+ + filter_key_words,
142
+ )
143
+
144
+ for dat_type in ["test"]:
145
+ data = []
146
+ label_set = set()
147
+ for data_file in dataset_file[dat_type]:
148
+ with open(data_file, mode="r", encoding="utf8") as fin:
149
+ lines = fin.readlines()
150
+ for line in lines:
151
+ text, label = line.split("$LABEL$")
152
+ text = text.strip()
153
+ label = int(label.strip())
154
+ data.append((text, label))
155
+ label_set.add(label)
156
+ return data[random.randint(0, len(data))]
157
 
 
 
158
 
159
+ def get_agnews_example():
160
+ filter_key_words = [
161
+ ".py",
162
+ ".md",
163
+ "readme",
164
+ "log",
165
+ "result",
166
+ "zip",
167
+ ".state_dict",
168
+ ".model",
169
+ ".png",
170
+ "acc_",
171
+ "f1_",
172
+ ".origin",
173
+ ".adv",
174
+ ".csv",
175
+ ]
176
 
177
+ dataset_file = {"train": [], "test": [], "valid": []}
178
+ dataset = "agnews"
179
+ search_path = "./"
180
+ task = "text_defense"
181
+ dataset_file["test"] += find_files(
182
+ search_path,
183
+ [dataset, "test", task],
184
+ exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
185
+ + filter_key_words,
186
+ )
187
+ for dat_type in ["test"]:
188
  data = []
189
  label_set = set()
190
  for data_file in dataset_file[dat_type]:
191
+ with open(data_file, mode="r", encoding="utf8") as fin:
 
192
  lines = fin.readlines()
193
  for line in lines:
194
+ text, label = line.split("$LABEL$")
195
  text = text.strip()
196
  label = int(label.strip())
197
  data.append((text, label))
 
199
  return data[random.randint(0, len(data))]
200
 
201
 
202
+ def get_amazon_example():
203
+ filter_key_words = [
204
+ ".py",
205
+ ".md",
206
+ "readme",
207
+ "log",
208
+ "result",
209
+ "zip",
210
+ ".state_dict",
211
+ ".model",
212
+ ".png",
213
+ "acc_",
214
+ "f1_",
215
+ ".origin",
216
+ ".adv",
217
+ ".csv",
218
+ ]
219
 
220
+ dataset_file = {"train": [], "test": [], "valid": []}
221
+ dataset = "amazon"
222
+ search_path = "./"
223
+ task = "text_defense"
224
+ dataset_file["test"] += find_files(
225
+ search_path,
226
+ [dataset, "test", task],
227
+ exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
228
+ + filter_key_words,
229
+ )
230
+
231
+ for dat_type in ["test"]:
232
  data = []
233
  label_set = set()
234
  for data_file in dataset_file[dat_type]:
235
+ with open(data_file, mode="r", encoding="utf8") as fin:
 
236
  lines = fin.readlines()
237
  for line in lines:
238
+ text, label = line.split("$LABEL$")
239
  text = text.strip()
240
  label = int(label.strip())
241
  data.append((text, label))
 
243
  return data[random.randint(0, len(data))]
244
 
245
 
246
+ def get_imdb_example():
247
+ filter_key_words = [
248
+ ".py",
249
+ ".md",
250
+ "readme",
251
+ "log",
252
+ "result",
253
+ "zip",
254
+ ".state_dict",
255
+ ".model",
256
+ ".png",
257
+ "acc_",
258
+ "f1_",
259
+ ".origin",
260
+ ".adv",
261
+ ".csv",
262
+ ]
263
 
264
+ dataset_file = {"train": [], "test": [], "valid": []}
265
+ dataset = "imdb"
266
+ search_path = "./"
267
+ task = "text_defense"
268
+ dataset_file["test"] += find_files(
269
+ search_path,
270
+ [dataset, "test", task],
271
+ exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
272
+ + filter_key_words,
273
+ )
274
+
275
+ for dat_type in ["test"]:
276
  data = []
277
  label_set = set()
278
  for data_file in dataset_file[dat_type]:
279
+ with open(data_file, mode="r", encoding="utf8") as fin:
 
280
  lines = fin.readlines()
281
  for line in lines:
282
+ text, label = line.split("$LABEL$")
283
  text = text.strip()
284
  label = int(label.strip())
285
  data.append((text, label))
 
287
  return data[random.randint(0, len(data))]
288
 
289
 
290
+ cache = set()
291
+
292
+
293
  def generate_adversarial_example(dataset, attacker, text=None, label=None):
294
+ if not text or text in cache:
295
+ if "agnews" in dataset.lower():
296
+ text, label = get_agnews_example()
297
+ elif "sst2" in dataset.lower():
298
+ text, label = get_sst2_example()
299
+ elif "amazon" in dataset.lower():
300
+ text, label = get_amazon_example()
301
+ elif "imdb" in dataset.lower():
302
+ text, label = get_imdb_example()
303
+
304
+ cache.add(text)
305
 
306
  result = None
307
+ attack_result = sent_attackers[
308
+ "tad-{}{}".format(dataset.lower(), attacker.lower())
309
+ ].attacker.simple_attack(text, int(label))
310
  if isinstance(attack_result, SuccessfulAttackResult):
311
+ if (
312
+ attack_result.perturbed_result.output
313
+ != attack_result.original_result.ground_truth_output
314
+ ) and (
315
+ attack_result.original_result.output
316
+ == attack_result.original_result.ground_truth_output
317
+ ):
318
  # with defense
319
+ result = tad_classifiers["tad-{}".format(dataset.lower())].infer(
320
+ attack_result.perturbed_result.attacked_text.text
321
+ + "!ref!{},{},{}".format(
322
+ attack_result.original_result.ground_truth_output,
323
+ 1,
324
+ attack_result.perturbed_result.output,
325
+ ),
326
  print_result=True,
327
+ defense="pwws",
328
  )
329
 
330
  if result:
331
  classification_df = {}
332
+ classification_df["is_repaired"] = result["is_fixed"]
333
+ classification_df["pred_label"] = result["label"]
334
+ classification_df["confidence"] = round(result["confidence"], 3)
335
+ classification_df["is_correct"] = result["ref_label_check"]
336
 
337
  advdetection_df = {}
338
+ if result["is_adv_label"] != "0":
339
+ advdetection_df["is_adversarial"] = {
340
+ "0": False,
341
+ "1": True,
342
+ 0: False,
343
+ 1: True,
344
+ }[result["is_adv_label"]]
345
+ advdetection_df["perturbed_label"] = result["perturbed_label"]
346
+ advdetection_df["confidence"] = round(result["is_adv_confidence"], 3)
347
  # advdetection_df['ref_is_attack'] = result['ref_is_adv_label']
348
  # advdetection_df['is_correct'] = result['ref_is_adv_check']
349
 
350
  else:
351
  return generate_adversarial_example(dataset, attacker)
352
 
353
+ return (
354
+ text,
355
+ label,
356
+ result["restored_text"],
357
+ result["label"],
358
+ attack_result.perturbed_result.attacked_text.text,
359
+ diff_texts(text, text),
360
+ diff_texts(text, attack_result.perturbed_result.attacked_text.text),
361
+ diff_texts(text, result["restored_text"]),
362
+ attack_result.perturbed_result.output,
363
+ pd.DataFrame(classification_df, index=[0]),
364
+ pd.DataFrame(advdetection_df, index=[0]),
365
+ )
366
 
367
 
368
  demo = gr.Blocks()
 
369
  with demo:
370
+ gr.Markdown(
371
+ "# <p align='center'> Reactive Perturbation Defocusing for Textual Adversarial Defense </p> "
372
+ )
373
+
374
+ gr.Markdown("## <p align='center'>Clarifications</p>")
375
+ gr.Markdown(
376
+ "- This demo has no mechanism to ensure the adversarial example will be correctly repaired by RPD."
377
+ " The repair success rate is actually the performance reported in the paper (approximately up to 97%.)"
378
+ )
379
+ gr.Markdown(
380
+ "- The red (+) and green (-) colors in the character edition indicate the character is added "
381
+ "or deleted in the adversarial example compared to the original input natural example."
382
+ )
383
+ gr.Markdown(
384
+ "- The adversarial example and repaired adversarial example may be unnatural to read, "
385
+ "while it is because the attackers usually generate unnatural perturbations."
386
+ "RPD does not introduce additional unnatural perturbations."
387
+ )
388
+ gr.Markdown(
389
+ "- To our best knowledge, Reactive Perturbation Defocusing is a novel approach in adversarial defense "
390
+ ". RPD significantly (>10% defense accuracy improvement) outperforms the state-of-the-art methods."
391
+ )
392
+
393
+
394
+ gr.Markdown("## <p align='center'>Natural Example Input</p>")
395
+ with gr.Group():
396
+ with gr.Row():
397
+ input_dataset = gr.Radio(
398
+ choices=["SST2", "AGNews10K", "Amazon"],
399
+ value="SST2",
400
+ label="Select a testing dataset and an adversarial attacker to generate an adversarial example.",
401
+ )
402
+ input_attacker = gr.Radio(
403
+ choices=["BAE", "PWWS", "TextFooler"],
404
+ value="TextFooler",
405
+ label="Choose an Adversarial Attacker for generating an adversarial example to attack the model.",
406
+ )
407
+ with gr.Group():
408
+ with gr.Row():
409
+ input_sentence = gr.Textbox(
410
+ placeholder="Input a natural example...",
411
+ label="Alternatively, input a natural example and its original label to generate an adversarial example.",
412
+ )
413
+ input_label = gr.Textbox(
414
+ placeholder="Original label...", label="Original Label"
415
+ )
416
+
417
+
418
+ button_gen = gr.Button(
419
+ "Generate an adversarial example and repair using RPD (No GPU, Time:3-10 mins )",
420
+ variant="primary",
421
+ )
422
+
423
+ gr.Markdown(
424
+ "## <p align='center'>Generated Adversarial Example and Repaired Adversarial Example</p>"
425
+ )
426
+ with gr.Group():
427
  with gr.Column():
428
+ with gr.Row():
429
+ output_original_example = gr.Textbox(label="Original Example")
430
+ output_original_label = gr.Textbox(label="Original Label")
431
+ with gr.Row():
432
+ output_adv_example = gr.Textbox(label="Adversarial Example")
433
+ output_adv_label = gr.Textbox(label="Perturbed Label")
434
+ with gr.Row():
435
+ output_repaired_example = gr.Textbox(label="Repaired Adversarial Example by RPD")
436
+ output_repaired_label = gr.Textbox(label="Repaired Label")
437
+
438
+
439
+ gr.Markdown("## <p align='center'>The Output of Reactive Perturbation Defocusing</p>")
440
+ with gr.Group():
441
+ output_is_adv_df = gr.DataFrame(label="Adversarial Example Detection Result")
442
+ gr.Markdown(
443
+ "The is_adversarial field indicates an adversarial example is detected. "
444
+ "The perturbed_label is the predicted label of the adversarial example. "
445
+ "The confidence field represents the confidence of the predicted adversarial example detection. "
446
+ )
447
+ output_df = gr.DataFrame(
448
+ label="Repaired Standard Classification Result"
449
+ )
450
+ gr.Markdown(
451
+ "If is_repaired=true, it has been repaired by RPD. "
452
+ "The pred_label field indicates the standard classification result. "
453
+ "The confidence field represents the confidence of the predicted label. "
454
+ "The is_correct field indicates whether the predicted label is correct."
455
+ )
456
+
457
+
458
+ gr.Markdown("## <p align='center'>Example Comparisons</p>")
459
+ ori_text_diff = gr.HighlightedText(
460
+ label="The Original Natural Example",
461
+ combine_adjacent=True,
462
+ )
463
+ adv_text_diff = gr.HighlightedText(
464
+ label="Character Editions of Adversarial Example Compared to the Natural Example",
465
+ combine_adjacent=True,
466
+ )
467
+ restored_text_diff = gr.HighlightedText(
468
+ label="Character Editions of Repaired Adversarial Example Compared to the Natural Example",
469
+ combine_adjacent=True,
470
+ )
471
 
472
  # Bind functions to buttons
473
+ button_gen.click(
474
+ fn=generate_adversarial_example,
475
+ inputs=[input_dataset, input_attacker, input_sentence, input_label],
476
+ outputs=[
477
+ output_original_example,
478
+ output_original_label,
479
+ output_repaired_example,
480
+ output_repaired_label,
481
+ output_adv_example,
482
+ ori_text_diff,
483
+ adv_text_diff,
484
+ restored_text_diff,
485
+ output_adv_label,
486
+ output_df,
487
+ output_is_adv_df,
488
+ ],
489
+ )
490
 
491
  demo.launch()
checkpoints.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4a5452cd89dcd3132d616cc81e2a1b063efa7d11e5798719b0779715b1c6edeb
3
- size 1846862527
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f77ae4a45785183900ee874cb318a16b0e2f173b31749a2555215aca93672f26
3
+ size 2456834455
text_defense/202.IMDB10K/imdb10k.test.dat ADDED
The diff for this file is too large to render. See raw diff
 
text_defense/202.IMDB10K/imdb10k.train.dat ADDED
The diff for this file is too large to render. See raw diff
 
text_defense/202.IMDB10K/imdb10k.valid.dat ADDED
The diff for this file is too large to render. See raw diff