import textattack import transformers from FlowCorrector import Flow_Corrector import torch import torch.nn.functional as F def count_matching_classes(original, corrected): if len(original) != len(corrected): raise ValueError("Arrays must have the same length") matching_count = 0 for i in range(len(corrected)): if original[i] == corrected[i]: matching_count += 1 return matching_count if __name__ == "main" : # Load model, tokenizer, and model_wrapper model = transformers.AutoModelForSequenceClassification.from_pretrained( "textattack/bert-base-uncased-ag-news" ) tokenizer = transformers.AutoTokenizer.from_pretrained( "textattack/bert-base-uncased-ag-news" ) model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) # Construct our four components for `Attack` from textattack.constraints.pre_transformation import ( RepeatModification, StopwordModification, ) from textattack.constraints.semantics import WordEmbeddingDistance from textattack.transformations import WordSwapEmbedding from textattack.search_methods import GreedyWordSwapWIR goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper) constraints = [ RepeatModification(), StopwordModification(), WordEmbeddingDistance(min_cos_sim=0.9), ] transformation = WordSwapEmbedding(max_candidates=50) search_method = GreedyWordSwapWIR(wir_method="weighted-saliency") # Construct the actual attack attack = textattack.Attack(goal_function, constraints, transformation, search_method) attack.cuda_() # intialisation de coreecteur corrector = Flow_Corrector( attack, word_rank_file="en_full_ranked.json", word_freq_file="en_full_freq.json", ) # All these texts are adverserial ones with open('perturbed_texts_ag_news.txt', 'r') as f: detected_texts = [line.strip() for line in f] #These are orginal texts in same order of adverserial ones with open("original_texts_ag_news.txt", "r") as f: original_texts = [line.strip() for line in f] victim_model = attack.goal_function.model # getting original labels for benchmarking later original_classes = [ torch.argmax(F.softmax(victim_model(original_text), dim=1)).item() for original_text in original_texts ] """ 0 :World 1 : Sports 2 : Business 3 : Sci/Tech""" corrected_classes = corrector.correct(original_texts) print(f"match {count_matching_classes()}")