DCWIR-Demo / main_correction.py
PFEemp2024's picture
adding the main file for the correction process
b16fdae verified
import textattack
import transformers
from FlowCorrector import Flow_Corrector
import torch
import torch.nn.functional as F
def count_matching_classes(original, corrected):
if len(original) != len(corrected):
raise ValueError("Arrays must have the same length")
matching_count = 0
for i in range(len(corrected)):
if original[i] == corrected[i]:
matching_count += 1
return matching_count
if __name__ == "main" :
# Load model, tokenizer, and model_wrapper
model = transformers.AutoModelForSequenceClassification.from_pretrained(
"textattack/bert-base-uncased-ag-news"
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
"textattack/bert-base-uncased-ag-news"
)
model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
# Construct our four components for `Attack`
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.transformations import WordSwapEmbedding
from textattack.search_methods import GreedyWordSwapWIR
goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
constraints = [
RepeatModification(),
StopwordModification(),
WordEmbeddingDistance(min_cos_sim=0.9),
]
transformation = WordSwapEmbedding(max_candidates=50)
search_method = GreedyWordSwapWIR(wir_method="weighted-saliency")
# Construct the actual attack
attack = textattack.Attack(goal_function, constraints, transformation, search_method)
attack.cuda_()
# intialisation de coreecteur
corrector = Flow_Corrector(
attack,
word_rank_file="en_full_ranked.json",
word_freq_file="en_full_freq.json",
)
# All these texts are adverserial ones
with open('perturbed_texts_ag_news.txt', 'r') as f:
detected_texts = [line.strip() for line in f]
#These are orginal texts in same order of adverserial ones
with open("original_texts_ag_news.txt", "r") as f:
original_texts = [line.strip() for line in f]
victim_model = attack.goal_function.model
# getting original labels for benchmarking later
original_classes = [
torch.argmax(F.softmax(victim_model(original_text), dim=1)).item()
for original_text in original_texts
]
""" 0 :World
1 : Sports
2 : Business
3 : Sci/Tech"""
corrected_classes = corrector.correct(original_texts)
print(f"match {count_matching_classes()}")