import textattack import transformers import pandas as pd import csv import string import pickle # Construct our four components for `Attack` from textattack.constraints.pre_transformation import ( RepeatModification, StopwordModification, ) from textattack.constraints.semantics import WordEmbeddingDistance from textattack.transformations import WordSwapEmbedding from textattack.search_methods import GreedyWordSwapWIR import numpy as np import json import random import re import textattack.shared.attacked_text as atk import torch.nn.functional as F import torch class InvertedText: def __init__( self, swapped_indexes, score, attacked_text, new_class, ): self.attacked_text = attacked_text self.swapped_indexes = ( swapped_indexes # dict of swapped indexes with their synonym ) self.score = score # value of original class self.new_class = new_class # class after inversion def __repr__(self): return f"InvertedText:\n attacked_text='{self.attacked_text}', \n swapped_indexes={self.swapped_indexes},\n score={self.score}" def count_matching_classes(original, corrected, perturbed_texts=None): if len(original) != len(corrected): raise ValueError("Arrays must have the same length") hard_samples = [] easy_samples = [] matching_count = 0 for i in range(len(corrected)): if original[i] == corrected[i]: matching_count += 1 easy_samples.append(perturbed_texts[i]) elif perturbed_texts != None: hard_samples.append(perturbed_texts[i]) return matching_count, hard_samples, easy_samples class Flow_Corrector: def __init__( self, attack, word_rank_file="en_full_ranked.json", word_freq_file="en_full_freq.json", wir_threshold=0.3, ): self.attack = attack self.attack.cuda_() self.wir_threshold = wir_threshold with open(word_rank_file, "r") as f: self.word_ranked_frequence = json.load(f) with open(word_freq_file, "r") as f: self.word_frequence = json.load(f) self.victim_model = attack.goal_function.model def wir_gradient( self, attack, victim_model, detected_text, ): _, indices_to_order = attack.get_indices_to_order(detected_text) index_scores = np.zeros(len(indices_to_order)) grad_output = victim_model.get_grad(detected_text.tokenizer_input) gradient = grad_output["gradient"] word2token_mapping = detected_text.align_with_model_tokens(victim_model) for i, index in enumerate(indices_to_order): matched_tokens = word2token_mapping[index] if not matched_tokens: index_scores[i] = 0.0 else: agg_grad = np.mean(gradient[matched_tokens], axis=0) index_scores[i] = np.linalg.norm(agg_grad, ord=1) index_order = np.array(indices_to_order)[(-index_scores).argsort()] return index_order def get_syn_freq_dict( self, index_order, detected_text, ): most_frequent_syn_dict = {} no_syn = [] freq_thershold = len(self.word_ranked_frequence) / 10 for idx in index_order: # get the synonyms of a specific index try: synonyms = [ attacked_text.words[idx] for attacked_text in self.attack.get_transformations( detected_text, detected_text, indices_to_modify=[idx] ) ] # getting synonyms that exists in dataset with thiere frequency rank ranked_synonyms = { syn: self.word_ranked_frequence[syn] for syn in synonyms if syn in self.word_ranked_frequence.keys() and self.word_ranked_frequence[syn] < freq_thershold and self.word_ranked_frequence[detected_text.words[idx]] > self.word_ranked_frequence[syn] } # selecting the M most frequent synonym if list(ranked_synonyms.keys()) != []: most_frequent_syn_dict[idx] = list(ranked_synonyms.keys()) except: # no synonyms avaialble in the dataset no_syn.append(idx) return most_frequent_syn_dict def build_candidates( self, detected_text, most_frequent_syn_dict: dict, max_attempt: int ): candidates = {} for _ in range(max_attempt): syn_dict = {} current_text = detected_text for index in most_frequent_syn_dict.keys(): syn = random.choice(most_frequent_syn_dict[index]) syn_dict[index] = syn current_text = current_text.replace_word_at_index(index, syn) candidates[current_text] = syn_dict return candidates def find_dominant_class(self, inverted_texts): class_counts = {} # Dictionary to store the count of each new class for text in inverted_texts: new_class = text.new_class class_counts[new_class] = class_counts.get(new_class, 0) + 1 # Find the most dominant class most_dominant_class = max(class_counts, key=class_counts.get) return most_dominant_class def correct(self, detected_texts): corrected_classes = [] for detected_text in detected_texts: # convert to Attacked texts detected_text = atk.AttackedText(detected_text) # getting 30% most important indexes index_order = self.wir_gradient( self.attack, self.victim_model, detected_text ) index_order = index_order[: int(len(index_order) * self.wir_threshold)] # getting synonyms according to frequency conditiontions most_frequent_syn_dict = self.get_syn_freq_dict(index_order, detected_text) # generate M candidates candidates = self.build_candidates( detected_text, most_frequent_syn_dict, max_attempt=100 ) original_probs = F.softmax(self.victim_model(detected_text.text), dim=1) original_class = torch.argmax(original_probs).item() original_golden_prob = float(original_probs[0][original_class]) nbr_inverted = 0 inverted_texts = [] # a dictionary of inverted texts with bad, impr = 0, 0 dict_deltas = {} batch_inputs = [candidate.text for candidate in candidates.keys()] batch_outputs = self.victim_model(batch_inputs) probabilities = F.softmax(batch_outputs, dim=1) for i, (candidate, syn_dict) in enumerate(candidates.items()): corrected_class = torch.argmax(probabilities[i]).item() new_golden_probability = float(probabilities[i][corrected_class]) if corrected_class != original_class: nbr_inverted += 1 inverted_texts.append( InvertedText( syn_dict, new_golden_probability, candidate, corrected_class ) ) else: delta = new_golden_probability - original_golden_prob if delta <= 0: bad += 1 else: impr += 1 dict_deltas[candidate] = delta if len(original_probs[0]) > 2 and len(inverted_texts) >= len(candidates) / ( len(original_probs[0]) ): # selecting the most dominant class dominant_class = self.find_dominant_class(inverted_texts) elif len(inverted_texts) >= len(candidates) / 2: dominant_class = corrected_class else: dominant_class = original_class corrected_classes.append(dominant_class) return corrected_classes def remove_brackets(text): text = text.replace("[[", "") text = text.replace("]]", "") return text def clean_text(text): pattern = "[" + re.escape(string.punctuation) + "]" cleaned_text = re.sub(pattern, " ", text) return cleaned_text # Load model, tokenizer, and model_wrapper model = transformers.AutoModelForSequenceClassification.from_pretrained( "textattack/bert-base-uncased-imdb" ) tokenizer = transformers.AutoTokenizer.from_pretrained( "textattack/bert-base-uncased-imdb" ) model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper) constraints = [ RepeatModification(), StopwordModification(), WordEmbeddingDistance(min_cos_sim=0.9), ] transformation = WordSwapEmbedding(max_candidates=50) search_method = GreedyWordSwapWIR(wir_method="gradient") # Construct the actual attack attack = textattack.Attack(goal_function, constraints, transformation, search_method) attack.cuda_() results = pd.read_csv("IMDB_results.csv") perturbed_texts = [ results["perturbed_text"][i] for i in range(len(results)) if results["result_type"][i] == "Successful" ] original_texts = [ results["original_text"][i] for i in range(len(results)) if results["result_type"][i] == "Successful" ] perturbed_texts = [remove_brackets(text) for text in perturbed_texts] original_texts = [remove_brackets(text) for text in original_texts] perturbed_texts = [clean_text(text) for text in perturbed_texts] original_texts = [clean_text(text) for text in original_texts] victim_model = attack.goal_function.model print("Getting corrected classes") print("This may take a while ...") # we can use directly resultds in csv file original_classes = [ torch.argmax(F.softmax(victim_model(original_text), dim=1)).item() for original_text in original_texts ] batch_size = 1000 num_batches = (len(perturbed_texts) + batch_size - 1) // batch_size batched_perturbed_texts = [] batched_original_texts = [] batched_original_classes = [] for i in range(num_batches): start = i * batch_size end = min(start + batch_size, len(perturbed_texts)) batched_perturbed_texts.append(perturbed_texts[start:end]) batched_original_texts.append(original_texts[start:end]) batched_original_classes.append(original_classes[start:end]) print(batched_original_classes) hard_samples_list = [] easy_samples_list = [] # Open a CSV file for writing csv_filename = "flow_correction_results_imdb.csv" with open(csv_filename, "w", newline="") as csvfile: fieldnames = ["freq_threshold", "batch_num", "match_perturbed", "match_original"] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) # Write the header row writer.writeheader() # Iterate over batched lists batch_num = 0 for perturbed, original, classes in zip( batched_perturbed_texts, batched_original_texts, batched_original_classes ): batch_num += 1 print(f"Processing batch number: {batch_num}") for i in range(2): wir_threshold = 0.1 * (i + 1) print(f"Setting Word threshold to: {wir_threshold}") corrector = Flow_Corrector( attack, word_rank_file="en_full_ranked.json", word_freq_file="en_full_freq.json", wir_threshold=wir_threshold, ) # Correct perturbed texts print("Correcting perturbed texts...") corrected_perturbed_classes = corrector.correct(perturbed) match_perturbed, hard_samples, easy_samples = count_matching_classes( classes, corrected_perturbed_classes, perturbed ) hard_samples_list.extend(hard_samples) easy_samples_list.extend(easy_samples) print(f"Number of matching classes (perturbed): {match_perturbed}") # Correct original texts print("Correcting original texts...") corrected_original_classes = corrector.correct(original) match_original, hard_samples, easy_samples = count_matching_classes( classes, corrected_original_classes, perturbed ) print(f"Number of matching classes (original): {match_original}") # Write results to CSV file print("Writing results to CSV file...") writer.writerow( { "freq_threshold": wir_threshold, "batch_num": batch_num, "match_perturbed": match_perturbed/len(perturbed), "match_original": match_original/len(perturbed), } ) print("-" * 20) print("savig samples for more statistics studies") # Save hard_samples_list and easy_samples_list to files with open('hard_samples.pkl', 'wb') as f: pickle.dump(hard_samples_list, f) with open('easy_samples.pkl', 'wb') as f: pickle.dump(easy_samples_list, f)