PFEemp2024 commited on
Commit
1ef6bf0
1 Parent(s): e3f5f4d

Upload 2 files

Browse files
Files changed (2) hide show
  1. flow_correction_ag_news.py +388 -0
  2. flow_correction_imdb.py +388 -0
flow_correction_ag_news.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import textattack
2
+ import transformers
3
+ import pandas as pd
4
+ import csv
5
+ import string
6
+ import pickle
7
+ # Construct our four components for `Attack`
8
+ from textattack.constraints.pre_transformation import (
9
+ RepeatModification,
10
+ StopwordModification,
11
+ )
12
+ from textattack.constraints.semantics import WordEmbeddingDistance
13
+ from textattack.transformations import WordSwapEmbedding
14
+ from textattack.search_methods import GreedyWordSwapWIR
15
+
16
+ import numpy as np
17
+ import json
18
+ import random
19
+ import re
20
+ import textattack.shared.attacked_text as atk
21
+ import torch.nn.functional as F
22
+ import torch
23
+
24
+
25
+ class InvertedText:
26
+
27
+ def __init__(
28
+ self,
29
+ swapped_indexes,
30
+ score,
31
+ attacked_text,
32
+ new_class,
33
+ ):
34
+ self.attacked_text = attacked_text
35
+ self.swapped_indexes = (
36
+ swapped_indexes # dict of swapped indexes with their synonym
37
+ )
38
+ self.score = score # value of original class
39
+ self.new_class = new_class # class after inversion
40
+
41
+ def __repr__(self):
42
+ return f"InvertedText:\n attacked_text='{self.attacked_text}', \n swapped_indexes={self.swapped_indexes},\n score={self.score}"
43
+
44
+
45
+ def count_matching_classes(original, corrected, perturbed_texts=None):
46
+ if len(original) != len(corrected):
47
+ raise ValueError("Arrays must have the same length")
48
+ hard_samples = []
49
+ easy_samples = []
50
+
51
+ matching_count = 0
52
+
53
+ for i in range(len(corrected)):
54
+ if original[i] == corrected[i]:
55
+ matching_count += 1
56
+ easy_samples.append(perturbed_texts[i])
57
+ elif perturbed_texts != None:
58
+ hard_samples.append(perturbed_texts[i])
59
+
60
+ return matching_count, hard_samples, easy_samples
61
+
62
+
63
+ class Flow_Corrector:
64
+ def __init__(
65
+ self,
66
+ attack,
67
+ word_rank_file="en_full_ranked.json",
68
+ word_freq_file="en_full_freq.json",
69
+ wir_threshold=0.3,
70
+ ):
71
+ self.attack = attack
72
+ self.attack.cuda_()
73
+ self.wir_threshold = wir_threshold
74
+ with open(word_rank_file, "r") as f:
75
+ self.word_ranked_frequence = json.load(f)
76
+ with open(word_freq_file, "r") as f:
77
+ self.word_frequence = json.load(f)
78
+ self.victim_model = attack.goal_function.model
79
+
80
+ def wir_gradient(
81
+ self,
82
+ attack,
83
+ victim_model,
84
+ detected_text,
85
+ ):
86
+ _, indices_to_order = attack.get_indices_to_order(detected_text)
87
+
88
+ index_scores = np.zeros(len(indices_to_order))
89
+ grad_output = victim_model.get_grad(detected_text.tokenizer_input)
90
+ gradient = grad_output["gradient"]
91
+ word2token_mapping = detected_text.align_with_model_tokens(victim_model)
92
+ for i, index in enumerate(indices_to_order):
93
+ matched_tokens = word2token_mapping[index]
94
+ if not matched_tokens:
95
+ index_scores[i] = 0.0
96
+ else:
97
+ agg_grad = np.mean(gradient[matched_tokens], axis=0)
98
+ index_scores[i] = np.linalg.norm(agg_grad, ord=1)
99
+ index_order = np.array(indices_to_order)[(-index_scores).argsort()]
100
+ return index_order
101
+
102
+ def get_syn_freq_dict(
103
+ self,
104
+ index_order,
105
+ detected_text,
106
+ ):
107
+ most_frequent_syn_dict = {}
108
+
109
+ no_syn = []
110
+ freq_thershold = len(self.word_ranked_frequence) / 10
111
+
112
+ for idx in index_order:
113
+ # get the synonyms of a specific index
114
+
115
+ try:
116
+ synonyms = [
117
+ attacked_text.words[idx]
118
+ for attacked_text in self.attack.get_transformations(
119
+ detected_text, detected_text, indices_to_modify=[idx]
120
+ )
121
+ ]
122
+ # getting synonyms that exists in dataset with thiere frequency rank
123
+ ranked_synonyms = {
124
+ syn: self.word_ranked_frequence[syn]
125
+ for syn in synonyms
126
+ if syn in self.word_ranked_frequence.keys()
127
+ and self.word_ranked_frequence[syn] < freq_thershold
128
+ and self.word_ranked_frequence[detected_text.words[idx]]
129
+ > self.word_ranked_frequence[syn]
130
+ }
131
+ # selecting the M most frequent synonym
132
+
133
+ if list(ranked_synonyms.keys()) != []:
134
+ most_frequent_syn_dict[idx] = list(ranked_synonyms.keys())
135
+ except:
136
+ # no synonyms avaialble in the dataset
137
+ no_syn.append(idx)
138
+
139
+ return most_frequent_syn_dict
140
+
141
+ def build_candidates(
142
+ self, detected_text, most_frequent_syn_dict: dict, max_attempt: int
143
+ ):
144
+ candidates = {}
145
+ for _ in range(max_attempt):
146
+ syn_dict = {}
147
+ current_text = detected_text
148
+ for index in most_frequent_syn_dict.keys():
149
+ syn = random.choice(most_frequent_syn_dict[index])
150
+ syn_dict[index] = syn
151
+ current_text = current_text.replace_word_at_index(index, syn)
152
+
153
+ candidates[current_text] = syn_dict
154
+ return candidates
155
+
156
+ def find_dominant_class(self, inverted_texts):
157
+ class_counts = {} # Dictionary to store the count of each new class
158
+
159
+ for text in inverted_texts:
160
+ new_class = text.new_class
161
+ class_counts[new_class] = class_counts.get(new_class, 0) + 1
162
+
163
+ # Find the most dominant class
164
+ most_dominant_class = max(class_counts, key=class_counts.get)
165
+
166
+ return most_dominant_class
167
+
168
+ def correct(self, detected_texts):
169
+ corrected_classes = []
170
+ for detected_text in detected_texts:
171
+
172
+ # convert to Attacked texts
173
+ detected_text = atk.AttackedText(detected_text)
174
+
175
+ # getting 30% most important indexes
176
+ index_order = self.wir_gradient(
177
+ self.attack, self.victim_model, detected_text
178
+ )
179
+ index_order = index_order[: int(len(index_order) * self.wir_threshold)]
180
+
181
+ # getting synonyms according to frequency conditiontions
182
+ most_frequent_syn_dict = self.get_syn_freq_dict(index_order, detected_text)
183
+
184
+ # generate M candidates
185
+ candidates = self.build_candidates(
186
+ detected_text, most_frequent_syn_dict, max_attempt=100
187
+ )
188
+
189
+ original_probs = F.softmax(self.victim_model(detected_text.text), dim=1)
190
+ original_class = torch.argmax(original_probs).item()
191
+ original_golden_prob = float(original_probs[0][original_class])
192
+
193
+ nbr_inverted = 0
194
+ inverted_texts = [] # a dictionary of inverted texts with
195
+ bad, impr = 0, 0
196
+ dict_deltas = {}
197
+
198
+ batch_inputs = [candidate.text for candidate in candidates.keys()]
199
+
200
+ batch_outputs = self.victim_model(batch_inputs)
201
+
202
+ probabilities = F.softmax(batch_outputs, dim=1)
203
+ for i, (candidate, syn_dict) in enumerate(candidates.items()):
204
+
205
+ corrected_class = torch.argmax(probabilities[i]).item()
206
+ new_golden_probability = float(probabilities[i][corrected_class])
207
+ if corrected_class != original_class:
208
+ nbr_inverted += 1
209
+ inverted_texts.append(
210
+ InvertedText(
211
+ syn_dict, new_golden_probability, candidate, corrected_class
212
+ )
213
+ )
214
+ else:
215
+ delta = new_golden_probability - original_golden_prob
216
+ if delta <= 0:
217
+ bad += 1
218
+ else:
219
+ impr += 1
220
+ dict_deltas[candidate] = delta
221
+
222
+ if len(original_probs[0]) > 2 and len(inverted_texts) >= len(candidates) / (
223
+ len(original_probs[0])
224
+ ):
225
+ # selecting the most dominant class
226
+ dominant_class = self.find_dominant_class(inverted_texts)
227
+ elif len(inverted_texts) >= len(candidates) / 2:
228
+ dominant_class = corrected_class
229
+ else:
230
+ dominant_class = original_class
231
+
232
+ corrected_classes.append(dominant_class)
233
+
234
+ return corrected_classes
235
+
236
+
237
+ def remove_brackets(text):
238
+ text = text.replace("[[", "")
239
+ text = text.replace("]]", "")
240
+ return text
241
+
242
+
243
+ def clean_text(text):
244
+ pattern = "[" + re.escape(string.punctuation) + "]"
245
+ cleaned_text = re.sub(pattern, " ", text)
246
+
247
+ return cleaned_text
248
+
249
+
250
+ # Load model, tokenizer, and model_wrapper
251
+ model = transformers.AutoModelForSequenceClassification.from_pretrained(
252
+ "textattack/bert-base-uncased-ag-news"
253
+ )
254
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
255
+ "textattack/bert-base-uncased-ag-news"
256
+ )
257
+ model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
258
+
259
+
260
+ goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
261
+ constraints = [
262
+ RepeatModification(),
263
+ StopwordModification(),
264
+ WordEmbeddingDistance(min_cos_sim=0.9),
265
+ ]
266
+ transformation = WordSwapEmbedding(max_candidates=50)
267
+ search_method = GreedyWordSwapWIR(wir_method="gradient")
268
+
269
+ # Construct the actual attack
270
+ attack = textattack.Attack(goal_function, constraints, transformation, search_method)
271
+ attack.cuda_()
272
+
273
+
274
+ results = pd.read_csv("ag_news_results.csv")
275
+ perturbed_texts = [
276
+ results["perturbed_text"][i]
277
+ for i in range(len(results))
278
+ if results["result_type"][i] == "Successful"
279
+ ]
280
+ original_texts = [
281
+ results["original_text"][i]
282
+ for i in range(len(results))
283
+ if results["result_type"][i] == "Successful"
284
+ ]
285
+
286
+ perturbed_texts = [remove_brackets(text) for text in perturbed_texts]
287
+ original_texts = [remove_brackets(text) for text in original_texts]
288
+
289
+ perturbed_texts = [clean_text(text) for text in perturbed_texts]
290
+ original_texts = [clean_text(text) for text in original_texts]
291
+
292
+
293
+ victim_model = attack.goal_function.model
294
+
295
+ print("Getting corrected classes")
296
+ print("This may take a while ...")
297
+ # we can use directly resultds in csv file
298
+ original_classes = [
299
+ torch.argmax(F.softmax(victim_model(original_text), dim=1)).item()
300
+ for original_text in original_texts
301
+ ]
302
+
303
+ batch_size = 1000
304
+ num_batches = (len(perturbed_texts) + batch_size - 1) // batch_size
305
+ batched_perturbed_texts = []
306
+ batched_original_texts = []
307
+ batched_original_classes = []
308
+
309
+ for i in range(num_batches):
310
+ start = i * batch_size
311
+ end = min(start + batch_size, len(perturbed_texts))
312
+ batched_perturbed_texts.append(perturbed_texts[start:end])
313
+ batched_original_texts.append(original_texts[start:end])
314
+ batched_original_classes.append(original_classes[start:end])
315
+ print(batched_original_classes)
316
+ hard_samples_list = []
317
+ easy_samples_list = []
318
+
319
+
320
+ # Open a CSV file for writing
321
+ csv_filename = "flow_correction_results_ag_news.csv"
322
+ with open(csv_filename, "w", newline="") as csvfile:
323
+ fieldnames = ["freq_threshold", "batch_num", "match_perturbed", "match_original"]
324
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
325
+
326
+ # Write the header row
327
+ writer.writeheader()
328
+
329
+ # Iterate over batched lists
330
+ batch_num = 0
331
+ for perturbed, original, classes in zip(
332
+ batched_perturbed_texts, batched_original_texts, batched_original_classes
333
+ ):
334
+ batch_num += 1
335
+ print(f"Processing batch number: {batch_num}")
336
+
337
+ for i in range(2):
338
+ wir_threshold = 0.1 * (i + 1)
339
+ print(f"Setting Word threshold to: {wir_threshold}")
340
+
341
+ corrector = Flow_Corrector(
342
+ attack,
343
+ word_rank_file="en_full_ranked.json",
344
+ word_freq_file="en_full_freq.json",
345
+ wir_threshold=wir_threshold,
346
+ )
347
+
348
+ # Correct perturbed texts
349
+ print("Correcting perturbed texts...")
350
+ corrected_perturbed_classes = corrector.correct(perturbed)
351
+
352
+ match_perturbed, hard_samples, easy_samples = count_matching_classes(
353
+ classes, corrected_perturbed_classes, perturbed
354
+ )
355
+ hard_samples_list.extend(hard_samples)
356
+ easy_samples_list.extend(easy_samples)
357
+
358
+
359
+ print(f"Number of matching classes (perturbed): {match_perturbed}")
360
+
361
+ # Correct original texts
362
+ print("Correcting original texts...")
363
+ corrected_original_classes = corrector.correct(original)
364
+ match_original, hard_samples, easy_samples = count_matching_classes(
365
+ classes, corrected_original_classes, perturbed
366
+ )
367
+ print(f"Number of matching classes (original): {match_original}")
368
+
369
+ # Write results to CSV file
370
+ print("Writing results to CSV file...")
371
+ writer.writerow(
372
+ {
373
+ "freq_threshold": wir_threshold,
374
+ "batch_num": batch_num,
375
+ "match_perturbed": match_perturbed/len(perturbed),
376
+ "match_original": match_original/len(perturbed),
377
+ }
378
+ )
379
+ print("-" * 20)
380
+
381
+ print("savig samples for more statistics studies")
382
+
383
+ # Save hard_samples_list and easy_samples_list to files
384
+ with open('hard_samples.pkl', 'wb') as f:
385
+ pickle.dump(hard_samples_list, f)
386
+
387
+ with open('easy_samples.pkl', 'wb') as f:
388
+ pickle.dump(easy_samples_list, f)
flow_correction_imdb.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import textattack
2
+ import transformers
3
+ import pandas as pd
4
+ import csv
5
+ import string
6
+ import pickle
7
+ # Construct our four components for `Attack`
8
+ from textattack.constraints.pre_transformation import (
9
+ RepeatModification,
10
+ StopwordModification,
11
+ )
12
+ from textattack.constraints.semantics import WordEmbeddingDistance
13
+ from textattack.transformations import WordSwapEmbedding
14
+ from textattack.search_methods import GreedyWordSwapWIR
15
+
16
+ import numpy as np
17
+ import json
18
+ import random
19
+ import re
20
+ import textattack.shared.attacked_text as atk
21
+ import torch.nn.functional as F
22
+ import torch
23
+
24
+
25
+ class InvertedText:
26
+
27
+ def __init__(
28
+ self,
29
+ swapped_indexes,
30
+ score,
31
+ attacked_text,
32
+ new_class,
33
+ ):
34
+ self.attacked_text = attacked_text
35
+ self.swapped_indexes = (
36
+ swapped_indexes # dict of swapped indexes with their synonym
37
+ )
38
+ self.score = score # value of original class
39
+ self.new_class = new_class # class after inversion
40
+
41
+ def __repr__(self):
42
+ return f"InvertedText:\n attacked_text='{self.attacked_text}', \n swapped_indexes={self.swapped_indexes},\n score={self.score}"
43
+
44
+
45
+ def count_matching_classes(original, corrected, perturbed_texts=None):
46
+ if len(original) != len(corrected):
47
+ raise ValueError("Arrays must have the same length")
48
+ hard_samples = []
49
+ easy_samples = []
50
+
51
+ matching_count = 0
52
+
53
+ for i in range(len(corrected)):
54
+ if original[i] == corrected[i]:
55
+ matching_count += 1
56
+ easy_samples.append(perturbed_texts[i])
57
+ elif perturbed_texts != None:
58
+ hard_samples.append(perturbed_texts[i])
59
+
60
+ return matching_count, hard_samples, easy_samples
61
+
62
+
63
+ class Flow_Corrector:
64
+ def __init__(
65
+ self,
66
+ attack,
67
+ word_rank_file="en_full_ranked.json",
68
+ word_freq_file="en_full_freq.json",
69
+ wir_threshold=0.3,
70
+ ):
71
+ self.attack = attack
72
+ self.attack.cuda_()
73
+ self.wir_threshold = wir_threshold
74
+ with open(word_rank_file, "r") as f:
75
+ self.word_ranked_frequence = json.load(f)
76
+ with open(word_freq_file, "r") as f:
77
+ self.word_frequence = json.load(f)
78
+ self.victim_model = attack.goal_function.model
79
+
80
+ def wir_gradient(
81
+ self,
82
+ attack,
83
+ victim_model,
84
+ detected_text,
85
+ ):
86
+ _, indices_to_order = attack.get_indices_to_order(detected_text)
87
+
88
+ index_scores = np.zeros(len(indices_to_order))
89
+ grad_output = victim_model.get_grad(detected_text.tokenizer_input)
90
+ gradient = grad_output["gradient"]
91
+ word2token_mapping = detected_text.align_with_model_tokens(victim_model)
92
+ for i, index in enumerate(indices_to_order):
93
+ matched_tokens = word2token_mapping[index]
94
+ if not matched_tokens:
95
+ index_scores[i] = 0.0
96
+ else:
97
+ agg_grad = np.mean(gradient[matched_tokens], axis=0)
98
+ index_scores[i] = np.linalg.norm(agg_grad, ord=1)
99
+ index_order = np.array(indices_to_order)[(-index_scores).argsort()]
100
+ return index_order
101
+
102
+ def get_syn_freq_dict(
103
+ self,
104
+ index_order,
105
+ detected_text,
106
+ ):
107
+ most_frequent_syn_dict = {}
108
+
109
+ no_syn = []
110
+ freq_thershold = len(self.word_ranked_frequence) / 10
111
+
112
+ for idx in index_order:
113
+ # get the synonyms of a specific index
114
+
115
+ try:
116
+ synonyms = [
117
+ attacked_text.words[idx]
118
+ for attacked_text in self.attack.get_transformations(
119
+ detected_text, detected_text, indices_to_modify=[idx]
120
+ )
121
+ ]
122
+ # getting synonyms that exists in dataset with thiere frequency rank
123
+ ranked_synonyms = {
124
+ syn: self.word_ranked_frequence[syn]
125
+ for syn in synonyms
126
+ if syn in self.word_ranked_frequence.keys()
127
+ and self.word_ranked_frequence[syn] < freq_thershold
128
+ and self.word_ranked_frequence[detected_text.words[idx]]
129
+ > self.word_ranked_frequence[syn]
130
+ }
131
+ # selecting the M most frequent synonym
132
+
133
+ if list(ranked_synonyms.keys()) != []:
134
+ most_frequent_syn_dict[idx] = list(ranked_synonyms.keys())
135
+ except:
136
+ # no synonyms avaialble in the dataset
137
+ no_syn.append(idx)
138
+
139
+ return most_frequent_syn_dict
140
+
141
+ def build_candidates(
142
+ self, detected_text, most_frequent_syn_dict: dict, max_attempt: int
143
+ ):
144
+ candidates = {}
145
+ for _ in range(max_attempt):
146
+ syn_dict = {}
147
+ current_text = detected_text
148
+ for index in most_frequent_syn_dict.keys():
149
+ syn = random.choice(most_frequent_syn_dict[index])
150
+ syn_dict[index] = syn
151
+ current_text = current_text.replace_word_at_index(index, syn)
152
+
153
+ candidates[current_text] = syn_dict
154
+ return candidates
155
+
156
+ def find_dominant_class(self, inverted_texts):
157
+ class_counts = {} # Dictionary to store the count of each new class
158
+
159
+ for text in inverted_texts:
160
+ new_class = text.new_class
161
+ class_counts[new_class] = class_counts.get(new_class, 0) + 1
162
+
163
+ # Find the most dominant class
164
+ most_dominant_class = max(class_counts, key=class_counts.get)
165
+
166
+ return most_dominant_class
167
+
168
+ def correct(self, detected_texts):
169
+ corrected_classes = []
170
+ for detected_text in detected_texts:
171
+
172
+ # convert to Attacked texts
173
+ detected_text = atk.AttackedText(detected_text)
174
+
175
+ # getting 30% most important indexes
176
+ index_order = self.wir_gradient(
177
+ self.attack, self.victim_model, detected_text
178
+ )
179
+ index_order = index_order[: int(len(index_order) * self.wir_threshold)]
180
+
181
+ # getting synonyms according to frequency conditiontions
182
+ most_frequent_syn_dict = self.get_syn_freq_dict(index_order, detected_text)
183
+
184
+ # generate M candidates
185
+ candidates = self.build_candidates(
186
+ detected_text, most_frequent_syn_dict, max_attempt=100
187
+ )
188
+
189
+ original_probs = F.softmax(self.victim_model(detected_text.text), dim=1)
190
+ original_class = torch.argmax(original_probs).item()
191
+ original_golden_prob = float(original_probs[0][original_class])
192
+
193
+ nbr_inverted = 0
194
+ inverted_texts = [] # a dictionary of inverted texts with
195
+ bad, impr = 0, 0
196
+ dict_deltas = {}
197
+
198
+ batch_inputs = [candidate.text for candidate in candidates.keys()]
199
+
200
+ batch_outputs = self.victim_model(batch_inputs)
201
+
202
+ probabilities = F.softmax(batch_outputs, dim=1)
203
+ for i, (candidate, syn_dict) in enumerate(candidates.items()):
204
+
205
+ corrected_class = torch.argmax(probabilities[i]).item()
206
+ new_golden_probability = float(probabilities[i][corrected_class])
207
+ if corrected_class != original_class:
208
+ nbr_inverted += 1
209
+ inverted_texts.append(
210
+ InvertedText(
211
+ syn_dict, new_golden_probability, candidate, corrected_class
212
+ )
213
+ )
214
+ else:
215
+ delta = new_golden_probability - original_golden_prob
216
+ if delta <= 0:
217
+ bad += 1
218
+ else:
219
+ impr += 1
220
+ dict_deltas[candidate] = delta
221
+
222
+ if len(original_probs[0]) > 2 and len(inverted_texts) >= len(candidates) / (
223
+ len(original_probs[0])
224
+ ):
225
+ # selecting the most dominant class
226
+ dominant_class = self.find_dominant_class(inverted_texts)
227
+ elif len(inverted_texts) >= len(candidates) / 2:
228
+ dominant_class = corrected_class
229
+ else:
230
+ dominant_class = original_class
231
+
232
+ corrected_classes.append(dominant_class)
233
+
234
+ return corrected_classes
235
+
236
+
237
+ def remove_brackets(text):
238
+ text = text.replace("[[", "")
239
+ text = text.replace("]]", "")
240
+ return text
241
+
242
+
243
+ def clean_text(text):
244
+ pattern = "[" + re.escape(string.punctuation) + "]"
245
+ cleaned_text = re.sub(pattern, " ", text)
246
+
247
+ return cleaned_text
248
+
249
+
250
+ # Load model, tokenizer, and model_wrapper
251
+ model = transformers.AutoModelForSequenceClassification.from_pretrained(
252
+ "textattack/bert-base-uncased-imdb"
253
+ )
254
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
255
+ "textattack/bert-base-uncased-imdb"
256
+ )
257
+ model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
258
+
259
+
260
+ goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
261
+ constraints = [
262
+ RepeatModification(),
263
+ StopwordModification(),
264
+ WordEmbeddingDistance(min_cos_sim=0.9),
265
+ ]
266
+ transformation = WordSwapEmbedding(max_candidates=50)
267
+ search_method = GreedyWordSwapWIR(wir_method="gradient")
268
+
269
+ # Construct the actual attack
270
+ attack = textattack.Attack(goal_function, constraints, transformation, search_method)
271
+ attack.cuda_()
272
+
273
+
274
+ results = pd.read_csv("IMDB_results.csv")
275
+ perturbed_texts = [
276
+ results["perturbed_text"][i]
277
+ for i in range(len(results))
278
+ if results["result_type"][i] == "Successful"
279
+ ]
280
+ original_texts = [
281
+ results["original_text"][i]
282
+ for i in range(len(results))
283
+ if results["result_type"][i] == "Successful"
284
+ ]
285
+
286
+ perturbed_texts = [remove_brackets(text) for text in perturbed_texts]
287
+ original_texts = [remove_brackets(text) for text in original_texts]
288
+
289
+ perturbed_texts = [clean_text(text) for text in perturbed_texts]
290
+ original_texts = [clean_text(text) for text in original_texts]
291
+
292
+
293
+ victim_model = attack.goal_function.model
294
+
295
+ print("Getting corrected classes")
296
+ print("This may take a while ...")
297
+ # we can use directly resultds in csv file
298
+ original_classes = [
299
+ torch.argmax(F.softmax(victim_model(original_text), dim=1)).item()
300
+ for original_text in original_texts
301
+ ]
302
+
303
+ batch_size = 1000
304
+ num_batches = (len(perturbed_texts) + batch_size - 1) // batch_size
305
+ batched_perturbed_texts = []
306
+ batched_original_texts = []
307
+ batched_original_classes = []
308
+
309
+ for i in range(num_batches):
310
+ start = i * batch_size
311
+ end = min(start + batch_size, len(perturbed_texts))
312
+ batched_perturbed_texts.append(perturbed_texts[start:end])
313
+ batched_original_texts.append(original_texts[start:end])
314
+ batched_original_classes.append(original_classes[start:end])
315
+ print(batched_original_classes)
316
+ hard_samples_list = []
317
+ easy_samples_list = []
318
+
319
+
320
+ # Open a CSV file for writing
321
+ csv_filename = "flow_correction_results_imdb.csv"
322
+ with open(csv_filename, "w", newline="") as csvfile:
323
+ fieldnames = ["freq_threshold", "batch_num", "match_perturbed", "match_original"]
324
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
325
+
326
+ # Write the header row
327
+ writer.writeheader()
328
+
329
+ # Iterate over batched lists
330
+ batch_num = 0
331
+ for perturbed, original, classes in zip(
332
+ batched_perturbed_texts, batched_original_texts, batched_original_classes
333
+ ):
334
+ batch_num += 1
335
+ print(f"Processing batch number: {batch_num}")
336
+
337
+ for i in range(2):
338
+ wir_threshold = 0.1 * (i + 1)
339
+ print(f"Setting Word threshold to: {wir_threshold}")
340
+
341
+ corrector = Flow_Corrector(
342
+ attack,
343
+ word_rank_file="en_full_ranked.json",
344
+ word_freq_file="en_full_freq.json",
345
+ wir_threshold=wir_threshold,
346
+ )
347
+
348
+ # Correct perturbed texts
349
+ print("Correcting perturbed texts...")
350
+ corrected_perturbed_classes = corrector.correct(perturbed)
351
+
352
+ match_perturbed, hard_samples, easy_samples = count_matching_classes(
353
+ classes, corrected_perturbed_classes, perturbed
354
+ )
355
+ hard_samples_list.extend(hard_samples)
356
+ easy_samples_list.extend(easy_samples)
357
+
358
+
359
+ print(f"Number of matching classes (perturbed): {match_perturbed}")
360
+
361
+ # Correct original texts
362
+ print("Correcting original texts...")
363
+ corrected_original_classes = corrector.correct(original)
364
+ match_original, hard_samples, easy_samples = count_matching_classes(
365
+ classes, corrected_original_classes, perturbed
366
+ )
367
+ print(f"Number of matching classes (original): {match_original}")
368
+
369
+ # Write results to CSV file
370
+ print("Writing results to CSV file...")
371
+ writer.writerow(
372
+ {
373
+ "freq_threshold": wir_threshold,
374
+ "batch_num": batch_num,
375
+ "match_perturbed": match_perturbed/len(perturbed),
376
+ "match_original": match_original/len(perturbed),
377
+ }
378
+ )
379
+ print("-" * 20)
380
+
381
+ print("savig samples for more statistics studies")
382
+
383
+ # Save hard_samples_list and easy_samples_list to files
384
+ with open('hard_samples.pkl', 'wb') as f:
385
+ pickle.dump(hard_samples_list, f)
386
+
387
+ with open('easy_samples.pkl', 'wb') as f:
388
+ pickle.dump(easy_samples_list, f)