DCWIR-Demo / flow_correction_imdb.py
PFEemp2024's picture
Upload 2 files
1ef6bf0 verified
import textattack
import transformers
import pandas as pd
import csv
import string
import pickle
# Construct our four components for `Attack`
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.transformations import WordSwapEmbedding
from textattack.search_methods import GreedyWordSwapWIR
import numpy as np
import json
import random
import re
import textattack.shared.attacked_text as atk
import torch.nn.functional as F
import torch
class InvertedText:
def __init__(
self,
swapped_indexes,
score,
attacked_text,
new_class,
):
self.attacked_text = attacked_text
self.swapped_indexes = (
swapped_indexes # dict of swapped indexes with their synonym
)
self.score = score # value of original class
self.new_class = new_class # class after inversion
def __repr__(self):
return f"InvertedText:\n attacked_text='{self.attacked_text}', \n swapped_indexes={self.swapped_indexes},\n score={self.score}"
def count_matching_classes(original, corrected, perturbed_texts=None):
if len(original) != len(corrected):
raise ValueError("Arrays must have the same length")
hard_samples = []
easy_samples = []
matching_count = 0
for i in range(len(corrected)):
if original[i] == corrected[i]:
matching_count += 1
easy_samples.append(perturbed_texts[i])
elif perturbed_texts != None:
hard_samples.append(perturbed_texts[i])
return matching_count, hard_samples, easy_samples
class Flow_Corrector:
def __init__(
self,
attack,
word_rank_file="en_full_ranked.json",
word_freq_file="en_full_freq.json",
wir_threshold=0.3,
):
self.attack = attack
self.attack.cuda_()
self.wir_threshold = wir_threshold
with open(word_rank_file, "r") as f:
self.word_ranked_frequence = json.load(f)
with open(word_freq_file, "r") as f:
self.word_frequence = json.load(f)
self.victim_model = attack.goal_function.model
def wir_gradient(
self,
attack,
victim_model,
detected_text,
):
_, indices_to_order = attack.get_indices_to_order(detected_text)
index_scores = np.zeros(len(indices_to_order))
grad_output = victim_model.get_grad(detected_text.tokenizer_input)
gradient = grad_output["gradient"]
word2token_mapping = detected_text.align_with_model_tokens(victim_model)
for i, index in enumerate(indices_to_order):
matched_tokens = word2token_mapping[index]
if not matched_tokens:
index_scores[i] = 0.0
else:
agg_grad = np.mean(gradient[matched_tokens], axis=0)
index_scores[i] = np.linalg.norm(agg_grad, ord=1)
index_order = np.array(indices_to_order)[(-index_scores).argsort()]
return index_order
def get_syn_freq_dict(
self,
index_order,
detected_text,
):
most_frequent_syn_dict = {}
no_syn = []
freq_thershold = len(self.word_ranked_frequence) / 10
for idx in index_order:
# get the synonyms of a specific index
try:
synonyms = [
attacked_text.words[idx]
for attacked_text in self.attack.get_transformations(
detected_text, detected_text, indices_to_modify=[idx]
)
]
# getting synonyms that exists in dataset with thiere frequency rank
ranked_synonyms = {
syn: self.word_ranked_frequence[syn]
for syn in synonyms
if syn in self.word_ranked_frequence.keys()
and self.word_ranked_frequence[syn] < freq_thershold
and self.word_ranked_frequence[detected_text.words[idx]]
> self.word_ranked_frequence[syn]
}
# selecting the M most frequent synonym
if list(ranked_synonyms.keys()) != []:
most_frequent_syn_dict[idx] = list(ranked_synonyms.keys())
except:
# no synonyms avaialble in the dataset
no_syn.append(idx)
return most_frequent_syn_dict
def build_candidates(
self, detected_text, most_frequent_syn_dict: dict, max_attempt: int
):
candidates = {}
for _ in range(max_attempt):
syn_dict = {}
current_text = detected_text
for index in most_frequent_syn_dict.keys():
syn = random.choice(most_frequent_syn_dict[index])
syn_dict[index] = syn
current_text = current_text.replace_word_at_index(index, syn)
candidates[current_text] = syn_dict
return candidates
def find_dominant_class(self, inverted_texts):
class_counts = {} # Dictionary to store the count of each new class
for text in inverted_texts:
new_class = text.new_class
class_counts[new_class] = class_counts.get(new_class, 0) + 1
# Find the most dominant class
most_dominant_class = max(class_counts, key=class_counts.get)
return most_dominant_class
def correct(self, detected_texts):
corrected_classes = []
for detected_text in detected_texts:
# convert to Attacked texts
detected_text = atk.AttackedText(detected_text)
# getting 30% most important indexes
index_order = self.wir_gradient(
self.attack, self.victim_model, detected_text
)
index_order = index_order[: int(len(index_order) * self.wir_threshold)]
# getting synonyms according to frequency conditiontions
most_frequent_syn_dict = self.get_syn_freq_dict(index_order, detected_text)
# generate M candidates
candidates = self.build_candidates(
detected_text, most_frequent_syn_dict, max_attempt=100
)
original_probs = F.softmax(self.victim_model(detected_text.text), dim=1)
original_class = torch.argmax(original_probs).item()
original_golden_prob = float(original_probs[0][original_class])
nbr_inverted = 0
inverted_texts = [] # a dictionary of inverted texts with
bad, impr = 0, 0
dict_deltas = {}
batch_inputs = [candidate.text for candidate in candidates.keys()]
batch_outputs = self.victim_model(batch_inputs)
probabilities = F.softmax(batch_outputs, dim=1)
for i, (candidate, syn_dict) in enumerate(candidates.items()):
corrected_class = torch.argmax(probabilities[i]).item()
new_golden_probability = float(probabilities[i][corrected_class])
if corrected_class != original_class:
nbr_inverted += 1
inverted_texts.append(
InvertedText(
syn_dict, new_golden_probability, candidate, corrected_class
)
)
else:
delta = new_golden_probability - original_golden_prob
if delta <= 0:
bad += 1
else:
impr += 1
dict_deltas[candidate] = delta
if len(original_probs[0]) > 2 and len(inverted_texts) >= len(candidates) / (
len(original_probs[0])
):
# selecting the most dominant class
dominant_class = self.find_dominant_class(inverted_texts)
elif len(inverted_texts) >= len(candidates) / 2:
dominant_class = corrected_class
else:
dominant_class = original_class
corrected_classes.append(dominant_class)
return corrected_classes
def remove_brackets(text):
text = text.replace("[[", "")
text = text.replace("]]", "")
return text
def clean_text(text):
pattern = "[" + re.escape(string.punctuation) + "]"
cleaned_text = re.sub(pattern, " ", text)
return cleaned_text
# Load model, tokenizer, and model_wrapper
model = transformers.AutoModelForSequenceClassification.from_pretrained(
"textattack/bert-base-uncased-imdb"
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
"textattack/bert-base-uncased-imdb"
)
model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
constraints = [
RepeatModification(),
StopwordModification(),
WordEmbeddingDistance(min_cos_sim=0.9),
]
transformation = WordSwapEmbedding(max_candidates=50)
search_method = GreedyWordSwapWIR(wir_method="gradient")
# Construct the actual attack
attack = textattack.Attack(goal_function, constraints, transformation, search_method)
attack.cuda_()
results = pd.read_csv("IMDB_results.csv")
perturbed_texts = [
results["perturbed_text"][i]
for i in range(len(results))
if results["result_type"][i] == "Successful"
]
original_texts = [
results["original_text"][i]
for i in range(len(results))
if results["result_type"][i] == "Successful"
]
perturbed_texts = [remove_brackets(text) for text in perturbed_texts]
original_texts = [remove_brackets(text) for text in original_texts]
perturbed_texts = [clean_text(text) for text in perturbed_texts]
original_texts = [clean_text(text) for text in original_texts]
victim_model = attack.goal_function.model
print("Getting corrected classes")
print("This may take a while ...")
# we can use directly resultds in csv file
original_classes = [
torch.argmax(F.softmax(victim_model(original_text), dim=1)).item()
for original_text in original_texts
]
batch_size = 1000
num_batches = (len(perturbed_texts) + batch_size - 1) // batch_size
batched_perturbed_texts = []
batched_original_texts = []
batched_original_classes = []
for i in range(num_batches):
start = i * batch_size
end = min(start + batch_size, len(perturbed_texts))
batched_perturbed_texts.append(perturbed_texts[start:end])
batched_original_texts.append(original_texts[start:end])
batched_original_classes.append(original_classes[start:end])
print(batched_original_classes)
hard_samples_list = []
easy_samples_list = []
# Open a CSV file for writing
csv_filename = "flow_correction_results_imdb.csv"
with open(csv_filename, "w", newline="") as csvfile:
fieldnames = ["freq_threshold", "batch_num", "match_perturbed", "match_original"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
# Write the header row
writer.writeheader()
# Iterate over batched lists
batch_num = 0
for perturbed, original, classes in zip(
batched_perturbed_texts, batched_original_texts, batched_original_classes
):
batch_num += 1
print(f"Processing batch number: {batch_num}")
for i in range(2):
wir_threshold = 0.1 * (i + 1)
print(f"Setting Word threshold to: {wir_threshold}")
corrector = Flow_Corrector(
attack,
word_rank_file="en_full_ranked.json",
word_freq_file="en_full_freq.json",
wir_threshold=wir_threshold,
)
# Correct perturbed texts
print("Correcting perturbed texts...")
corrected_perturbed_classes = corrector.correct(perturbed)
match_perturbed, hard_samples, easy_samples = count_matching_classes(
classes, corrected_perturbed_classes, perturbed
)
hard_samples_list.extend(hard_samples)
easy_samples_list.extend(easy_samples)
print(f"Number of matching classes (perturbed): {match_perturbed}")
# Correct original texts
print("Correcting original texts...")
corrected_original_classes = corrector.correct(original)
match_original, hard_samples, easy_samples = count_matching_classes(
classes, corrected_original_classes, perturbed
)
print(f"Number of matching classes (original): {match_original}")
# Write results to CSV file
print("Writing results to CSV file...")
writer.writerow(
{
"freq_threshold": wir_threshold,
"batch_num": batch_num,
"match_perturbed": match_perturbed/len(perturbed),
"match_original": match_original/len(perturbed),
}
)
print("-" * 20)
print("savig samples for more statistics studies")
# Save hard_samples_list and easy_samples_list to files
with open('hard_samples.pkl', 'wb') as f:
pickle.dump(hard_samples_list, f)
with open('easy_samples.pkl', 'wb') as f:
pickle.dump(easy_samples_list, f)