PFEemp2024 commited on
Commit
c2c01a0
1 Parent(s): e8a3717

Upload 2 files

Browse files
Files changed (2) hide show
  1. gitignore +143 -0
  2. utils.py +234 -0
gitignore ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dev files
2
+ *.cache
3
+ *.dev.py
4
+ state_dict/
5
+ TAD*/
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+ *.pyc
11
+ tests/
12
+ *.result.json
13
+ .idea/
14
+
15
+ # Embedding
16
+ glove.840B.300d.txt
17
+ glove.42B.300d.txt
18
+ glove.twitter.27B.txt
19
+
20
+ # project main files
21
+ release_note.json
22
+
23
+ # C extensions
24
+ *.so
25
+
26
+ # Distribution / packaging
27
+ .Python
28
+ build/
29
+ develop-eggs/
30
+ dist/
31
+ downloads/
32
+ eggs/
33
+ .eggs/
34
+ lib64/
35
+ parts/
36
+ sdist/
37
+ var/
38
+ wheels/
39
+ pip-wheel-metadata/
40
+ share/python-wheels/
41
+ *.egg-info/
42
+ .installed.cfg
43
+ *.egg
44
+ MANIFEST
45
+
46
+ # PyInstaller
47
+ # Usually these files are written by a python script from a template
48
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
49
+ *.manifest
50
+ *.spec
51
+
52
+ # Installer training_logs
53
+ pip-log.txt
54
+ pip-delete-this-directory.txt
55
+
56
+ # Unit test / coverage reports
57
+ htmlcov/
58
+ .tox/
59
+ .nox/
60
+ .coverage
61
+ .coverage.*
62
+ .cache
63
+ nosetests.xml
64
+ coverage.xml
65
+ *.cover
66
+ *.py,cover
67
+ .hypothesis/
68
+ .pytest_cache/
69
+
70
+ # Translations
71
+ *.mo
72
+ *.pot
73
+
74
+ # Django stuff:
75
+ *.log
76
+ local_settings.py
77
+ db.sqlite3
78
+ db.sqlite3-journal
79
+
80
+ # Flask stuff:
81
+ instance/
82
+ .webassets-cache
83
+
84
+ # Scrapy stuff:
85
+ .scrapy
86
+
87
+ # Sphinx documentation
88
+ docs/_build/
89
+
90
+ # PyBuilder
91
+ target/
92
+
93
+ # Jupyter Notebook
94
+ .ipynb_checkpoints
95
+
96
+ # IPython
97
+ profile_default/
98
+ ipython_config.py
99
+
100
+ # pyenv
101
+ .python-version
102
+
103
+ # pipenv
104
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
105
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
106
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
107
+ # install all needed dependencies.
108
+ #Pipfile.lock
109
+
110
+ # celery beat schedule file
111
+ celerybeat-schedule
112
+
113
+ # SageMath parsed files
114
+ *.sage.py
115
+
116
+ # Environments
117
+ .env
118
+ .venv
119
+ env/
120
+ venv/
121
+ ENV/
122
+ env.bak/
123
+ venv.bak/
124
+
125
+ # Spyder project settings
126
+ .spyderproject
127
+ .spyproject
128
+
129
+ # Rope project settings
130
+ .ropeproject
131
+
132
+ # mkdocs documentation
133
+ /site
134
+
135
+ # mypy
136
+ .mypy_cache/
137
+ .dmypy.json
138
+ dmypy.json
139
+
140
+ # Pyre type checker
141
+ .pyre/
142
+ .DS_Store
143
+ examples/.DS_Store
utils.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from difflib import Differ
3
+
4
+ from textattack.attack_recipes import BAEGarg2019
5
+ from textattack.datasets import Dataset
6
+ from textattack.models.wrappers import HuggingFaceModelWrapper
7
+ from findfile import find_files
8
+ from flask import Flask
9
+ from textattack import Attacker
10
+
11
+
12
+ class ModelWrapper(HuggingFaceModelWrapper):
13
+ def __init__(self, model):
14
+ self.model = model # pipeline = pipeline
15
+
16
+ def __call__(self, text_inputs, **kwargs):
17
+ outputs = []
18
+ for text_input in text_inputs:
19
+ raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
20
+ outputs.append(raw_outputs["probs"])
21
+ return outputs
22
+
23
+
24
+ class SentAttacker:
25
+ def __init__(self, model, recipe_class=BAEGarg2019):
26
+ model = model
27
+ model_wrapper = ModelWrapper(model)
28
+
29
+ recipe = recipe_class.build(model_wrapper)
30
+ # WordNet defaults to english. Set the default language to French ('fra')
31
+
32
+ # recipe.transformation.language = "en"
33
+
34
+ _dataset = [("", 0)]
35
+ _dataset = Dataset(_dataset)
36
+
37
+ self.attacker = Attacker(recipe, _dataset)
38
+
39
+
40
+ def diff_texts(text1, text2):
41
+ d = Differ()
42
+ return [
43
+ (token[2:], token[0] if token[0] != " " else None)
44
+ for token in d.compare(text1, text2)
45
+ ]
46
+
47
+
48
+ def get_ensembled_tad_results(results):
49
+ target_dict = {}
50
+ for r in results:
51
+ target_dict[r["label"]] = (
52
+ target_dict.get(r["label"]) + 1 if r["label"] in target_dict else 1
53
+ )
54
+
55
+ return dict(zip(target_dict.values(), target_dict.keys()))[
56
+ max(target_dict.values())
57
+ ]
58
+
59
+
60
+
61
+ def get_sst2_example():
62
+ filter_key_words = [
63
+ ".py",
64
+ ".md",
65
+ "readme",
66
+ "log",
67
+ "result",
68
+ "zip",
69
+ ".state_dict",
70
+ ".model",
71
+ ".png",
72
+ "acc_",
73
+ "f1_",
74
+ ".origin",
75
+ ".adv",
76
+ ".csv",
77
+ ]
78
+
79
+ dataset_file = {"train": [], "test": [], "valid": []}
80
+ dataset = "sst2"
81
+ search_path = "./"
82
+ task = "text_defense"
83
+ dataset_file["test"] += find_files(
84
+ search_path,
85
+ [dataset, "test", task],
86
+ exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
87
+ + filter_key_words,
88
+ )
89
+
90
+ for dat_type in ["test"]:
91
+ data = []
92
+ label_set = set()
93
+ for data_file in dataset_file[dat_type]:
94
+ with open(data_file, mode="r", encoding="utf8") as fin:
95
+ lines = fin.readlines()
96
+ for line in lines:
97
+ text, label = line.split("$LABEL$")
98
+ text = text.strip()
99
+ label = int(label.strip())
100
+ data.append((text, label))
101
+ label_set.add(label)
102
+ return random.choice(data)
103
+
104
+
105
+ def get_agnews_example():
106
+ filter_key_words = [
107
+ ".py",
108
+ ".md",
109
+ "readme",
110
+ "log",
111
+ "result",
112
+ "zip",
113
+ ".state_dict",
114
+ ".model",
115
+ ".png",
116
+ "acc_",
117
+ "f1_",
118
+ ".origin",
119
+ ".adv",
120
+ ".csv",
121
+ ]
122
+
123
+ dataset_file = {"train": [], "test": [], "valid": []}
124
+ dataset = "agnews"
125
+ search_path = "./"
126
+ task = "text_defense"
127
+ dataset_file["test"] += find_files(
128
+ search_path,
129
+ [dataset, "test", task],
130
+ exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
131
+ + filter_key_words,
132
+ )
133
+ for dat_type in ["test"]:
134
+ data = []
135
+ label_set = set()
136
+ for data_file in dataset_file[dat_type]:
137
+ with open(data_file, mode="r", encoding="utf8") as fin:
138
+ lines = fin.readlines()
139
+ for line in lines:
140
+ text, label = line.split("$LABEL$")
141
+ text = text.strip()
142
+ label = int(label.strip())
143
+ data.append((text, label))
144
+ label_set.add(label)
145
+ return random.choice(data)
146
+
147
+
148
+ def get_amazon_example():
149
+ filter_key_words = [
150
+ ".py",
151
+ ".md",
152
+ "readme",
153
+ "log",
154
+ "result",
155
+ "zip",
156
+ ".state_dict",
157
+ ".model",
158
+ ".png",
159
+ "acc_",
160
+ "f1_",
161
+ ".origin",
162
+ ".adv",
163
+ ".csv",
164
+ ]
165
+
166
+ dataset_file = {"train": [], "test": [], "valid": []}
167
+ dataset = "amazon"
168
+ search_path = "./"
169
+ task = "text_defense"
170
+ dataset_file["test"] += find_files(
171
+ search_path,
172
+ [dataset, "test", task],
173
+ exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
174
+ + filter_key_words,
175
+ )
176
+
177
+ for dat_type in ["test"]:
178
+ data = []
179
+ label_set = set()
180
+ for data_file in dataset_file[dat_type]:
181
+ with open(data_file, mode="r", encoding="utf8") as fin:
182
+ lines = fin.readlines()
183
+ for line in lines:
184
+ text, label = line.split("$LABEL$")
185
+ text = text.strip()
186
+ label = int(label.strip())
187
+ data.append((text, label))
188
+ label_set.add(label)
189
+ return random.choice(data)
190
+
191
+
192
+ def get_imdb_example():
193
+ filter_key_words = [
194
+ ".py",
195
+ ".md",
196
+ "readme",
197
+ "log",
198
+ "result",
199
+ "zip",
200
+ ".state_dict",
201
+ ".model",
202
+ ".png",
203
+ "acc_",
204
+ "f1_",
205
+ ".origin",
206
+ ".adv",
207
+ ".csv",
208
+ ]
209
+
210
+ dataset_file = {"train": [], "test": [], "valid": []}
211
+ dataset = "imdb"
212
+ search_path = "./"
213
+ task = "text_defense"
214
+ dataset_file["test"] += find_files(
215
+ search_path,
216
+ [dataset, "test", task],
217
+ exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
218
+ + filter_key_words,
219
+ )
220
+
221
+ for dat_type in ["test"]:
222
+ data = []
223
+ label_set = set()
224
+ for data_file in dataset_file[dat_type]:
225
+ with open(data_file, mode="r", encoding="utf8") as fin:
226
+ lines = fin.readlines()
227
+ for line in lines:
228
+ text, label = line.split("$LABEL$")
229
+ text = text.strip()
230
+ label = int(label.strip())
231
+ data.append((text, label))
232
+ label_set.add(label)
233
+ return random.choice(data)
234
+