PFEemp2024 commited on
Commit
b16fdae
1 Parent(s): 18b7cae

adding the main file for the correction process

Browse files
Files changed (1) hide show
  1. main_correction.py +89 -0
main_correction.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import textattack
2
+ import transformers
3
+ from FlowCorrector import Flow_Corrector
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ def count_matching_classes(original, corrected):
8
+ if len(original) != len(corrected):
9
+ raise ValueError("Arrays must have the same length")
10
+
11
+ matching_count = 0
12
+
13
+ for i in range(len(corrected)):
14
+ if original[i] == corrected[i]:
15
+ matching_count += 1
16
+
17
+ return matching_count
18
+
19
+ if __name__ == "main" :
20
+
21
+ # Load model, tokenizer, and model_wrapper
22
+ model = transformers.AutoModelForSequenceClassification.from_pretrained(
23
+ "textattack/bert-base-uncased-ag-news"
24
+ )
25
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
26
+ "textattack/bert-base-uncased-ag-news"
27
+ )
28
+ model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
29
+
30
+ # Construct our four components for `Attack`
31
+ from textattack.constraints.pre_transformation import (
32
+ RepeatModification,
33
+ StopwordModification,
34
+ )
35
+ from textattack.constraints.semantics import WordEmbeddingDistance
36
+ from textattack.transformations import WordSwapEmbedding
37
+ from textattack.search_methods import GreedyWordSwapWIR
38
+
39
+ goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
40
+ constraints = [
41
+ RepeatModification(),
42
+ StopwordModification(),
43
+ WordEmbeddingDistance(min_cos_sim=0.9),
44
+ ]
45
+ transformation = WordSwapEmbedding(max_candidates=50)
46
+ search_method = GreedyWordSwapWIR(wir_method="weighted-saliency")
47
+
48
+ # Construct the actual attack
49
+ attack = textattack.Attack(goal_function, constraints, transformation, search_method)
50
+ attack.cuda_()
51
+
52
+ # intialisation de coreecteur
53
+ corrector = Flow_Corrector(
54
+ attack,
55
+ word_rank_file="en_full_ranked.json",
56
+ word_freq_file="en_full_freq.json",
57
+ )
58
+
59
+ # All these texts are adverserial ones
60
+
61
+ with open('perturbed_texts_ag_news.txt', 'r') as f:
62
+ detected_texts = [line.strip() for line in f]
63
+
64
+
65
+ #These are orginal texts in same order of adverserial ones
66
+
67
+ with open("original_texts_ag_news.txt", "r") as f:
68
+ original_texts = [line.strip() for line in f]
69
+
70
+ victim_model = attack.goal_function.model
71
+
72
+ # getting original labels for benchmarking later
73
+ original_classes = [
74
+ torch.argmax(F.softmax(victim_model(original_text), dim=1)).item()
75
+ for original_text in original_texts
76
+ ]
77
+
78
+ """ 0 :World
79
+ 1 : Sports
80
+ 2 : Business
81
+ 3 : Sci/Tech"""
82
+
83
+ corrected_classes = corrector.correct(original_texts)
84
+ print(f"match {count_matching_classes()}")
85
+
86
+
87
+
88
+
89
+