Spaces:
Runtime error
Runtime error
PFEemp2024
commited on
Commit
•
1ef6bf0
1
Parent(s):
e3f5f4d
Upload 2 files
Browse files- flow_correction_ag_news.py +388 -0
- flow_correction_imdb.py +388 -0
flow_correction_ag_news.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import textattack
|
2 |
+
import transformers
|
3 |
+
import pandas as pd
|
4 |
+
import csv
|
5 |
+
import string
|
6 |
+
import pickle
|
7 |
+
# Construct our four components for `Attack`
|
8 |
+
from textattack.constraints.pre_transformation import (
|
9 |
+
RepeatModification,
|
10 |
+
StopwordModification,
|
11 |
+
)
|
12 |
+
from textattack.constraints.semantics import WordEmbeddingDistance
|
13 |
+
from textattack.transformations import WordSwapEmbedding
|
14 |
+
from textattack.search_methods import GreedyWordSwapWIR
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import json
|
18 |
+
import random
|
19 |
+
import re
|
20 |
+
import textattack.shared.attacked_text as atk
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import torch
|
23 |
+
|
24 |
+
|
25 |
+
class InvertedText:
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
swapped_indexes,
|
30 |
+
score,
|
31 |
+
attacked_text,
|
32 |
+
new_class,
|
33 |
+
):
|
34 |
+
self.attacked_text = attacked_text
|
35 |
+
self.swapped_indexes = (
|
36 |
+
swapped_indexes # dict of swapped indexes with their synonym
|
37 |
+
)
|
38 |
+
self.score = score # value of original class
|
39 |
+
self.new_class = new_class # class after inversion
|
40 |
+
|
41 |
+
def __repr__(self):
|
42 |
+
return f"InvertedText:\n attacked_text='{self.attacked_text}', \n swapped_indexes={self.swapped_indexes},\n score={self.score}"
|
43 |
+
|
44 |
+
|
45 |
+
def count_matching_classes(original, corrected, perturbed_texts=None):
|
46 |
+
if len(original) != len(corrected):
|
47 |
+
raise ValueError("Arrays must have the same length")
|
48 |
+
hard_samples = []
|
49 |
+
easy_samples = []
|
50 |
+
|
51 |
+
matching_count = 0
|
52 |
+
|
53 |
+
for i in range(len(corrected)):
|
54 |
+
if original[i] == corrected[i]:
|
55 |
+
matching_count += 1
|
56 |
+
easy_samples.append(perturbed_texts[i])
|
57 |
+
elif perturbed_texts != None:
|
58 |
+
hard_samples.append(perturbed_texts[i])
|
59 |
+
|
60 |
+
return matching_count, hard_samples, easy_samples
|
61 |
+
|
62 |
+
|
63 |
+
class Flow_Corrector:
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
attack,
|
67 |
+
word_rank_file="en_full_ranked.json",
|
68 |
+
word_freq_file="en_full_freq.json",
|
69 |
+
wir_threshold=0.3,
|
70 |
+
):
|
71 |
+
self.attack = attack
|
72 |
+
self.attack.cuda_()
|
73 |
+
self.wir_threshold = wir_threshold
|
74 |
+
with open(word_rank_file, "r") as f:
|
75 |
+
self.word_ranked_frequence = json.load(f)
|
76 |
+
with open(word_freq_file, "r") as f:
|
77 |
+
self.word_frequence = json.load(f)
|
78 |
+
self.victim_model = attack.goal_function.model
|
79 |
+
|
80 |
+
def wir_gradient(
|
81 |
+
self,
|
82 |
+
attack,
|
83 |
+
victim_model,
|
84 |
+
detected_text,
|
85 |
+
):
|
86 |
+
_, indices_to_order = attack.get_indices_to_order(detected_text)
|
87 |
+
|
88 |
+
index_scores = np.zeros(len(indices_to_order))
|
89 |
+
grad_output = victim_model.get_grad(detected_text.tokenizer_input)
|
90 |
+
gradient = grad_output["gradient"]
|
91 |
+
word2token_mapping = detected_text.align_with_model_tokens(victim_model)
|
92 |
+
for i, index in enumerate(indices_to_order):
|
93 |
+
matched_tokens = word2token_mapping[index]
|
94 |
+
if not matched_tokens:
|
95 |
+
index_scores[i] = 0.0
|
96 |
+
else:
|
97 |
+
agg_grad = np.mean(gradient[matched_tokens], axis=0)
|
98 |
+
index_scores[i] = np.linalg.norm(agg_grad, ord=1)
|
99 |
+
index_order = np.array(indices_to_order)[(-index_scores).argsort()]
|
100 |
+
return index_order
|
101 |
+
|
102 |
+
def get_syn_freq_dict(
|
103 |
+
self,
|
104 |
+
index_order,
|
105 |
+
detected_text,
|
106 |
+
):
|
107 |
+
most_frequent_syn_dict = {}
|
108 |
+
|
109 |
+
no_syn = []
|
110 |
+
freq_thershold = len(self.word_ranked_frequence) / 10
|
111 |
+
|
112 |
+
for idx in index_order:
|
113 |
+
# get the synonyms of a specific index
|
114 |
+
|
115 |
+
try:
|
116 |
+
synonyms = [
|
117 |
+
attacked_text.words[idx]
|
118 |
+
for attacked_text in self.attack.get_transformations(
|
119 |
+
detected_text, detected_text, indices_to_modify=[idx]
|
120 |
+
)
|
121 |
+
]
|
122 |
+
# getting synonyms that exists in dataset with thiere frequency rank
|
123 |
+
ranked_synonyms = {
|
124 |
+
syn: self.word_ranked_frequence[syn]
|
125 |
+
for syn in synonyms
|
126 |
+
if syn in self.word_ranked_frequence.keys()
|
127 |
+
and self.word_ranked_frequence[syn] < freq_thershold
|
128 |
+
and self.word_ranked_frequence[detected_text.words[idx]]
|
129 |
+
> self.word_ranked_frequence[syn]
|
130 |
+
}
|
131 |
+
# selecting the M most frequent synonym
|
132 |
+
|
133 |
+
if list(ranked_synonyms.keys()) != []:
|
134 |
+
most_frequent_syn_dict[idx] = list(ranked_synonyms.keys())
|
135 |
+
except:
|
136 |
+
# no synonyms avaialble in the dataset
|
137 |
+
no_syn.append(idx)
|
138 |
+
|
139 |
+
return most_frequent_syn_dict
|
140 |
+
|
141 |
+
def build_candidates(
|
142 |
+
self, detected_text, most_frequent_syn_dict: dict, max_attempt: int
|
143 |
+
):
|
144 |
+
candidates = {}
|
145 |
+
for _ in range(max_attempt):
|
146 |
+
syn_dict = {}
|
147 |
+
current_text = detected_text
|
148 |
+
for index in most_frequent_syn_dict.keys():
|
149 |
+
syn = random.choice(most_frequent_syn_dict[index])
|
150 |
+
syn_dict[index] = syn
|
151 |
+
current_text = current_text.replace_word_at_index(index, syn)
|
152 |
+
|
153 |
+
candidates[current_text] = syn_dict
|
154 |
+
return candidates
|
155 |
+
|
156 |
+
def find_dominant_class(self, inverted_texts):
|
157 |
+
class_counts = {} # Dictionary to store the count of each new class
|
158 |
+
|
159 |
+
for text in inverted_texts:
|
160 |
+
new_class = text.new_class
|
161 |
+
class_counts[new_class] = class_counts.get(new_class, 0) + 1
|
162 |
+
|
163 |
+
# Find the most dominant class
|
164 |
+
most_dominant_class = max(class_counts, key=class_counts.get)
|
165 |
+
|
166 |
+
return most_dominant_class
|
167 |
+
|
168 |
+
def correct(self, detected_texts):
|
169 |
+
corrected_classes = []
|
170 |
+
for detected_text in detected_texts:
|
171 |
+
|
172 |
+
# convert to Attacked texts
|
173 |
+
detected_text = atk.AttackedText(detected_text)
|
174 |
+
|
175 |
+
# getting 30% most important indexes
|
176 |
+
index_order = self.wir_gradient(
|
177 |
+
self.attack, self.victim_model, detected_text
|
178 |
+
)
|
179 |
+
index_order = index_order[: int(len(index_order) * self.wir_threshold)]
|
180 |
+
|
181 |
+
# getting synonyms according to frequency conditiontions
|
182 |
+
most_frequent_syn_dict = self.get_syn_freq_dict(index_order, detected_text)
|
183 |
+
|
184 |
+
# generate M candidates
|
185 |
+
candidates = self.build_candidates(
|
186 |
+
detected_text, most_frequent_syn_dict, max_attempt=100
|
187 |
+
)
|
188 |
+
|
189 |
+
original_probs = F.softmax(self.victim_model(detected_text.text), dim=1)
|
190 |
+
original_class = torch.argmax(original_probs).item()
|
191 |
+
original_golden_prob = float(original_probs[0][original_class])
|
192 |
+
|
193 |
+
nbr_inverted = 0
|
194 |
+
inverted_texts = [] # a dictionary of inverted texts with
|
195 |
+
bad, impr = 0, 0
|
196 |
+
dict_deltas = {}
|
197 |
+
|
198 |
+
batch_inputs = [candidate.text for candidate in candidates.keys()]
|
199 |
+
|
200 |
+
batch_outputs = self.victim_model(batch_inputs)
|
201 |
+
|
202 |
+
probabilities = F.softmax(batch_outputs, dim=1)
|
203 |
+
for i, (candidate, syn_dict) in enumerate(candidates.items()):
|
204 |
+
|
205 |
+
corrected_class = torch.argmax(probabilities[i]).item()
|
206 |
+
new_golden_probability = float(probabilities[i][corrected_class])
|
207 |
+
if corrected_class != original_class:
|
208 |
+
nbr_inverted += 1
|
209 |
+
inverted_texts.append(
|
210 |
+
InvertedText(
|
211 |
+
syn_dict, new_golden_probability, candidate, corrected_class
|
212 |
+
)
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
delta = new_golden_probability - original_golden_prob
|
216 |
+
if delta <= 0:
|
217 |
+
bad += 1
|
218 |
+
else:
|
219 |
+
impr += 1
|
220 |
+
dict_deltas[candidate] = delta
|
221 |
+
|
222 |
+
if len(original_probs[0]) > 2 and len(inverted_texts) >= len(candidates) / (
|
223 |
+
len(original_probs[0])
|
224 |
+
):
|
225 |
+
# selecting the most dominant class
|
226 |
+
dominant_class = self.find_dominant_class(inverted_texts)
|
227 |
+
elif len(inverted_texts) >= len(candidates) / 2:
|
228 |
+
dominant_class = corrected_class
|
229 |
+
else:
|
230 |
+
dominant_class = original_class
|
231 |
+
|
232 |
+
corrected_classes.append(dominant_class)
|
233 |
+
|
234 |
+
return corrected_classes
|
235 |
+
|
236 |
+
|
237 |
+
def remove_brackets(text):
|
238 |
+
text = text.replace("[[", "")
|
239 |
+
text = text.replace("]]", "")
|
240 |
+
return text
|
241 |
+
|
242 |
+
|
243 |
+
def clean_text(text):
|
244 |
+
pattern = "[" + re.escape(string.punctuation) + "]"
|
245 |
+
cleaned_text = re.sub(pattern, " ", text)
|
246 |
+
|
247 |
+
return cleaned_text
|
248 |
+
|
249 |
+
|
250 |
+
# Load model, tokenizer, and model_wrapper
|
251 |
+
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
252 |
+
"textattack/bert-base-uncased-ag-news"
|
253 |
+
)
|
254 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
255 |
+
"textattack/bert-base-uncased-ag-news"
|
256 |
+
)
|
257 |
+
model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
|
258 |
+
|
259 |
+
|
260 |
+
goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
|
261 |
+
constraints = [
|
262 |
+
RepeatModification(),
|
263 |
+
StopwordModification(),
|
264 |
+
WordEmbeddingDistance(min_cos_sim=0.9),
|
265 |
+
]
|
266 |
+
transformation = WordSwapEmbedding(max_candidates=50)
|
267 |
+
search_method = GreedyWordSwapWIR(wir_method="gradient")
|
268 |
+
|
269 |
+
# Construct the actual attack
|
270 |
+
attack = textattack.Attack(goal_function, constraints, transformation, search_method)
|
271 |
+
attack.cuda_()
|
272 |
+
|
273 |
+
|
274 |
+
results = pd.read_csv("ag_news_results.csv")
|
275 |
+
perturbed_texts = [
|
276 |
+
results["perturbed_text"][i]
|
277 |
+
for i in range(len(results))
|
278 |
+
if results["result_type"][i] == "Successful"
|
279 |
+
]
|
280 |
+
original_texts = [
|
281 |
+
results["original_text"][i]
|
282 |
+
for i in range(len(results))
|
283 |
+
if results["result_type"][i] == "Successful"
|
284 |
+
]
|
285 |
+
|
286 |
+
perturbed_texts = [remove_brackets(text) for text in perturbed_texts]
|
287 |
+
original_texts = [remove_brackets(text) for text in original_texts]
|
288 |
+
|
289 |
+
perturbed_texts = [clean_text(text) for text in perturbed_texts]
|
290 |
+
original_texts = [clean_text(text) for text in original_texts]
|
291 |
+
|
292 |
+
|
293 |
+
victim_model = attack.goal_function.model
|
294 |
+
|
295 |
+
print("Getting corrected classes")
|
296 |
+
print("This may take a while ...")
|
297 |
+
# we can use directly resultds in csv file
|
298 |
+
original_classes = [
|
299 |
+
torch.argmax(F.softmax(victim_model(original_text), dim=1)).item()
|
300 |
+
for original_text in original_texts
|
301 |
+
]
|
302 |
+
|
303 |
+
batch_size = 1000
|
304 |
+
num_batches = (len(perturbed_texts) + batch_size - 1) // batch_size
|
305 |
+
batched_perturbed_texts = []
|
306 |
+
batched_original_texts = []
|
307 |
+
batched_original_classes = []
|
308 |
+
|
309 |
+
for i in range(num_batches):
|
310 |
+
start = i * batch_size
|
311 |
+
end = min(start + batch_size, len(perturbed_texts))
|
312 |
+
batched_perturbed_texts.append(perturbed_texts[start:end])
|
313 |
+
batched_original_texts.append(original_texts[start:end])
|
314 |
+
batched_original_classes.append(original_classes[start:end])
|
315 |
+
print(batched_original_classes)
|
316 |
+
hard_samples_list = []
|
317 |
+
easy_samples_list = []
|
318 |
+
|
319 |
+
|
320 |
+
# Open a CSV file for writing
|
321 |
+
csv_filename = "flow_correction_results_ag_news.csv"
|
322 |
+
with open(csv_filename, "w", newline="") as csvfile:
|
323 |
+
fieldnames = ["freq_threshold", "batch_num", "match_perturbed", "match_original"]
|
324 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
325 |
+
|
326 |
+
# Write the header row
|
327 |
+
writer.writeheader()
|
328 |
+
|
329 |
+
# Iterate over batched lists
|
330 |
+
batch_num = 0
|
331 |
+
for perturbed, original, classes in zip(
|
332 |
+
batched_perturbed_texts, batched_original_texts, batched_original_classes
|
333 |
+
):
|
334 |
+
batch_num += 1
|
335 |
+
print(f"Processing batch number: {batch_num}")
|
336 |
+
|
337 |
+
for i in range(2):
|
338 |
+
wir_threshold = 0.1 * (i + 1)
|
339 |
+
print(f"Setting Word threshold to: {wir_threshold}")
|
340 |
+
|
341 |
+
corrector = Flow_Corrector(
|
342 |
+
attack,
|
343 |
+
word_rank_file="en_full_ranked.json",
|
344 |
+
word_freq_file="en_full_freq.json",
|
345 |
+
wir_threshold=wir_threshold,
|
346 |
+
)
|
347 |
+
|
348 |
+
# Correct perturbed texts
|
349 |
+
print("Correcting perturbed texts...")
|
350 |
+
corrected_perturbed_classes = corrector.correct(perturbed)
|
351 |
+
|
352 |
+
match_perturbed, hard_samples, easy_samples = count_matching_classes(
|
353 |
+
classes, corrected_perturbed_classes, perturbed
|
354 |
+
)
|
355 |
+
hard_samples_list.extend(hard_samples)
|
356 |
+
easy_samples_list.extend(easy_samples)
|
357 |
+
|
358 |
+
|
359 |
+
print(f"Number of matching classes (perturbed): {match_perturbed}")
|
360 |
+
|
361 |
+
# Correct original texts
|
362 |
+
print("Correcting original texts...")
|
363 |
+
corrected_original_classes = corrector.correct(original)
|
364 |
+
match_original, hard_samples, easy_samples = count_matching_classes(
|
365 |
+
classes, corrected_original_classes, perturbed
|
366 |
+
)
|
367 |
+
print(f"Number of matching classes (original): {match_original}")
|
368 |
+
|
369 |
+
# Write results to CSV file
|
370 |
+
print("Writing results to CSV file...")
|
371 |
+
writer.writerow(
|
372 |
+
{
|
373 |
+
"freq_threshold": wir_threshold,
|
374 |
+
"batch_num": batch_num,
|
375 |
+
"match_perturbed": match_perturbed/len(perturbed),
|
376 |
+
"match_original": match_original/len(perturbed),
|
377 |
+
}
|
378 |
+
)
|
379 |
+
print("-" * 20)
|
380 |
+
|
381 |
+
print("savig samples for more statistics studies")
|
382 |
+
|
383 |
+
# Save hard_samples_list and easy_samples_list to files
|
384 |
+
with open('hard_samples.pkl', 'wb') as f:
|
385 |
+
pickle.dump(hard_samples_list, f)
|
386 |
+
|
387 |
+
with open('easy_samples.pkl', 'wb') as f:
|
388 |
+
pickle.dump(easy_samples_list, f)
|
flow_correction_imdb.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import textattack
|
2 |
+
import transformers
|
3 |
+
import pandas as pd
|
4 |
+
import csv
|
5 |
+
import string
|
6 |
+
import pickle
|
7 |
+
# Construct our four components for `Attack`
|
8 |
+
from textattack.constraints.pre_transformation import (
|
9 |
+
RepeatModification,
|
10 |
+
StopwordModification,
|
11 |
+
)
|
12 |
+
from textattack.constraints.semantics import WordEmbeddingDistance
|
13 |
+
from textattack.transformations import WordSwapEmbedding
|
14 |
+
from textattack.search_methods import GreedyWordSwapWIR
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import json
|
18 |
+
import random
|
19 |
+
import re
|
20 |
+
import textattack.shared.attacked_text as atk
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import torch
|
23 |
+
|
24 |
+
|
25 |
+
class InvertedText:
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
swapped_indexes,
|
30 |
+
score,
|
31 |
+
attacked_text,
|
32 |
+
new_class,
|
33 |
+
):
|
34 |
+
self.attacked_text = attacked_text
|
35 |
+
self.swapped_indexes = (
|
36 |
+
swapped_indexes # dict of swapped indexes with their synonym
|
37 |
+
)
|
38 |
+
self.score = score # value of original class
|
39 |
+
self.new_class = new_class # class after inversion
|
40 |
+
|
41 |
+
def __repr__(self):
|
42 |
+
return f"InvertedText:\n attacked_text='{self.attacked_text}', \n swapped_indexes={self.swapped_indexes},\n score={self.score}"
|
43 |
+
|
44 |
+
|
45 |
+
def count_matching_classes(original, corrected, perturbed_texts=None):
|
46 |
+
if len(original) != len(corrected):
|
47 |
+
raise ValueError("Arrays must have the same length")
|
48 |
+
hard_samples = []
|
49 |
+
easy_samples = []
|
50 |
+
|
51 |
+
matching_count = 0
|
52 |
+
|
53 |
+
for i in range(len(corrected)):
|
54 |
+
if original[i] == corrected[i]:
|
55 |
+
matching_count += 1
|
56 |
+
easy_samples.append(perturbed_texts[i])
|
57 |
+
elif perturbed_texts != None:
|
58 |
+
hard_samples.append(perturbed_texts[i])
|
59 |
+
|
60 |
+
return matching_count, hard_samples, easy_samples
|
61 |
+
|
62 |
+
|
63 |
+
class Flow_Corrector:
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
attack,
|
67 |
+
word_rank_file="en_full_ranked.json",
|
68 |
+
word_freq_file="en_full_freq.json",
|
69 |
+
wir_threshold=0.3,
|
70 |
+
):
|
71 |
+
self.attack = attack
|
72 |
+
self.attack.cuda_()
|
73 |
+
self.wir_threshold = wir_threshold
|
74 |
+
with open(word_rank_file, "r") as f:
|
75 |
+
self.word_ranked_frequence = json.load(f)
|
76 |
+
with open(word_freq_file, "r") as f:
|
77 |
+
self.word_frequence = json.load(f)
|
78 |
+
self.victim_model = attack.goal_function.model
|
79 |
+
|
80 |
+
def wir_gradient(
|
81 |
+
self,
|
82 |
+
attack,
|
83 |
+
victim_model,
|
84 |
+
detected_text,
|
85 |
+
):
|
86 |
+
_, indices_to_order = attack.get_indices_to_order(detected_text)
|
87 |
+
|
88 |
+
index_scores = np.zeros(len(indices_to_order))
|
89 |
+
grad_output = victim_model.get_grad(detected_text.tokenizer_input)
|
90 |
+
gradient = grad_output["gradient"]
|
91 |
+
word2token_mapping = detected_text.align_with_model_tokens(victim_model)
|
92 |
+
for i, index in enumerate(indices_to_order):
|
93 |
+
matched_tokens = word2token_mapping[index]
|
94 |
+
if not matched_tokens:
|
95 |
+
index_scores[i] = 0.0
|
96 |
+
else:
|
97 |
+
agg_grad = np.mean(gradient[matched_tokens], axis=0)
|
98 |
+
index_scores[i] = np.linalg.norm(agg_grad, ord=1)
|
99 |
+
index_order = np.array(indices_to_order)[(-index_scores).argsort()]
|
100 |
+
return index_order
|
101 |
+
|
102 |
+
def get_syn_freq_dict(
|
103 |
+
self,
|
104 |
+
index_order,
|
105 |
+
detected_text,
|
106 |
+
):
|
107 |
+
most_frequent_syn_dict = {}
|
108 |
+
|
109 |
+
no_syn = []
|
110 |
+
freq_thershold = len(self.word_ranked_frequence) / 10
|
111 |
+
|
112 |
+
for idx in index_order:
|
113 |
+
# get the synonyms of a specific index
|
114 |
+
|
115 |
+
try:
|
116 |
+
synonyms = [
|
117 |
+
attacked_text.words[idx]
|
118 |
+
for attacked_text in self.attack.get_transformations(
|
119 |
+
detected_text, detected_text, indices_to_modify=[idx]
|
120 |
+
)
|
121 |
+
]
|
122 |
+
# getting synonyms that exists in dataset with thiere frequency rank
|
123 |
+
ranked_synonyms = {
|
124 |
+
syn: self.word_ranked_frequence[syn]
|
125 |
+
for syn in synonyms
|
126 |
+
if syn in self.word_ranked_frequence.keys()
|
127 |
+
and self.word_ranked_frequence[syn] < freq_thershold
|
128 |
+
and self.word_ranked_frequence[detected_text.words[idx]]
|
129 |
+
> self.word_ranked_frequence[syn]
|
130 |
+
}
|
131 |
+
# selecting the M most frequent synonym
|
132 |
+
|
133 |
+
if list(ranked_synonyms.keys()) != []:
|
134 |
+
most_frequent_syn_dict[idx] = list(ranked_synonyms.keys())
|
135 |
+
except:
|
136 |
+
# no synonyms avaialble in the dataset
|
137 |
+
no_syn.append(idx)
|
138 |
+
|
139 |
+
return most_frequent_syn_dict
|
140 |
+
|
141 |
+
def build_candidates(
|
142 |
+
self, detected_text, most_frequent_syn_dict: dict, max_attempt: int
|
143 |
+
):
|
144 |
+
candidates = {}
|
145 |
+
for _ in range(max_attempt):
|
146 |
+
syn_dict = {}
|
147 |
+
current_text = detected_text
|
148 |
+
for index in most_frequent_syn_dict.keys():
|
149 |
+
syn = random.choice(most_frequent_syn_dict[index])
|
150 |
+
syn_dict[index] = syn
|
151 |
+
current_text = current_text.replace_word_at_index(index, syn)
|
152 |
+
|
153 |
+
candidates[current_text] = syn_dict
|
154 |
+
return candidates
|
155 |
+
|
156 |
+
def find_dominant_class(self, inverted_texts):
|
157 |
+
class_counts = {} # Dictionary to store the count of each new class
|
158 |
+
|
159 |
+
for text in inverted_texts:
|
160 |
+
new_class = text.new_class
|
161 |
+
class_counts[new_class] = class_counts.get(new_class, 0) + 1
|
162 |
+
|
163 |
+
# Find the most dominant class
|
164 |
+
most_dominant_class = max(class_counts, key=class_counts.get)
|
165 |
+
|
166 |
+
return most_dominant_class
|
167 |
+
|
168 |
+
def correct(self, detected_texts):
|
169 |
+
corrected_classes = []
|
170 |
+
for detected_text in detected_texts:
|
171 |
+
|
172 |
+
# convert to Attacked texts
|
173 |
+
detected_text = atk.AttackedText(detected_text)
|
174 |
+
|
175 |
+
# getting 30% most important indexes
|
176 |
+
index_order = self.wir_gradient(
|
177 |
+
self.attack, self.victim_model, detected_text
|
178 |
+
)
|
179 |
+
index_order = index_order[: int(len(index_order) * self.wir_threshold)]
|
180 |
+
|
181 |
+
# getting synonyms according to frequency conditiontions
|
182 |
+
most_frequent_syn_dict = self.get_syn_freq_dict(index_order, detected_text)
|
183 |
+
|
184 |
+
# generate M candidates
|
185 |
+
candidates = self.build_candidates(
|
186 |
+
detected_text, most_frequent_syn_dict, max_attempt=100
|
187 |
+
)
|
188 |
+
|
189 |
+
original_probs = F.softmax(self.victim_model(detected_text.text), dim=1)
|
190 |
+
original_class = torch.argmax(original_probs).item()
|
191 |
+
original_golden_prob = float(original_probs[0][original_class])
|
192 |
+
|
193 |
+
nbr_inverted = 0
|
194 |
+
inverted_texts = [] # a dictionary of inverted texts with
|
195 |
+
bad, impr = 0, 0
|
196 |
+
dict_deltas = {}
|
197 |
+
|
198 |
+
batch_inputs = [candidate.text for candidate in candidates.keys()]
|
199 |
+
|
200 |
+
batch_outputs = self.victim_model(batch_inputs)
|
201 |
+
|
202 |
+
probabilities = F.softmax(batch_outputs, dim=1)
|
203 |
+
for i, (candidate, syn_dict) in enumerate(candidates.items()):
|
204 |
+
|
205 |
+
corrected_class = torch.argmax(probabilities[i]).item()
|
206 |
+
new_golden_probability = float(probabilities[i][corrected_class])
|
207 |
+
if corrected_class != original_class:
|
208 |
+
nbr_inverted += 1
|
209 |
+
inverted_texts.append(
|
210 |
+
InvertedText(
|
211 |
+
syn_dict, new_golden_probability, candidate, corrected_class
|
212 |
+
)
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
delta = new_golden_probability - original_golden_prob
|
216 |
+
if delta <= 0:
|
217 |
+
bad += 1
|
218 |
+
else:
|
219 |
+
impr += 1
|
220 |
+
dict_deltas[candidate] = delta
|
221 |
+
|
222 |
+
if len(original_probs[0]) > 2 and len(inverted_texts) >= len(candidates) / (
|
223 |
+
len(original_probs[0])
|
224 |
+
):
|
225 |
+
# selecting the most dominant class
|
226 |
+
dominant_class = self.find_dominant_class(inverted_texts)
|
227 |
+
elif len(inverted_texts) >= len(candidates) / 2:
|
228 |
+
dominant_class = corrected_class
|
229 |
+
else:
|
230 |
+
dominant_class = original_class
|
231 |
+
|
232 |
+
corrected_classes.append(dominant_class)
|
233 |
+
|
234 |
+
return corrected_classes
|
235 |
+
|
236 |
+
|
237 |
+
def remove_brackets(text):
|
238 |
+
text = text.replace("[[", "")
|
239 |
+
text = text.replace("]]", "")
|
240 |
+
return text
|
241 |
+
|
242 |
+
|
243 |
+
def clean_text(text):
|
244 |
+
pattern = "[" + re.escape(string.punctuation) + "]"
|
245 |
+
cleaned_text = re.sub(pattern, " ", text)
|
246 |
+
|
247 |
+
return cleaned_text
|
248 |
+
|
249 |
+
|
250 |
+
# Load model, tokenizer, and model_wrapper
|
251 |
+
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
252 |
+
"textattack/bert-base-uncased-imdb"
|
253 |
+
)
|
254 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
255 |
+
"textattack/bert-base-uncased-imdb"
|
256 |
+
)
|
257 |
+
model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
|
258 |
+
|
259 |
+
|
260 |
+
goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
|
261 |
+
constraints = [
|
262 |
+
RepeatModification(),
|
263 |
+
StopwordModification(),
|
264 |
+
WordEmbeddingDistance(min_cos_sim=0.9),
|
265 |
+
]
|
266 |
+
transformation = WordSwapEmbedding(max_candidates=50)
|
267 |
+
search_method = GreedyWordSwapWIR(wir_method="gradient")
|
268 |
+
|
269 |
+
# Construct the actual attack
|
270 |
+
attack = textattack.Attack(goal_function, constraints, transformation, search_method)
|
271 |
+
attack.cuda_()
|
272 |
+
|
273 |
+
|
274 |
+
results = pd.read_csv("IMDB_results.csv")
|
275 |
+
perturbed_texts = [
|
276 |
+
results["perturbed_text"][i]
|
277 |
+
for i in range(len(results))
|
278 |
+
if results["result_type"][i] == "Successful"
|
279 |
+
]
|
280 |
+
original_texts = [
|
281 |
+
results["original_text"][i]
|
282 |
+
for i in range(len(results))
|
283 |
+
if results["result_type"][i] == "Successful"
|
284 |
+
]
|
285 |
+
|
286 |
+
perturbed_texts = [remove_brackets(text) for text in perturbed_texts]
|
287 |
+
original_texts = [remove_brackets(text) for text in original_texts]
|
288 |
+
|
289 |
+
perturbed_texts = [clean_text(text) for text in perturbed_texts]
|
290 |
+
original_texts = [clean_text(text) for text in original_texts]
|
291 |
+
|
292 |
+
|
293 |
+
victim_model = attack.goal_function.model
|
294 |
+
|
295 |
+
print("Getting corrected classes")
|
296 |
+
print("This may take a while ...")
|
297 |
+
# we can use directly resultds in csv file
|
298 |
+
original_classes = [
|
299 |
+
torch.argmax(F.softmax(victim_model(original_text), dim=1)).item()
|
300 |
+
for original_text in original_texts
|
301 |
+
]
|
302 |
+
|
303 |
+
batch_size = 1000
|
304 |
+
num_batches = (len(perturbed_texts) + batch_size - 1) // batch_size
|
305 |
+
batched_perturbed_texts = []
|
306 |
+
batched_original_texts = []
|
307 |
+
batched_original_classes = []
|
308 |
+
|
309 |
+
for i in range(num_batches):
|
310 |
+
start = i * batch_size
|
311 |
+
end = min(start + batch_size, len(perturbed_texts))
|
312 |
+
batched_perturbed_texts.append(perturbed_texts[start:end])
|
313 |
+
batched_original_texts.append(original_texts[start:end])
|
314 |
+
batched_original_classes.append(original_classes[start:end])
|
315 |
+
print(batched_original_classes)
|
316 |
+
hard_samples_list = []
|
317 |
+
easy_samples_list = []
|
318 |
+
|
319 |
+
|
320 |
+
# Open a CSV file for writing
|
321 |
+
csv_filename = "flow_correction_results_imdb.csv"
|
322 |
+
with open(csv_filename, "w", newline="") as csvfile:
|
323 |
+
fieldnames = ["freq_threshold", "batch_num", "match_perturbed", "match_original"]
|
324 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
325 |
+
|
326 |
+
# Write the header row
|
327 |
+
writer.writeheader()
|
328 |
+
|
329 |
+
# Iterate over batched lists
|
330 |
+
batch_num = 0
|
331 |
+
for perturbed, original, classes in zip(
|
332 |
+
batched_perturbed_texts, batched_original_texts, batched_original_classes
|
333 |
+
):
|
334 |
+
batch_num += 1
|
335 |
+
print(f"Processing batch number: {batch_num}")
|
336 |
+
|
337 |
+
for i in range(2):
|
338 |
+
wir_threshold = 0.1 * (i + 1)
|
339 |
+
print(f"Setting Word threshold to: {wir_threshold}")
|
340 |
+
|
341 |
+
corrector = Flow_Corrector(
|
342 |
+
attack,
|
343 |
+
word_rank_file="en_full_ranked.json",
|
344 |
+
word_freq_file="en_full_freq.json",
|
345 |
+
wir_threshold=wir_threshold,
|
346 |
+
)
|
347 |
+
|
348 |
+
# Correct perturbed texts
|
349 |
+
print("Correcting perturbed texts...")
|
350 |
+
corrected_perturbed_classes = corrector.correct(perturbed)
|
351 |
+
|
352 |
+
match_perturbed, hard_samples, easy_samples = count_matching_classes(
|
353 |
+
classes, corrected_perturbed_classes, perturbed
|
354 |
+
)
|
355 |
+
hard_samples_list.extend(hard_samples)
|
356 |
+
easy_samples_list.extend(easy_samples)
|
357 |
+
|
358 |
+
|
359 |
+
print(f"Number of matching classes (perturbed): {match_perturbed}")
|
360 |
+
|
361 |
+
# Correct original texts
|
362 |
+
print("Correcting original texts...")
|
363 |
+
corrected_original_classes = corrector.correct(original)
|
364 |
+
match_original, hard_samples, easy_samples = count_matching_classes(
|
365 |
+
classes, corrected_original_classes, perturbed
|
366 |
+
)
|
367 |
+
print(f"Number of matching classes (original): {match_original}")
|
368 |
+
|
369 |
+
# Write results to CSV file
|
370 |
+
print("Writing results to CSV file...")
|
371 |
+
writer.writerow(
|
372 |
+
{
|
373 |
+
"freq_threshold": wir_threshold,
|
374 |
+
"batch_num": batch_num,
|
375 |
+
"match_perturbed": match_perturbed/len(perturbed),
|
376 |
+
"match_original": match_original/len(perturbed),
|
377 |
+
}
|
378 |
+
)
|
379 |
+
print("-" * 20)
|
380 |
+
|
381 |
+
print("savig samples for more statistics studies")
|
382 |
+
|
383 |
+
# Save hard_samples_list and easy_samples_list to files
|
384 |
+
with open('hard_samples.pkl', 'wb') as f:
|
385 |
+
pickle.dump(hard_samples_list, f)
|
386 |
+
|
387 |
+
with open('easy_samples.pkl', 'wb') as f:
|
388 |
+
pickle.dump(easy_samples_list, f)
|