anonymous8/RPD-Demo commited on
Commit
4943752
0 Parent(s):

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +31 -0
  2. .gitignore +143 -0
  3. README.md +13 -0
  4. anonymous_demo/__init__.py +5 -0
  5. anonymous_demo/core/__init__.py +0 -0
  6. anonymous_demo/core/tad/__init__.py +0 -0
  7. anonymous_demo/core/tad/classic/__bert__/README.MD +3 -0
  8. anonymous_demo/core/tad/classic/__bert__/__init__.py +1 -0
  9. anonymous_demo/core/tad/classic/__bert__/dataset_utils/__init__.py +0 -0
  10. anonymous_demo/core/tad/classic/__bert__/dataset_utils/data_utils_for_inference.py +116 -0
  11. anonymous_demo/core/tad/classic/__bert__/models/__init__.py +1 -0
  12. anonymous_demo/core/tad/classic/__bert__/models/tad_bert.py +43 -0
  13. anonymous_demo/core/tad/classic/__init__.py +0 -0
  14. anonymous_demo/core/tad/models/__init__.py +9 -0
  15. anonymous_demo/core/tad/prediction/__init__.py +0 -0
  16. anonymous_demo/core/tad/prediction/tad_classifier.py +390 -0
  17. anonymous_demo/functional/__init__.py +3 -0
  18. anonymous_demo/functional/checkpoint/__init__.py +1 -0
  19. anonymous_demo/functional/checkpoint/checkpoint_manager.py +20 -0
  20. anonymous_demo/functional/config/__init__.py +1 -0
  21. anonymous_demo/functional/config/config_manager.py +66 -0
  22. anonymous_demo/functional/config/tad_config_manager.py +221 -0
  23. anonymous_demo/functional/dataset/__init__.py +1 -0
  24. anonymous_demo/functional/dataset/dataset_manager.py +21 -0
  25. anonymous_demo/network/__init__.py +0 -0
  26. anonymous_demo/network/lcf_pooler.py +26 -0
  27. anonymous_demo/network/lsa.py +52 -0
  28. anonymous_demo/network/sa_encoder.py +159 -0
  29. anonymous_demo/utils/__init__.py +0 -0
  30. anonymous_demo/utils/demo_utils.py +209 -0
  31. anonymous_demo/utils/logger.py +38 -0
  32. app.py +271 -0
  33. checkpoints.zip +3 -0
  34. requirements.txt +19 -0
  35. text_defense/201.SST2/stsa.binary.dev.dat +0 -0
  36. text_defense/201.SST2/stsa.binary.test.dat +0 -0
  37. text_defense/201.SST2/stsa.binary.train.dat +0 -0
  38. text_defense/204.AGNews10K/AGNews10K.test.dat +0 -0
  39. text_defense/204.AGNews10K/AGNews10K.train.dat +0 -0
  40. text_defense/204.AGNews10K/AGNews10K.valid.dat +0 -0
  41. text_defense/206.Amazon_Review_Polarity10K/amazon.test.dat +0 -0
  42. text_defense/206.Amazon_Review_Polarity10K/amazon.train.dat +0 -0
  43. textattack/__init__.py +39 -0
  44. textattack/__main__.py +6 -0
  45. textattack/attack.py +492 -0
  46. textattack/attack_args.py +763 -0
  47. textattack/attack_recipes/__init__.py +43 -0
  48. textattack/attack_recipes/a2t_yoo_2021.py +74 -0
  49. textattack/attack_recipes/attack_recipe.py +30 -0
  50. textattack/attack_recipes/bae_garg_2019.py +123 -0
.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dev files
2
+ *.cache
3
+ *.dev.py
4
+ state_dict/
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+ *.pyc
11
+ tests/
12
+ *.result.json
13
+ .idea/
14
+
15
+ # Embedding
16
+ glove.840B.300d.txt
17
+ glove.42B.300d.txt
18
+ glove.twitter.27B.txt
19
+
20
+ # project main files
21
+ release_note.json
22
+
23
+ # C extensions
24
+ *.so
25
+
26
+ # Distribution / packaging
27
+ .Python
28
+ build/
29
+ develop-eggs/
30
+ dist/
31
+ downloads/
32
+ eggs/
33
+ .eggs/
34
+ lib64/
35
+ parts/
36
+ sdist/
37
+ var/
38
+ wheels/
39
+ pip-wheel-metadata/
40
+ share/python-wheels/
41
+ *.egg-info/
42
+ .installed.cfg
43
+ *.egg
44
+ MANIFEST
45
+
46
+ # PyInstaller
47
+ # Usually these files are written by a python script from a template
48
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
49
+ *.manifest
50
+ *.spec
51
+
52
+ # Installer training_logs
53
+ pip-log.txt
54
+ pip-delete-this-directory.txt
55
+
56
+ # Unit test / coverage reports
57
+ htmlcov/
58
+ .tox/
59
+ .nox/
60
+ .coverage
61
+ .coverage.*
62
+ .cache
63
+ nosetests.xml
64
+ coverage.xml
65
+ *.cover
66
+ *.py,cover
67
+ .hypothesis/
68
+ .pytest_cache/
69
+
70
+ # Translations
71
+ *.mo
72
+ *.pot
73
+
74
+ # Django stuff:
75
+ *.log
76
+ local_settings.py
77
+ db.sqlite3
78
+ db.sqlite3-journal
79
+
80
+ # Flask stuff:
81
+ instance/
82
+ .webassets-cache
83
+
84
+ # Scrapy stuff:
85
+ .scrapy
86
+
87
+ # Sphinx documentation
88
+ docs/_build/
89
+
90
+ # PyBuilder
91
+ target/
92
+
93
+ # Jupyter Notebook
94
+ .ipynb_checkpoints
95
+
96
+ # IPython
97
+ profile_default/
98
+ ipython_config.py
99
+
100
+ # pyenv
101
+ .python-version
102
+
103
+ # pipenv
104
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
105
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
106
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
107
+ # install all needed dependencies.
108
+ #Pipfile.lock
109
+
110
+ # celery beat schedule file
111
+ celerybeat-schedule
112
+
113
+ # SageMath parsed files
114
+ *.sage.py
115
+
116
+ # Environments
117
+ .env
118
+ .venv
119
+ env/
120
+ venv/
121
+ ENV/
122
+ env.bak/
123
+ venv.bak/
124
+
125
+ # Spyder project settings
126
+ .spyderproject
127
+ .spyproject
128
+
129
+ # Rope project settings
130
+ .ropeproject
131
+
132
+ # mkdocs documentation
133
+ /site
134
+
135
+ # mypy
136
+ .mypy_cache/
137
+ .dmypy.json
138
+ dmypy.json
139
+
140
+ # Pyre type checker
141
+ .pyre/
142
+ .DS_Store
143
+ examples/.DS_Store
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: RPD-Demo
3
+ emoji: 🛡️
4
+ colorFrom: gray
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.0.19
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
anonymous_demo/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ __version__ = '1.0.0'
2
+
3
+ __name__ = 'anonymous_demo'
4
+
5
+ from anonymous_demo.functional import TADCheckpointManager
anonymous_demo/core/__init__.py ADDED
File without changes
anonymous_demo/core/tad/__init__.py ADDED
File without changes
anonymous_demo/core/tad/classic/__bert__/README.MD ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ ## This is the simple migration from ABSA-PyTorch under MIT license
2
+
3
+ Project Address: https://github.com/songyouwei/ABSA-PyTorch
anonymous_demo/core/tad/classic/__bert__/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .models import *
anonymous_demo/core/tad/classic/__bert__/dataset_utils/__init__.py ADDED
File without changes
anonymous_demo/core/tad/classic/__bert__/dataset_utils/data_utils_for_inference.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ from findfile import find_cwd_dir
3
+ from torch.utils.data import Dataset
4
+ from transformers import AutoTokenizer
5
+
6
+
7
+ class Tokenizer4Pretraining:
8
+ def __init__(self, max_seq_len, opt, **kwargs):
9
+ if kwargs.pop('offline', False):
10
+ self.tokenizer = AutoTokenizer.from_pretrained(find_cwd_dir(opt.pretrained_bert.split('/')[-1]),
11
+ do_lower_case='uncased' in opt.pretrained_bert)
12
+ else:
13
+ self.tokenizer = AutoTokenizer.from_pretrained(opt.pretrained_bert,
14
+ do_lower_case='uncased' in opt.pretrained_bert)
15
+ self.max_seq_len = max_seq_len
16
+
17
+ def text_to_sequence(self, text, reverse=False, padding='post', truncating='post'):
18
+
19
+ return self.tokenizer.encode(text, truncation=True, padding='max_length', max_length=self.max_seq_len,
20
+ return_tensors='pt')
21
+
22
+
23
+ class BERTTADDataset(Dataset):
24
+
25
+ def __init__(self, tokenizer, opt):
26
+ self.bert_baseline_input_colses = {
27
+ 'bert': ['text_bert_indices']
28
+ }
29
+
30
+ self.tokenizer = tokenizer
31
+ self.opt = opt
32
+ self.all_data = []
33
+
34
+ def parse_sample(self, text):
35
+ return [text]
36
+
37
+ def prepare_infer_sample(self, text: str, ignore_error):
38
+ self.process_data(self.parse_sample(text), ignore_error=ignore_error)
39
+
40
+ def process_data(self, samples, ignore_error=True):
41
+ all_data = []
42
+ if len(samples) > 100:
43
+ it = tqdm.tqdm(samples, postfix='preparing text classification inference dataloader...')
44
+ else:
45
+ it = samples
46
+ for text in it:
47
+ try:
48
+ # handle for empty lines in inference datasets
49
+ if text is None or '' == text.strip():
50
+ raise RuntimeError('Invalid Input!')
51
+
52
+ if '!ref!' in text:
53
+ text, _, labels = text.strip().partition('!ref!')
54
+ text = text.strip()
55
+ if labels.count(',') == 2:
56
+ label, is_adv, adv_train_label = labels.strip().split(',')
57
+ label, is_adv, adv_train_label = label.strip(), is_adv.strip(), adv_train_label.strip()
58
+ elif labels.count(',') == 1:
59
+ label, is_adv = labels.strip().split(',')
60
+ label, is_adv = label.strip(), is_adv.strip()
61
+ adv_train_label = '-100'
62
+ elif labels.count(',') == 0:
63
+ label = labels.strip()
64
+ adv_train_label = '-100'
65
+ is_adv = '-100'
66
+ else:
67
+ label = '-100'
68
+ adv_train_label = '-100'
69
+ is_adv = '-100'
70
+
71
+ label = int(label)
72
+ adv_train_label = int(adv_train_label)
73
+ is_adv = int(is_adv)
74
+
75
+ else:
76
+ text = text.strip()
77
+ label = -100
78
+ adv_train_label = -100
79
+ is_adv = -100
80
+
81
+ text_indices = self.tokenizer.text_to_sequence('{}'.format(text))
82
+
83
+ data = {
84
+ 'text_bert_indices': text_indices[0],
85
+
86
+ 'text_raw': text,
87
+
88
+ 'label': label,
89
+
90
+ 'adv_train_label': adv_train_label,
91
+
92
+ 'is_adv': is_adv,
93
+
94
+ # 'label': self.opt.label_to_index.get(label, -100) if isinstance(label, str) else label,
95
+ #
96
+ # 'adv_train_label': self.opt.adv_train_label_to_index.get(adv_train_label, -100) if isinstance(adv_train_label, str) else adv_train_label,
97
+ #
98
+ # 'is_adv': self.opt.is_adv_to_index.get(is_adv, -100) if isinstance(is_adv, str) else is_adv,
99
+ }
100
+
101
+ all_data.append(data)
102
+
103
+ except Exception as e:
104
+ if ignore_error:
105
+ print('Ignore error while processing:', text)
106
+ else:
107
+ raise e
108
+
109
+ self.all_data = all_data
110
+ return self.all_data
111
+
112
+ def __getitem__(self, index):
113
+ return self.all_data[index]
114
+
115
+ def __len__(self):
116
+ return len(self.all_data)
anonymous_demo/core/tad/classic/__bert__/models/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .tad_bert import TADBERT
anonymous_demo/core/tad/classic/__bert__/models/tad_bert.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers.models.bert.modeling_bert import BertPooler
4
+
5
+ from anonymous_demo.network.sa_encoder import Encoder
6
+
7
+
8
+ class TADBERT(nn.Module):
9
+ inputs = ['text_bert_indices']
10
+
11
+ def __init__(self, bert, opt):
12
+ super(TADBERT, self).__init__()
13
+ self.opt = opt
14
+ self.bert = bert
15
+ self.pooler = BertPooler(bert.config)
16
+ self.dense1 = nn.Linear(self.opt.hidden_dim, self.opt.class_dim)
17
+ self.dense2 = nn.Linear(self.opt.hidden_dim, self.opt.adv_det_dim)
18
+ self.dense3 = nn.Linear(self.opt.hidden_dim, self.opt.class_dim)
19
+
20
+ self.encoder1 = Encoder(self.bert.config, opt=opt)
21
+ self.encoder2 = Encoder(self.bert.config, opt=opt)
22
+ self.encoder3 = Encoder(self.bert.config, opt=opt)
23
+
24
+ def forward(self, inputs):
25
+ text_raw_indices = inputs[0]
26
+ last_hidden_state = self.bert(text_raw_indices)['last_hidden_state']
27
+
28
+ sent_logits = self.dense1(self.pooler(last_hidden_state))
29
+ advdet_logits = self.dense2(self.pooler(last_hidden_state))
30
+ adv_tr_logits = self.dense3(self.pooler(last_hidden_state))
31
+
32
+ att_score = torch.nn.functional.normalize(
33
+ last_hidden_state.abs().sum(dim=1, keepdim=False) - last_hidden_state.abs().min(dim=1, keepdim=True)[0],
34
+ p=1, dim=1)
35
+
36
+ outputs = {
37
+ 'sent_logits': sent_logits,
38
+ 'advdet_logits': advdet_logits,
39
+ 'adv_tr_logits': adv_tr_logits,
40
+ 'last_hidden_state': last_hidden_state,
41
+ 'att_score': att_score
42
+ }
43
+ return outputs
anonymous_demo/core/tad/classic/__init__.py ADDED
File without changes
anonymous_demo/core/tad/models/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ import anonymous_demo.core.tad.classic.__bert__.models
2
+
3
+
4
+ class BERTTADModelList(list):
5
+ TADBERT = anonymous_demo.core.tad.classic.__bert__.TADBERT
6
+
7
+ def __init__(self):
8
+ model_list = [self.TADBERT]
9
+ super().__init__(model_list)
anonymous_demo/core/tad/prediction/__init__.py ADDED
File without changes
anonymous_demo/core/tad/prediction/tad_classifier.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pickle
4
+ import time
5
+
6
+ import torch
7
+ import tqdm
8
+ from findfile import find_file, find_cwd_dir
9
+ from termcolor import colored
10
+
11
+ from torch.utils.data import DataLoader
12
+ from transformers import AutoTokenizer, AutoModel, AutoConfig, DebertaV2ForMaskedLM, RobertaForMaskedLM, BertForMaskedLM
13
+
14
+ from ....functional.dataset.dataset_manager import detect_infer_dataset
15
+
16
+ from ..models import BERTTADModelList
17
+ from ..classic.__bert__.dataset_utils.data_utils_for_inference import BERTTADDataset, Tokenizer4Pretraining
18
+
19
+ from ....utils.demo_utils import print_args, TransformerConnectionError, get_device, build_embedding_matrix
20
+
21
+
22
+ def init_attacker(tad_classifier, defense):
23
+ try:
24
+ from textattack import Attacker
25
+ from textattack.attack_recipes import BAEGarg2019, PWWSRen2019, TextFoolerJin2019, PSOZang2020, IGAWang2019, \
26
+ GeneticAlgorithmAlzantot2018, DeepWordBugGao2018
27
+ from textattack.datasets import Dataset
28
+ from textattack.models.wrappers import HuggingFaceModelWrapper
29
+
30
+ class DemoModelWrapper(HuggingFaceModelWrapper):
31
+ def __init__(self, model):
32
+ self.model = model # pipeline = pipeline
33
+
34
+ def __call__(self, text_inputs, **kwargs):
35
+ outputs = []
36
+ for text_input in text_inputs:
37
+ raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
38
+ outputs.append(raw_outputs['probs'])
39
+ return outputs
40
+
41
+ class SentAttacker:
42
+
43
+ def __init__(self, model, recipe_class=BAEGarg2019):
44
+ model = model
45
+ model_wrapper = DemoModelWrapper(model)
46
+
47
+ recipe = recipe_class.build(model_wrapper)
48
+
49
+ _dataset = [('', 0)]
50
+ _dataset = Dataset(_dataset)
51
+
52
+ self.attacker = Attacker(recipe, _dataset)
53
+
54
+ attackers = {
55
+ 'bae': BAEGarg2019,
56
+ 'pwws': PWWSRen2019,
57
+ 'textfooler': TextFoolerJin2019,
58
+ 'pso': PSOZang2020,
59
+ 'iga': IGAWang2019,
60
+ 'ga': GeneticAlgorithmAlzantot2018,
61
+ 'wordbugger': DeepWordBugGao2018,
62
+ }
63
+ return SentAttacker(tad_classifier, attackers[defense])
64
+ except Exception as e:
65
+
66
+ print('Original error:', e)
67
+
68
+
69
+ def get_mlm_and_tokenizer(text_classifier, config):
70
+ if isinstance(text_classifier, TADTextClassifier):
71
+ base_model = text_classifier.model.bert.base_model
72
+ else:
73
+ base_model = text_classifier.bert.base_model
74
+ pretrained_config = AutoConfig.from_pretrained(config.pretrained_bert)
75
+ if 'deberta-v3' in config.pretrained_bert:
76
+ MLM = DebertaV2ForMaskedLM(pretrained_config)
77
+ MLM.deberta = base_model
78
+ elif 'roberta' in config.pretrained_bert:
79
+ MLM = RobertaForMaskedLM(pretrained_config)
80
+ MLM.roberta = base_model
81
+ else:
82
+ MLM = BertForMaskedLM(pretrained_config)
83
+ MLM.bert = base_model
84
+ return MLM, AutoTokenizer.from_pretrained(config.pretrained_bert)
85
+
86
+
87
+ class TADTextClassifier:
88
+ def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):
89
+ '''
90
+ from_train_model: load inference model from trained model
91
+ '''
92
+ self.cal_perplexity = cal_perplexity
93
+ # load from a training
94
+ if not isinstance(model_arg, str):
95
+ print('Load text classifier from training')
96
+ self.model = model_arg[0]
97
+ self.opt = model_arg[1]
98
+ self.tokenizer = model_arg[2]
99
+ else:
100
+ try:
101
+ if 'fine-tuned' in model_arg:
102
+ raise ValueError(
103
+ 'Do not support to directly load a fine-tuned model, please load a .state_dict or .model instead!')
104
+ print('Load text classifier from', model_arg)
105
+ state_dict_path = find_file(model_arg, key='.state_dict', exclude_key=['__MACOSX'])
106
+ model_path = find_file(model_arg, key='.model', exclude_key=['__MACOSX'])
107
+ tokenizer_path = find_file(model_arg, key='.tokenizer', exclude_key=['__MACOSX'])
108
+ config_path = find_file(model_arg, key='.config', exclude_key=['__MACOSX'])
109
+
110
+ print('config: {}'.format(config_path))
111
+ print('state_dict: {}'.format(state_dict_path))
112
+ print('model: {}'.format(model_path))
113
+ print('tokenizer: {}'.format(tokenizer_path))
114
+
115
+ with open(config_path, mode='rb') as f:
116
+ self.opt = pickle.load(f)
117
+ self.opt.device = get_device(kwargs.pop('auto_device', True))[0]
118
+
119
+ if state_dict_path or model_path:
120
+ if hasattr(BERTTADModelList, self.opt.model.__name__):
121
+ if state_dict_path:
122
+ if kwargs.pop('offline', False):
123
+ self.bert = AutoModel.from_pretrained(
124
+ find_cwd_dir(self.opt.pretrained_bert.split('/')[-1]))
125
+ else:
126
+ self.bert = AutoModel.from_pretrained(self.opt.pretrained_bert)
127
+ self.model = self.opt.model(self.bert, self.opt)
128
+ self.model.load_state_dict(torch.load(state_dict_path, map_location='cpu'))
129
+ elif model_path:
130
+ self.model = torch.load(model_path, map_location='cpu')
131
+
132
+ try:
133
+ self.tokenizer = Tokenizer4Pretraining(max_seq_len=self.opt.max_seq_len, opt=self.opt,
134
+ **kwargs)
135
+ except ValueError:
136
+ if tokenizer_path:
137
+ with open(tokenizer_path, mode='rb') as f:
138
+ self.tokenizer = pickle.load(f)
139
+ else:
140
+ raise TransformerConnectionError()
141
+
142
+ except Exception as e:
143
+ raise RuntimeError('Exception: {} Fail to load the model from {}! '.format(e, model_arg))
144
+
145
+ self.infer_dataloader = None
146
+ self.opt.eval_batch_size = kwargs.pop('eval_batch_size', 128)
147
+
148
+ self.opt.initializer = self.opt.initializer
149
+
150
+ if self.cal_perplexity:
151
+ try:
152
+ self.MLM, self.MLM_tokenizer = get_mlm_and_tokenizer(self, self.opt)
153
+ except Exception as e:
154
+ self.MLM, self.MLM_tokenizer = None, None
155
+
156
+ self.to(self.opt.device)
157
+
158
+ def to(self, device=None):
159
+ self.opt.device = device
160
+ self.model.to(device)
161
+ if hasattr(self, 'MLM'):
162
+ self.MLM.to(self.opt.device)
163
+
164
+ def cpu(self):
165
+ self.opt.device = 'cpu'
166
+ self.model.to('cpu')
167
+ if hasattr(self, 'MLM'):
168
+ self.MLM.to('cpu')
169
+
170
+ def cuda(self, device='cuda:0'):
171
+ self.opt.device = device
172
+ self.model.to(device)
173
+ if hasattr(self, 'MLM'):
174
+ self.MLM.to(device)
175
+
176
+ def _log_write_args(self):
177
+ n_trainable_params, n_nontrainable_params = 0, 0
178
+ for p in self.model.parameters():
179
+ n_params = torch.prod(torch.tensor(p.shape))
180
+ if p.requires_grad:
181
+ n_trainable_params += n_params
182
+ else:
183
+ n_nontrainable_params += n_params
184
+ print(
185
+ 'n_trainable_params: {0}, n_nontrainable_params: {1}'.format(n_trainable_params, n_nontrainable_params))
186
+ for arg in vars(self.opt):
187
+ if getattr(self.opt, arg) is not None:
188
+ print('>>> {0}: {1}'.format(arg, getattr(self.opt, arg)))
189
+
190
+ def batch_infer(self,
191
+ target_file=None,
192
+ print_result=True,
193
+ save_result=False,
194
+ ignore_error=True,
195
+ defense: str = None
196
+ ):
197
+
198
+ save_path = os.path.join(os.getcwd(), 'tad_text_classification.result.json')
199
+
200
+ target_file = detect_infer_dataset(target_file, task='text_defense')
201
+ if not target_file:
202
+ raise FileNotFoundError('Can not find inference datasets!')
203
+
204
+ if hasattr(BERTTADModelList, self.opt.model.__name__):
205
+ dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
206
+
207
+ dataset.prepare_infer_dataset(target_file, ignore_error=ignore_error)
208
+ self.infer_dataloader = DataLoader(dataset=dataset, batch_size=self.opt.eval_batch_size, pin_memory=True,
209
+ shuffle=False)
210
+ return self._infer(save_path=save_path if save_result else None, print_result=print_result, defense=defense)
211
+
212
+ def infer(self,
213
+ text: str = None,
214
+ print_result=True,
215
+ ignore_error=True,
216
+ defense: str = None
217
+ ):
218
+
219
+ if hasattr(BERTTADModelList, self.opt.model.__name__):
220
+ dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
221
+
222
+ if text:
223
+ dataset.prepare_infer_sample(text, ignore_error=ignore_error)
224
+ else:
225
+ raise RuntimeError('Please specify your datasets path!')
226
+ self.infer_dataloader = DataLoader(dataset=dataset, batch_size=self.opt.eval_batch_size, shuffle=False)
227
+ return self._infer(print_result=print_result, defense=defense)[0]
228
+
229
+ def _infer(self, save_path=None, print_result=True, defense=None):
230
+
231
+ _params = filter(lambda p: p.requires_grad, self.model.parameters())
232
+
233
+ correct = {True: 'Correct', False: 'Wrong'}
234
+ results = []
235
+
236
+ with torch.no_grad():
237
+ self.model.eval()
238
+ n_correct = 0
239
+ n_labeled = 0
240
+
241
+ n_advdet_correct = 0
242
+ n_advdet_labeled = 0
243
+ if len(self.infer_dataloader.dataset) >= 100:
244
+ it = tqdm.tqdm(self.infer_dataloader, postfix='inferring...')
245
+ else:
246
+ it = self.infer_dataloader
247
+ for _, sample in enumerate(it):
248
+ inputs = [sample[col].to(self.opt.device) for col in self.opt.inputs_cols]
249
+ outputs = self.model(inputs)
250
+ logits, advdet_logits, adv_tr_logits = outputs['sent_logits'], outputs['advdet_logits'], outputs[
251
+ 'adv_tr_logits']
252
+ probs, advdet_probs, adv_tr_probs = torch.softmax(logits, dim=-1), torch.softmax(advdet_logits,
253
+ dim=-1), torch.softmax(
254
+ adv_tr_logits, dim=-1)
255
+
256
+ for i, (prob, advdet_prob, adv_tr_prob) in enumerate(zip(probs, advdet_probs, adv_tr_probs)):
257
+ text_raw = sample['text_raw'][i]
258
+
259
+ pred_label = int(prob.argmax(axis=-1))
260
+ pred_is_adv_label = int(advdet_prob.argmax(axis=-1))
261
+ pred_adv_tr_label = int(adv_tr_prob.argmax(axis=-1))
262
+ ref_label = int(sample['label'][i]) if int(sample['label'][i]) in self.opt.index_to_label else ''
263
+ ref_is_adv_label = int(sample['is_adv'][i]) if int(
264
+ sample['is_adv'][i]) in self.opt.index_to_is_adv else ''
265
+ ref_adv_tr_label = int(sample['adv_train_label'][i]) if int(
266
+ sample['adv_train_label'][i]) in self.opt.index_to_adv_train_label else ''
267
+
268
+ if self.cal_perplexity:
269
+ ids = self.MLM_tokenizer(text_raw, return_tensors="pt")
270
+ ids['labels'] = ids['input_ids'].clone()
271
+ ids = ids.to(self.opt.device)
272
+ loss = self.MLM(**ids)['loss']
273
+ perplexity = float(torch.exp(loss / ids['input_ids'].size(1)))
274
+ else:
275
+ perplexity = 'N.A.'
276
+
277
+ result = {
278
+ 'text': text_raw,
279
+
280
+ 'label': self.opt.index_to_label[pred_label],
281
+ 'probs': prob.cpu().numpy(),
282
+ 'confidence': float(max(prob)),
283
+ 'ref_label': self.opt.index_to_label[ref_label] if isinstance(ref_label, int) else ref_label,
284
+ 'ref_label_check': correct[pred_label == ref_label] if ref_label != -100 else '',
285
+ 'is_fixed': False,
286
+
287
+ 'is_adv_label': self.opt.index_to_is_adv[pred_is_adv_label],
288
+ 'is_adv_probs': advdet_prob.cpu().numpy(),
289
+ 'is_adv_confidence': float(max(advdet_prob)),
290
+ 'ref_is_adv_label': self.opt.index_to_is_adv[ref_is_adv_label] if isinstance(ref_is_adv_label, int) else ref_is_adv_label,
291
+ 'ref_is_adv_check': correct[pred_is_adv_label == ref_is_adv_label] if ref_is_adv_label != -100 and isinstance(ref_is_adv_label, int) else '',
292
+
293
+ 'pred_adv_tr_label': self.opt.index_to_label[pred_adv_tr_label],
294
+ 'ref_adv_tr_label': self.opt.index_to_label[ref_adv_tr_label],
295
+
296
+ 'perplexity': perplexity,
297
+ }
298
+ if defense:
299
+ try:
300
+ if not hasattr(self, 'sent_attacker'):
301
+ self.sent_attacker = init_attacker(self, defense.lower())
302
+ if result['is_adv_label'] == '1':
303
+ res = self.sent_attacker.attacker.simple_attack(text_raw, int(result['label']))
304
+ new_infer_res = self.infer(res.perturbed_result.attacked_text.text, print_result=False)
305
+ result['perturbed_label'] = result['label']
306
+ result['label'] = new_infer_res['label']
307
+ result['probs'] = new_infer_res['probs']
308
+ result['ref_label_check'] = correct[int(result['label']) == ref_label] if ref_label != -100 else ''
309
+ result['restored_text'] = res.perturbed_result.attacked_text.text
310
+ result['is_fixed'] = True
311
+ else:
312
+ result['restored_text'] = ''
313
+ result['is_fixed'] = False
314
+
315
+ except Exception as e:
316
+ print('Error:{}, try install TextAttack and tensorflow_text after 10 seconds...'.format(e))
317
+ time.sleep(10)
318
+ raise RuntimeError('Installation done, please run again...')
319
+
320
+ if ref_label != -100:
321
+ n_labeled += 1
322
+
323
+ if result['label'] == result['ref_label']:
324
+ n_correct += 1
325
+
326
+ if ref_is_adv_label != -100:
327
+ n_advdet_labeled += 1
328
+ if ref_is_adv_label == pred_is_adv_label:
329
+ n_advdet_correct += 1
330
+
331
+ results.append(result)
332
+
333
+ try:
334
+ if print_result:
335
+ for ex_id, result in enumerate(results):
336
+ text_printing = result['text'][:]
337
+ text_info = ''
338
+ if result['label'] != '-100':
339
+ if not result['ref_label']:
340
+ text_info += ' -> <CLS:{}(ref:{} confidence:{})>'.format(result['label'],
341
+ result['ref_label'],
342
+ result['confidence'])
343
+ elif result['label'] == result['ref_label']:
344
+ text_info += colored(
345
+ ' -> <CLS:{}(ref:{} confidence:{})>'.format(result['label'], result['ref_label'],
346
+ result['confidence']), 'green')
347
+ else:
348
+ text_info += colored(
349
+ ' -> <CLS:{}(ref:{} confidence:{})>'.format(result['label'], result['ref_label'],
350
+ result['confidence']), 'red')
351
+
352
+ # AdvDet
353
+ if result['is_adv_label'] != '-100':
354
+ if not result['ref_is_adv_label']:
355
+ text_info += ' -> <AdvDet:{}(ref:{} confidence:{})>'.format(result['is_adv_label'],
356
+ result['ref_is_adv_check'],
357
+ result['is_adv_confidence'])
358
+ elif result['is_adv_label'] == result['ref_is_adv_label']:
359
+ text_info += colored(' -> <AdvDet:{}(ref:{} confidence:{})>'.format(result['is_adv_label'],
360
+ result[
361
+ 'ref_is_adv_label'],
362
+ result[
363
+ 'is_adv_confidence']),
364
+ 'green')
365
+ else:
366
+ text_info += colored(' -> <AdvDet:{}(ref:{} confidence:{})>'.format(result['is_adv_label'],
367
+ result[
368
+ 'ref_is_adv_label'],
369
+ result[
370
+ 'is_adv_confidence']),
371
+ 'red')
372
+ text_printing += text_info
373
+ if self.cal_perplexity:
374
+ text_printing += colored(' --> <perplexity:{}>'.format(result['perplexity']), 'yellow')
375
+ print('Example {}: {}'.format(ex_id, text_printing))
376
+ if save_path:
377
+ with open(save_path, 'w', encoding='utf8') as fout:
378
+ json.dump(str(results), fout, ensure_ascii=False)
379
+ print('inference result saved in: {}'.format(save_path))
380
+ except Exception as e:
381
+ print('Can not save result: {}, Exception: {}'.format(text_raw, e))
382
+
383
+ if len(results) > 1:
384
+ print('CLS Acc:{}%'.format(100 * n_correct / n_labeled if n_labeled else ''))
385
+ print('AdvDet Acc:{}%'.format(100 * n_advdet_correct / n_advdet_labeled if n_advdet_labeled else ''))
386
+
387
+ return results
388
+
389
+ def clear_input_samples(self):
390
+ self.dataset.all_data = []
anonymous_demo/functional/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ from anonymous_demo.functional.checkpoint.checkpoint_manager import TADCheckpointManager
2
+
3
+ from anonymous_demo.functional.config import TADConfigManager
anonymous_demo/functional/checkpoint/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .checkpoint_manager import TADCheckpointManager
anonymous_demo/functional/checkpoint/checkpoint_manager.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from findfile import find_file
3
+
4
+ from anonymous_demo.core.tad.prediction.tad_classifier import TADTextClassifier
5
+ from anonymous_demo.utils.demo_utils import retry
6
+
7
+
8
+ class CheckpointManager:
9
+ pass
10
+
11
+
12
+ class TADCheckpointManager(CheckpointManager):
13
+ @staticmethod
14
+ @retry
15
+ def get_tad_text_classifier(checkpoint: str = None,
16
+ eval_batch_size=128,
17
+ **kwargs):
18
+
19
+ tad_text_classifier = TADTextClassifier(checkpoint, eval_batch_size=eval_batch_size, **kwargs)
20
+ return tad_text_classifier
anonymous_demo/functional/config/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .tad_config_manager import TADConfigManager
anonymous_demo/functional/config/config_manager.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+
3
+ import torch
4
+
5
+ one_shot_messages = set()
6
+
7
+
8
+ def config_check(args):
9
+ pass
10
+
11
+
12
+ class ConfigManager(Namespace):
13
+
14
+ def __init__(self, args=None, **kwargs):
15
+ """
16
+ The ConfigManager is a subclass of argparse.Namespace and based on parameter dict and count the call-frequency of each parameter
17
+ :param args: A parameter dict
18
+ :param kwargs: Same param as Namespce
19
+ """
20
+ if not args:
21
+ args = {}
22
+ super().__init__(**kwargs)
23
+
24
+ if isinstance(args, Namespace):
25
+ self.args = vars(args)
26
+ self.args_call_count = {arg: 0 for arg in vars(args)}
27
+ else:
28
+ self.args = args
29
+ self.args_call_count = {arg: 0 for arg in args}
30
+
31
+ def __getattribute__(self, arg_name):
32
+ if arg_name == 'args' or arg_name == 'args_call_count':
33
+ return super().__getattribute__(arg_name)
34
+ try:
35
+ value = super().__getattribute__('args')[arg_name]
36
+ args_call_count = super().__getattribute__('args_call_count')
37
+ args_call_count[arg_name] += 1
38
+ super().__setattr__('args_call_count', args_call_count)
39
+ return value
40
+
41
+ except Exception as e:
42
+
43
+ return super().__getattribute__(arg_name)
44
+
45
+ def __setattr__(self, arg_name, value):
46
+ if arg_name == 'args' or arg_name == 'args_call_count':
47
+ super().__setattr__(arg_name, value)
48
+ return
49
+ try:
50
+ args = super().__getattribute__('args')
51
+ args[arg_name] = value
52
+ super().__setattr__('args', args)
53
+ args_call_count = super().__getattribute__('args_call_count')
54
+
55
+ if arg_name in args_call_count:
56
+ # args_call_count[arg_name] += 1
57
+ super().__setattr__('args_call_count', args_call_count)
58
+
59
+ else:
60
+ args_call_count[arg_name] = 0
61
+ super().__setattr__('args_call_count', args_call_count)
62
+
63
+ except Exception as e:
64
+ super().__setattr__(arg_name, value)
65
+
66
+ config_check(args)
anonymous_demo/functional/config/tad_config_manager.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from anonymous_demo.functional.config.config_manager import ConfigManager
4
+ from anonymous_demo.core.tad.classic.__bert__.models import TADBERT
5
+
6
+ _tad_config_template = {'model': TADBERT,
7
+ 'optimizer': "adamw",
8
+ 'learning_rate': 0.00002,
9
+ 'patience': 99999,
10
+ 'pretrained_bert': "microsoft/mdeberta-v3-base",
11
+ 'cache_dataset': True,
12
+ 'warmup_step': -1,
13
+ 'show_metric': False,
14
+ 'max_seq_len': 80,
15
+ 'dropout': 0,
16
+ 'l2reg': 0.000001,
17
+ 'num_epoch': 10,
18
+ 'batch_size': 16,
19
+ 'initializer': 'xavier_uniform_',
20
+ 'seed': 52,
21
+ 'polarities_dim': 3,
22
+ 'log_step': 10,
23
+ 'evaluate_begin': 0,
24
+ 'cross_validate_fold': -1,
25
+ 'use_amp': False,
26
+ # split train and test datasets into 5 folds and repeat 3 training
27
+ }
28
+
29
+ _tad_config_base = {'model': TADBERT,
30
+ 'optimizer': "adamw",
31
+ 'learning_rate': 0.00002,
32
+ 'pretrained_bert': "microsoft/deberta-v3-base",
33
+ 'cache_dataset': True,
34
+ 'warmup_step': -1,
35
+ 'show_metric': False,
36
+ 'max_seq_len': 80,
37
+ 'patience': 99999,
38
+ 'dropout': 0,
39
+ 'l2reg': 0.000001,
40
+ 'num_epoch': 10,
41
+ 'batch_size': 16,
42
+ 'initializer': 'xavier_uniform_',
43
+ 'seed': 52,
44
+ 'polarities_dim': 3,
45
+ 'log_step': 10,
46
+ 'evaluate_begin': 0,
47
+ 'cross_validate_fold': -1
48
+ # split train and test datasets into 5 folds and repeat 3 training
49
+ }
50
+
51
+ _tad_config_english = {'model': TADBERT,
52
+ 'optimizer': "adamw",
53
+ 'learning_rate': 0.00002,
54
+ 'patience': 99999,
55
+ 'pretrained_bert': "microsoft/deberta-v3-base",
56
+ 'cache_dataset': True,
57
+ 'warmup_step': -1,
58
+ 'show_metric': False,
59
+ 'max_seq_len': 80,
60
+ 'dropout': 0,
61
+ 'l2reg': 0.000001,
62
+ 'num_epoch': 10,
63
+ 'batch_size': 16,
64
+ 'initializer': 'xavier_uniform_',
65
+ 'seed': 52,
66
+ 'polarities_dim': 3,
67
+ 'log_step': 10,
68
+ 'evaluate_begin': 0,
69
+ 'cross_validate_fold': -1
70
+ # split train and test datasets into 5 folds and repeat 3 training
71
+ }
72
+
73
+ _tad_config_multilingual = {'model': TADBERT,
74
+ 'optimizer': "adamw",
75
+ 'learning_rate': 0.00002,
76
+ 'patience': 99999,
77
+ 'pretrained_bert': "microsoft/mdeberta-v3-base",
78
+ 'cache_dataset': True,
79
+ 'warmup_step': -1,
80
+ 'show_metric': False,
81
+ 'max_seq_len': 80,
82
+ 'dropout': 0,
83
+ 'l2reg': 0.000001,
84
+ 'num_epoch': 10,
85
+ 'batch_size': 16,
86
+ 'initializer': 'xavier_uniform_',
87
+ 'seed': 52,
88
+ 'polarities_dim': 3,
89
+ 'log_step': 10,
90
+ 'evaluate_begin': 0,
91
+ 'cross_validate_fold': -1
92
+ # split train and test datasets into 5 folds and repeat 3 training
93
+ }
94
+
95
+ _tad_config_chinese = {'model': TADBERT,
96
+ 'optimizer': "adamw",
97
+ 'learning_rate': 0.00002,
98
+ 'patience': 99999,
99
+ 'cache_dataset': True,
100
+ 'warmup_step': -1,
101
+ 'show_metric': False,
102
+ 'pretrained_bert': "bert-base-chinese",
103
+ 'max_seq_len': 80,
104
+ 'dropout': 0,
105
+ 'l2reg': 0.000001,
106
+ 'num_epoch': 10,
107
+ 'batch_size': 16,
108
+ 'initializer': 'xavier_uniform_',
109
+ 'seed': 52,
110
+ 'polarities_dim': 3,
111
+ 'log_step': 10,
112
+ 'evaluate_begin': 0,
113
+ 'cross_validate_fold': -1
114
+ # split train and test datasets into 5 folds and repeat 3 training
115
+ }
116
+
117
+
118
+ class TADConfigManager(ConfigManager):
119
+ def __init__(self, args, **kwargs):
120
+ """
121
+ Available Params: {'model': BERT,
122
+ 'optimizer': "adamw",
123
+ 'learning_rate': 0.00002,
124
+ 'pretrained_bert': "roberta-base",
125
+ 'cache_dataset': True,
126
+ 'warmup_step': -1,
127
+ 'show_metric': False,
128
+ 'max_seq_len': 80,
129
+ 'patience': 99999,
130
+ 'dropout': 0,
131
+ 'l2reg': 0.000001,
132
+ 'num_epoch': 10,
133
+ 'batch_size': 16,
134
+ 'initializer': 'xavier_uniform_',
135
+ 'seed': {52, 25}
136
+ 'embed_dim': 768,
137
+ 'hidden_dim': 768,
138
+ 'polarities_dim': 3,
139
+ 'log_step': 10,
140
+ 'evaluate_begin': 0,
141
+ 'cross_validate_fold': -1 # split train and test datasets into 5 folds and repeat 3 training
142
+ }
143
+ :param args:
144
+ :param kwargs:
145
+ """
146
+ super().__init__(args, **kwargs)
147
+
148
+ @staticmethod
149
+ def set_tad_config(configType: str, newitem: dict):
150
+ if isinstance(newitem, dict):
151
+ if configType == 'template':
152
+ _tad_config_template.update(newitem)
153
+ elif configType == 'base':
154
+ _tad_config_base.update(newitem)
155
+ elif configType == 'english':
156
+ _tad_config_english.update(newitem)
157
+ elif configType == 'chinese':
158
+ _tad_config_chinese.update(newitem)
159
+ elif configType == 'multilingual':
160
+ _tad_config_multilingual.update(newitem)
161
+ elif configType == 'glove':
162
+ _tad_config_glove.update(newitem)
163
+ else:
164
+ raise ValueError(
165
+ "Wrong value of config type supplied, please use one from following type: template, base, english, chinese, multilingual, glove")
166
+ else:
167
+ raise TypeError("Wrong type of new config item supplied, please use dict e.g.{'NewConfig': NewValue}")
168
+
169
+ @staticmethod
170
+ def set_tad_config_template(newitem):
171
+ TADConfigManager.set_tad_config('template', newitem)
172
+
173
+ @staticmethod
174
+ def set_tad_config_base(newitem):
175
+ TADConfigManager.set_tad_config('base', newitem)
176
+
177
+ @staticmethod
178
+ def set_tad_config_english(newitem):
179
+ TADConfigManager.set_tad_config('english', newitem)
180
+
181
+ @staticmethod
182
+ def set_tad_config_chinese(newitem):
183
+ TADConfigManager.set_tad_config('chinese', newitem)
184
+
185
+ @staticmethod
186
+ def set_tad_config_multilingual(newitem):
187
+ TADConfigManager.set_tad_config('multilingual', newitem)
188
+
189
+ @staticmethod
190
+ def set_tad_config_glove(newitem):
191
+ TADConfigManager.set_tad_config('glove', newitem)
192
+
193
+ @staticmethod
194
+ def get_tad_config_template() -> ConfigManager:
195
+ _tad_config_template.update(_tad_config_template)
196
+ return TADConfigManager(copy.deepcopy(_tad_config_template))
197
+
198
+ @staticmethod
199
+ def get_tad_config_base() -> ConfigManager:
200
+ _tad_config_template.update(_tad_config_base)
201
+ return TADConfigManager(copy.deepcopy(_tad_config_template))
202
+
203
+ @staticmethod
204
+ def get_tad_config_english() -> ConfigManager:
205
+ _tad_config_template.update(_tad_config_english)
206
+ return TADConfigManager(copy.deepcopy(_tad_config_template))
207
+
208
+ @staticmethod
209
+ def get_tad_config_chinese() -> ConfigManager:
210
+ _tad_config_template.update(_tad_config_chinese)
211
+ return TADConfigManager(copy.deepcopy(_tad_config_template))
212
+
213
+ @staticmethod
214
+ def get_tad_config_multilingual() -> ConfigManager:
215
+ _tad_config_template.update(_tad_config_multilingual)
216
+ return TADConfigManager(copy.deepcopy(_tad_config_template))
217
+
218
+ @staticmethod
219
+ def get_tad_config_glove() -> ConfigManager:
220
+ _tad_config_template.update(_tad_config_glove)
221
+ return TADConfigManager(copy.deepcopy(_tad_config_template))
anonymous_demo/functional/dataset/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from anonymous_demo.functional.dataset.dataset_manager import (detect_infer_dataset)
anonymous_demo/functional/dataset/dataset_manager.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from findfile import find_files, find_dir
3
+
4
+ filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip',
5
+ '.state_dict', '.model', '.png', 'acc_', 'f1_', '.backup', '.bak']
6
+
7
+
8
+ def detect_infer_dataset(dataset_path, task='apc'):
9
+ dataset_file = []
10
+ if isinstance(dataset_path, str) and os.path.isfile(dataset_path):
11
+ dataset_file.append(dataset_path)
12
+ return dataset_file
13
+
14
+ for d in dataset_path:
15
+ if not os.path.exists(d):
16
+ search_path = find_dir(os.getcwd(), [d, task, 'dataset'], exclude_key=filter_key_words, disable_alert=False)
17
+ dataset_file += find_files(search_path, ['.inference', d], exclude_key=['train.'] + filter_key_words)
18
+ else:
19
+ dataset_file += find_files(d, ['.inference', task], exclude_key=['train.'] + filter_key_words)
20
+
21
+ return dataset_file
anonymous_demo/network/__init__.py ADDED
File without changes
anonymous_demo/network/lcf_pooler.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class LCF_Pooler(nn.Module):
7
+ def __init__(self, config):
8
+ super().__init__()
9
+ self.config = config
10
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
11
+ self.activation = nn.Tanh()
12
+
13
+ def forward(self, hidden_states, lcf_vec):
14
+ device = hidden_states.device
15
+ lcf_vec = lcf_vec.detach().cpu().numpy()
16
+
17
+ pooled_output = numpy.zeros((hidden_states.shape[0], hidden_states.shape[2]), dtype=numpy.float32)
18
+ hidden_states = hidden_states.detach().cpu().numpy()
19
+ for i, vec in enumerate(lcf_vec):
20
+ lcf_ids = [j for j in range(len(vec)) if sum(vec[j] - 1.) == 0]
21
+ pooled_output[i] = hidden_states[i][lcf_ids[len(lcf_ids) // 2]]
22
+
23
+ pooled_output = torch.Tensor(pooled_output).to(device)
24
+ pooled_output = self.dense(pooled_output)
25
+ pooled_output = self.activation(pooled_output)
26
+ return pooled_output
anonymous_demo/network/lsa.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from anonymous_demo.network.sa_encoder import Encoder
3
+ from torch import nn
4
+
5
+
6
+ class LSA(nn.Module):
7
+ def __init__(self, bert, opt):
8
+ super(LSA, self).__init__()
9
+ self.opt = opt
10
+
11
+ self.encoder = Encoder(bert.config, opt)
12
+ self.encoder_left = Encoder(bert.config, opt)
13
+ self.encoder_right = Encoder(bert.config, opt)
14
+ self.linear_window_3h = nn.Linear(opt.embed_dim * 3, opt.embed_dim)
15
+ self.linear_window_2h = nn.Linear(opt.embed_dim * 2, opt.embed_dim)
16
+ self.eta1 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
17
+ self.eta2 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
18
+
19
+ def forward(self, global_context_features, spc_mask_vec, lcf_matrix, left_lcf_matrix, right_lcf_matrix):
20
+ masked_global_context_features = torch.mul(spc_mask_vec, global_context_features)
21
+
22
+ # # --------------------------------------------------- #
23
+ lcf_features = torch.mul(global_context_features, lcf_matrix)
24
+ lcf_features = self.encoder(lcf_features)
25
+ # # --------------------------------------------------- #
26
+ left_lcf_features = torch.mul(masked_global_context_features, left_lcf_matrix)
27
+ left_lcf_features = self.encoder_left(left_lcf_features)
28
+ # # --------------------------------------------------- #
29
+ right_lcf_features = torch.mul(masked_global_context_features, right_lcf_matrix)
30
+ right_lcf_features = self.encoder_right(right_lcf_features)
31
+ # # --------------------------------------------------- #
32
+ if 'lr' == self.opt.window or 'rl' == self.opt.window:
33
+ if self.eta1 <= 0 and self.opt.eta != -1:
34
+ torch.nn.init.uniform_(self.eta1)
35
+ print('reset eta1 to: {}'.format(self.eta1.item()))
36
+ if self.eta2 <= 0 and self.opt.eta != -1:
37
+ torch.nn.init.uniform_(self.eta2)
38
+ print('reset eta2 to: {}'.format(self.eta2.item()))
39
+ if self.opt.eta >= 0:
40
+ cat_features = torch.cat((lcf_features, self.eta1 * left_lcf_features, self.eta2 * right_lcf_features),
41
+ -1)
42
+ else:
43
+ cat_features = torch.cat((lcf_features, left_lcf_features, right_lcf_features), -1)
44
+ sent_out = self.linear_window_3h(cat_features)
45
+ elif 'l' == self.opt.window:
46
+ sent_out = self.linear_window_2h(torch.cat((lcf_features, self.eta1 * left_lcf_features), -1))
47
+ elif 'r' == self.opt.window:
48
+ sent_out = self.linear_window_2h(torch.cat((lcf_features, self.eta2 * right_lcf_features), -1))
49
+ else:
50
+ raise KeyError('Invalid parameter:', self.opt.window)
51
+
52
+ return sent_out
anonymous_demo/network/sa_encoder.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class BertSelfAttention(nn.Module):
9
+ def __init__(self, config):
10
+ super().__init__()
11
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
12
+ raise ValueError(
13
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
14
+ f"heads ({config.num_attention_heads})"
15
+ )
16
+
17
+ self.num_attention_heads = config.num_attention_heads
18
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
19
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
20
+
21
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
22
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
23
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
24
+
25
+ self.dropout = nn.Dropout(
26
+ config.attention_probs_dropout_prob if hasattr(config, 'attention_probs_dropout_prob') else 0)
27
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
28
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
29
+ self.max_position_embeddings = config.max_position_embeddings
30
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
31
+
32
+ self.is_decoder = config.is_decoder
33
+
34
+ def transpose_for_scores(self, x):
35
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
36
+ x = x.view(*new_x_shape)
37
+ return x.permute(0, 2, 1, 3)
38
+
39
+ def forward(
40
+ self,
41
+ hidden_states,
42
+ attention_mask=None,
43
+ head_mask=None,
44
+ encoder_hidden_states=None,
45
+ encoder_attention_mask=None,
46
+ past_key_value=None,
47
+ output_attentions=False,
48
+ ):
49
+ mixed_query_layer = self.query(hidden_states)
50
+
51
+ # If this is instantiated as a cross-attention module, the keys
52
+ # and values come from an encoder; the attention mask needs to be
53
+ # such that the encoder's padding tokens are not attended to.
54
+ is_cross_attention = encoder_hidden_states is not None
55
+
56
+ if is_cross_attention and past_key_value is not None:
57
+ # reuse k,v, cross_attentions
58
+ key_layer = past_key_value[0]
59
+ value_layer = past_key_value[1]
60
+ attention_mask = encoder_attention_mask
61
+ elif is_cross_attention:
62
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
63
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
64
+ attention_mask = encoder_attention_mask
65
+ elif past_key_value is not None:
66
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
67
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
68
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
69
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
70
+ else:
71
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
72
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
73
+
74
+ query_layer = self.transpose_for_scores(mixed_query_layer)
75
+
76
+ if self.is_decoder:
77
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
78
+ # Further calls to cross_attention layer can then reuse all cross-attention
79
+ # key/value_states (first "if" case)
80
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
81
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
82
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
83
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
84
+ past_key_value = (key_layer, value_layer)
85
+
86
+ # Take the dot product between "query" and "key" to get the raw attention scores.
87
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
88
+
89
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
90
+ seq_length = hidden_states.size()[1]
91
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
92
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
93
+ distance = position_ids_l - position_ids_r
94
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
95
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
96
+
97
+ if self.position_embedding_type == "relative_key":
98
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
99
+ attention_scores = attention_scores + relative_position_scores
100
+ elif self.position_embedding_type == "relative_key_query":
101
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
102
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
103
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
104
+
105
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
106
+ if attention_mask is not None:
107
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
108
+ attention_scores = attention_scores + attention_mask
109
+
110
+ # Normalize the attention scores to probabilities.
111
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
112
+
113
+ # This is actually dropping out entire tokens to attend to, which might
114
+ # seem a bit unusual, but is taken from the original Transformer paper.
115
+ attention_probs = self.dropout(attention_probs)
116
+
117
+ # Mask heads if we want to
118
+ if head_mask is not None:
119
+ attention_probs = attention_probs * head_mask
120
+
121
+ context_layer = torch.matmul(attention_probs, value_layer)
122
+
123
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
124
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
125
+ context_layer = context_layer.view(*new_context_layer_shape)
126
+
127
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
128
+
129
+ if self.is_decoder:
130
+ outputs = outputs + (past_key_value,)
131
+ return outputs
132
+
133
+
134
+ class Encoder(nn.Module):
135
+ def __init__(self, config, opt, layer_num=1):
136
+ super(Encoder, self).__init__()
137
+ self.opt = opt
138
+ self.config = config
139
+ self.encoder = nn.ModuleList([SelfAttention(config, opt) for _ in range(layer_num)])
140
+ self.tanh = torch.nn.Tanh()
141
+
142
+ def forward(self, x):
143
+ for i, enc in enumerate(self.encoder):
144
+ x = self.tanh(enc(x)[0])
145
+ return x
146
+
147
+
148
+ class SelfAttention(nn.Module):
149
+ def __init__(self, config, opt):
150
+ super(SelfAttention, self).__init__()
151
+ self.opt = opt
152
+ self.config = config
153
+ self.SA = BertSelfAttention(config)
154
+
155
+ def forward(self, inputs):
156
+ zero_vec = np.zeros((inputs.size(0), 1, 1, self.opt.max_seq_len))
157
+ zero_tensor = torch.tensor(zero_vec).float().to(inputs.device)
158
+ SA_out = self.SA(inputs, zero_tensor)
159
+ return SA_out
anonymous_demo/utils/__init__.py ADDED
File without changes
anonymous_demo/utils/demo_utils.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pickle
4
+ import signal
5
+ import threading
6
+ import time
7
+ import zipfile
8
+
9
+ import gdown
10
+ import numpy as np
11
+ import requests
12
+ import torch
13
+ import tqdm
14
+ from autocuda import auto_cuda, auto_cuda_name
15
+ from findfile import find_files, find_cwd_file, find_file
16
+ from termcolor import colored
17
+ from functools import wraps
18
+
19
+ from update_checker import parse_version
20
+
21
+ from anonymous_demo import __version__
22
+
23
+
24
+ def save_args(config, save_path):
25
+ f = open(os.path.join(save_path), mode='w', encoding='utf8')
26
+ for arg in config.args:
27
+ if config.args_call_count[arg]:
28
+ f.write('{}: {}\n'.format(arg, config.args[arg]))
29
+ f.close()
30
+
31
+
32
+ def print_args(config, logger=None, mode=0):
33
+ args = [key for key in sorted(config.args.keys())]
34
+ for arg in args:
35
+ if logger:
36
+ logger.info('{0}:{1}\t-->\tCalling Count:{2}'.format(arg, config.args[arg], config.args_call_count[arg]))
37
+ else:
38
+ print('{0}:{1}\t-->\tCalling Count:{2}'.format(arg, config.args[arg], config.args_call_count[arg]))
39
+
40
+
41
+ def check_and_fix_labels(label_set: set, label_name, all_data, opt):
42
+ if '-100' in label_set:
43
+
44
+ label_to_index = {origin_label: int(idx) - 1 if origin_label != '-100' else -100 for origin_label, idx in zip(sorted(label_set), range(len(label_set)))}
45
+ index_to_label = {int(idx) - 1 if origin_label != '-100' else -100: origin_label for origin_label, idx in zip(sorted(label_set), range(len(label_set)))}
46
+ else:
47
+ label_to_index = {origin_label: int(idx) for origin_label, idx in zip(sorted(label_set), range(len(label_set)))}
48
+ index_to_label = {int(idx): origin_label for origin_label, idx in zip(sorted(label_set), range(len(label_set)))}
49
+ if 'index_to_label' not in opt.args:
50
+ opt.index_to_label = index_to_label
51
+ opt.label_to_index = label_to_index
52
+
53
+ if opt.index_to_label != index_to_label:
54
+ opt.index_to_label.update(index_to_label)
55
+ opt.label_to_index.update(label_to_index)
56
+ num_label = {l: 0 for l in label_set}
57
+ num_label['Sum'] = len(all_data)
58
+ for item in all_data:
59
+ try:
60
+ num_label[item[label_name]] += 1
61
+ item[label_name] = label_to_index[item[label_name]]
62
+ except Exception as e:
63
+ # print(e)
64
+ num_label[item.polarity] += 1
65
+ item.polarity = label_to_index[item.polarity]
66
+ print('Dataset Label Details: {}'.format(num_label))
67
+
68
+
69
+ def check_and_fix_IOB_labels(label_map, opt):
70
+ index_to_IOB_label = {int(label_map[origin_label]): origin_label for origin_label in label_map}
71
+ opt.index_to_IOB_label = index_to_IOB_label
72
+
73
+
74
+ def get_device(auto_device):
75
+ if isinstance(auto_device, str) and auto_device == 'allcuda':
76
+ device = 'cuda'
77
+ elif isinstance(auto_device, str):
78
+ device = auto_device
79
+ elif isinstance(auto_device, bool):
80
+ device = auto_cuda() if auto_device else 'cpu'
81
+ else:
82
+ device = auto_cuda()
83
+ try:
84
+ torch.device(device)
85
+ except RuntimeError as e:
86
+ print(colored('Device assignment error: {}, redirect to CPU'.format(e), 'red'))
87
+ device = 'cpu'
88
+ device_name = auto_cuda_name()
89
+ return device, device_name
90
+
91
+
92
+ def _load_word_vec(path, word2idx=None, embed_dim=300):
93
+ fin = open(path, 'r', encoding='utf-8', newline='\n', errors='ignore')
94
+ word_vec = {}
95
+ for line in tqdm.tqdm(fin.readlines(), postfix='Loading embedding file...'):
96
+ tokens = line.rstrip().split()
97
+ word, vec = ' '.join(tokens[:-embed_dim]), tokens[-embed_dim:]
98
+ if word in word2idx.keys():
99
+ word_vec[word] = np.asarray(vec, dtype='float32')
100
+ return word_vec
101
+
102
+
103
+ def build_embedding_matrix(word2idx, embed_dim, dat_fname, opt):
104
+ if not os.path.exists('run'):
105
+ os.makedirs('run')
106
+ embed_matrix_path = 'run/{}'.format(os.path.join(opt.dataset_name, dat_fname))
107
+ if os.path.exists(embed_matrix_path):
108
+ print(colored('Loading cached embedding_matrix from {} (Please remove all cached files if there is any problem!)'.format(embed_matrix_path), 'green'))
109
+ embedding_matrix = pickle.load(open(embed_matrix_path, 'rb'))
110
+ else:
111
+ glove_path = prepare_glove840_embedding(embed_matrix_path)
112
+ embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim))
113
+
114
+ word_vec = _load_word_vec(glove_path, word2idx=word2idx, embed_dim=embed_dim)
115
+
116
+ for word, i in tqdm.tqdm(word2idx.items(), postfix=colored('Building embedding_matrix {}'.format(dat_fname), 'yellow')):
117
+ vec = word_vec.get(word)
118
+ if vec is not None:
119
+ embedding_matrix[i] = vec
120
+ pickle.dump(embedding_matrix, open(embed_matrix_path, 'wb'))
121
+ return embedding_matrix
122
+
123
+
124
+ def pad_and_truncate(sequence, maxlen, dtype='int64', padding='post', truncating='post', value=0):
125
+ x = (np.ones(maxlen) * value).astype(dtype)
126
+ if truncating == 'pre':
127
+ trunc = sequence[-maxlen:]
128
+ else:
129
+ trunc = sequence[:maxlen]
130
+ trunc = np.asarray(trunc, dtype=dtype)
131
+ if padding == 'post':
132
+ x[:len(trunc)] = trunc
133
+ else:
134
+ x[-len(trunc):] = trunc
135
+ return x
136
+
137
+
138
+ class TransformerConnectionError(ValueError):
139
+ def __init__(self):
140
+ pass
141
+
142
+
143
+ def retry(f):
144
+ @wraps(f)
145
+ def decorated(*args, **kwargs):
146
+ count = 5
147
+ while count:
148
+
149
+ try:
150
+ return f(*args, **kwargs)
151
+ except (
152
+ TransformerConnectionError,
153
+ requests.exceptions.RequestException,
154
+ requests.exceptions.ConnectionError,
155
+ requests.exceptions.HTTPError,
156
+ requests.exceptions.ConnectTimeout,
157
+ requests.exceptions.ProxyError,
158
+ requests.exceptions.SSLError,
159
+ requests.exceptions.BaseHTTPError,
160
+ ) as e:
161
+ print(colored('Training Exception: {}, will retry later'.format(e)))
162
+ time.sleep(60)
163
+ count -= 1
164
+
165
+ return decorated
166
+
167
+
168
+ def save_json(dic, save_path):
169
+ if isinstance(dic, str):
170
+ dic = eval(dic)
171
+ with open(save_path, 'w', encoding='utf-8') as f:
172
+ # f.write(str(dict))
173
+ str_ = json.dumps(dic, ensure_ascii=False)
174
+ f.write(str_)
175
+
176
+
177
+ def load_json(save_path):
178
+ with open(save_path, 'r', encoding='utf-8') as f:
179
+ data = f.readline().strip()
180
+ print(type(data), data)
181
+ dic = json.loads(data)
182
+ return dic
183
+
184
+
185
+ def init_optimizer(optimizer):
186
+ optimizers = {
187
+ 'adadelta': torch.optim.Adadelta, # default lr=1.0
188
+ 'adagrad': torch.optim.Adagrad, # default lr=0.01
189
+ 'adam': torch.optim.Adam, # default lr=0.001
190
+ 'adamax': torch.optim.Adamax, # default lr=0.002
191
+ 'asgd': torch.optim.ASGD, # default lr=0.01
192
+ 'rmsprop': torch.optim.RMSprop, # default lr=0.01
193
+ 'sgd': torch.optim.SGD,
194
+ 'adamw': torch.optim.AdamW,
195
+ torch.optim.Adadelta: torch.optim.Adadelta, # default lr=1.0
196
+ torch.optim.Adagrad: torch.optim.Adagrad, # default lr=0.01
197
+ torch.optim.Adam: torch.optim.Adam, # default lr=0.001
198
+ torch.optim.Adamax: torch.optim.Adamax, # default lr=0.002
199
+ torch.optim.ASGD: torch.optim.ASGD, # default lr=0.01
200
+ torch.optim.RMSprop: torch.optim.RMSprop, # default lr=0.01
201
+ torch.optim.SGD: torch.optim.SGD,
202
+ torch.optim.AdamW: torch.optim.AdamW,
203
+ }
204
+ if optimizer in optimizers:
205
+ return optimizers[optimizer]
206
+ elif hasattr(torch.optim, optimizer.__name__):
207
+ return optimizer
208
+ else:
209
+ raise KeyError('Unsupported optimizer: {}. Please use string or the optimizer objects in torch.optim as your optimizer'.format(optimizer))
anonymous_demo/utils/logger.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ import time
5
+
6
+ import termcolor
7
+
8
+ today = time.strftime('%Y%m%d %H%M%S', time.localtime(time.time()))
9
+
10
+
11
+ def get_logger(log_path, log_name='', log_type='training_log'):
12
+ if not log_path:
13
+ log_dir = os.path.join(log_path, "logs")
14
+ else:
15
+ log_dir = os.path.join('.', "logs")
16
+
17
+ full_path = os.path.join(log_dir, log_name + '_' + today)
18
+ if not os.path.exists(full_path):
19
+ os.makedirs(full_path)
20
+ log_path = os.path.join(full_path, "{}.log".format(log_type))
21
+ logger = logging.getLogger(log_name)
22
+ if not logger.handlers:
23
+ formatter = logging.Formatter('%(asctime)s %(levelname)s: %(message)s')
24
+
25
+ file_handler = logging.FileHandler(log_path, encoding="utf8")
26
+ file_handler.setFormatter(formatter)
27
+ file_handler.setLevel(logging.INFO)
28
+
29
+ console_handler = logging.StreamHandler(sys.stdout)
30
+ console_handler.formatter = formatter
31
+ console_handler.setLevel(logging.INFO)
32
+
33
+ logger.addHandler(file_handler)
34
+ logger.addHandler(console_handler)
35
+
36
+ logger.setLevel(logging.INFO)
37
+
38
+ return logger
app.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import zipfile
4
+ from difflib import Differ
5
+
6
+ import gradio as gr
7
+ import nltk
8
+ import pandas as pd
9
+ from findfile import find_files
10
+
11
+ from anonymous_demo import TADCheckpointManager
12
+ from textattack import Attacker
13
+ from textattack.attack_recipes import BAEGarg2019, PWWSRen2019, TextFoolerJin2019, PSOZang2020, IGAWang2019, GeneticAlgorithmAlzantot2018, DeepWordBugGao2018
14
+ from textattack.attack_results import SuccessfulAttackResult
15
+ from textattack.datasets import Dataset
16
+ from textattack.models.wrappers import HuggingFaceModelWrapper
17
+
18
+ z = zipfile.ZipFile('checkpoints.zip', 'r')
19
+ z.extractall(os.getcwd())
20
+
21
+ class ModelWrapper(HuggingFaceModelWrapper):
22
+ def __init__(self, model):
23
+ self.model = model # pipeline = pipeline
24
+
25
+ def __call__(self, text_inputs, **kwargs):
26
+ outputs = []
27
+ for text_input in text_inputs:
28
+ raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
29
+ outputs.append(raw_outputs['probs'])
30
+ return outputs
31
+
32
+
33
+ class SentAttacker:
34
+
35
+ def __init__(self, model, recipe_class=BAEGarg2019):
36
+ model = model
37
+ model_wrapper = ModelWrapper(model)
38
+
39
+ recipe = recipe_class.build(model_wrapper)
40
+ # WordNet defaults to english. Set the default language to French ('fra')
41
+
42
+ # recipe.transformation.language = "en"
43
+
44
+ _dataset = [('', 0)]
45
+ _dataset = Dataset(_dataset)
46
+
47
+ self.attacker = Attacker(recipe, _dataset)
48
+
49
+
50
+ def diff_texts(text1, text2):
51
+ d = Differ()
52
+ return [
53
+ (token[2:], token[0] if token[0] != " " else None)
54
+ for token in d.compare(text1, text2)
55
+ ]
56
+
57
+
58
+ def get_ensembled_tad_results(results):
59
+ target_dict = {}
60
+ for r in results:
61
+ target_dict[r['label']] = target_dict.get(r['label']) + 1 if r['label'] in target_dict else 1
62
+
63
+ return dict(zip(target_dict.values(), target_dict.keys()))[max(target_dict.values())]
64
+
65
+
66
+ nltk.download('omw-1.4')
67
+
68
+ sent_attackers = {}
69
+ tad_classifiers = {}
70
+
71
+ attack_recipes = {
72
+ 'bae': BAEGarg2019,
73
+ 'pwws': PWWSRen2019,
74
+ 'textfooler': TextFoolerJin2019,
75
+ 'pso': PSOZang2020,
76
+ 'iga': IGAWang2019,
77
+ 'GA': GeneticAlgorithmAlzantot2018,
78
+ 'wordbugger': DeepWordBugGao2018,
79
+ }
80
+
81
+ for attacker in [
82
+ 'pwws',
83
+ 'bae',
84
+ 'textfooler'
85
+ ]:
86
+ for dataset in [
87
+ 'agnews10k',
88
+ 'amazon',
89
+ 'sst2',
90
+ ]:
91
+ if 'tad-{}'.format(dataset) not in tad_classifiers:
92
+ tad_classifiers['tad-{}'.format(dataset)] = TADCheckpointManager.get_tad_text_classifier('tad-{}'.format(dataset).upper())
93
+
94
+ sent_attackers['tad-{}{}'.format(dataset, attacker)] = SentAttacker(tad_classifiers['tad-{}'.format(dataset)], attack_recipes[attacker])
95
+ tad_classifiers['tad-{}'.format(dataset)].sent_attacker = sent_attackers['tad-{}pwws'.format(dataset)]
96
+
97
+
98
+ def get_a_sst2_example():
99
+ filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip', '.state_dict', '.model', '.png', 'acc_', 'f1_', '.origin', '.adv', '.csv']
100
+
101
+ dataset_file = {'train': [], 'test': [], 'valid': []}
102
+ dataset = 'sst2'
103
+ search_path = './'
104
+ task = 'text_defense'
105
+ dataset_file['test'] += find_files(search_path, [dataset, 'test', task], exclude_key=['.adv', '.org', '.defense', '.inference', 'train.'] + filter_key_words)
106
+
107
+ for dat_type in [
108
+ 'test'
109
+ ]:
110
+ data = []
111
+ label_set = set()
112
+ for data_file in dataset_file[dat_type]:
113
+
114
+ with open(data_file, mode='r', encoding='utf8') as fin:
115
+ lines = fin.readlines()
116
+ for line in lines:
117
+ text, label = line.split('$LABEL$')
118
+ text = text.strip()
119
+ label = int(label.strip())
120
+ data.append((text, label))
121
+ label_set.add(label)
122
+ return data[random.randint(0, len(data))]
123
+
124
+
125
+ def get_a_agnews_example():
126
+ filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip', '.state_dict', '.model', '.png', 'acc_', 'f1_', '.origin', '.adv', '.csv']
127
+
128
+ dataset_file = {'train': [], 'test': [], 'valid': []}
129
+ dataset = 'agnews'
130
+ search_path = './'
131
+ task = 'text_defense'
132
+ dataset_file['test'] += find_files(search_path, [dataset, 'test', task], exclude_key=['.adv', '.org', '.defense', '.inference', 'train.'] + filter_key_words)
133
+ for dat_type in [
134
+ 'test'
135
+ ]:
136
+ data = []
137
+ label_set = set()
138
+ for data_file in dataset_file[dat_type]:
139
+
140
+ with open(data_file, mode='r', encoding='utf8') as fin:
141
+ lines = fin.readlines()
142
+ for line in lines:
143
+ text, label = line.split('$LABEL$')
144
+ text = text.strip()
145
+ label = int(label.strip())
146
+ data.append((text, label))
147
+ label_set.add(label)
148
+ return data[random.randint(0, len(data))]
149
+
150
+
151
+ def get_a_amazon_example():
152
+ filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip', '.state_dict', '.model', '.png', 'acc_', 'f1_', '.origin', '.adv', '.csv']
153
+
154
+ dataset_file = {'train': [], 'test': [], 'valid': []}
155
+ dataset = 'amazon'
156
+ search_path = './'
157
+ task = 'text_defense'
158
+ dataset_file['test'] += find_files(search_path, [dataset, 'test', task], exclude_key=['.adv', '.org', '.defense', '.inference', 'train.'] + filter_key_words)
159
+
160
+ for dat_type in [
161
+ 'test'
162
+ ]:
163
+ data = []
164
+ label_set = set()
165
+ for data_file in dataset_file[dat_type]:
166
+
167
+ with open(data_file, mode='r', encoding='utf8') as fin:
168
+ lines = fin.readlines()
169
+ for line in lines:
170
+ text, label = line.split('$LABEL$')
171
+ text = text.strip()
172
+ label = int(label.strip())
173
+ data.append((text, label))
174
+ label_set.add(label)
175
+ return data[random.randint(0, len(data))]
176
+
177
+
178
+ def generate_adversarial_example(dataset, attacker, text=None, label=None):
179
+ if not text:
180
+ if 'agnews' in dataset.lower():
181
+ text, label = get_a_agnews_example()
182
+ elif 'sst2' in dataset.lower():
183
+ text, label = get_a_sst2_example()
184
+ elif 'amazon' in dataset.lower():
185
+ text, label = get_a_amazon_example()
186
+
187
+ result = None
188
+ attack_result = sent_attackers['tad-{}{}'.format(dataset.lower(), attacker.lower())].attacker.simple_attack(text, int(label))
189
+ if isinstance(attack_result, SuccessfulAttackResult):
190
+
191
+ if (attack_result.perturbed_result.output != attack_result.original_result.ground_truth_output) and (attack_result.original_result.output == attack_result.original_result.ground_truth_output):
192
+ # with defense
193
+ result = tad_classifiers['tad-{}'.format(dataset.lower())].infer(
194
+ attack_result.perturbed_result.attacked_text.text + '!ref!{},{},{}'.format(attack_result.original_result.ground_truth_output, 1, attack_result.perturbed_result.output),
195
+ print_result=True,
196
+ defense='pwws',
197
+ )
198
+
199
+ if result:
200
+ classification_df = {}
201
+ classification_df['pred_label'] = result['label']
202
+ classification_df['confidence'] = round(result['confidence'], 3)
203
+ classification_df['is_correct'] = result['ref_label_check']
204
+ classification_df['is_repaired'] = result['is_fixed']
205
+
206
+ advdetection_df = {}
207
+ if result['is_adv_label'] != '0':
208
+ advdetection_df['is_adversary'] = result['is_adv_label']
209
+ advdetection_df['perturbed_label'] = result['perturbed_label']
210
+ advdetection_df['confidence'] = round(result['is_adv_confidence'], 3)
211
+ # advdetection_df['ref_is_attack'] = result['ref_is_adv_label']
212
+ # advdetection_df['is_correct'] = result['ref_is_adv_check']
213
+
214
+ else:
215
+ return generate_adversarial_example(dataset, attacker)
216
+
217
+ return (text,
218
+ label,
219
+ attack_result.perturbed_result.attacked_text.text,
220
+ diff_texts(text, attack_result.perturbed_result.attacked_text.text),
221
+ diff_texts(text, result['restored_text']),
222
+ attack_result.perturbed_result.output,
223
+ pd.DataFrame(classification_df, index=[0]),
224
+ pd.DataFrame(advdetection_df, index=[0])
225
+ )
226
+
227
+
228
+ demo = gr.Blocks()
229
+
230
+ with demo:
231
+ with gr.Row():
232
+ with gr.Column():
233
+ input_dataset = gr.Radio(choices=['SST2', 'AGNews10K', 'Amazon'], value='Amazon', label="Dataset")
234
+ input_attacker = gr.Radio(choices=['BAE', 'PWWS', 'TextFooler'], value='TextFooler', label="Attacker")
235
+ input_sentence = gr.Textbox(placeholder='Randomly choose a example from testing set if this box is blank', label="Sentence")
236
+ input_label = gr.Textbox(placeholder='original label ... ', label="Original Label")
237
+
238
+ gr.Markdown("Original Example")
239
+
240
+ output_origin_example = gr.Textbox(label="Original Example")
241
+ output_original_label = gr.Textbox(label="Original Label")
242
+
243
+ gr.Markdown("Adversarial Example")
244
+ output_adv_example = gr.Textbox(label="Adversarial Example")
245
+ output_adv_label = gr.Textbox(label="Perturbed Label")
246
+
247
+ gr.Markdown('This demo is deployed on a CPU device so it may take a long time to execute. Please be patient.')
248
+ button_gen = gr.Button("Click Here to Generate an Adversary and Run Adversary Detection & Repair")
249
+
250
+ # Right column (outputs)
251
+ with gr.Column():
252
+ gr.Markdown("Example Difference")
253
+ adv_text_diff = gr.HighlightedText(label="Adversarial Example Difference", combine_adjacent=True)
254
+ restored_text_diff = gr.HighlightedText(label="Restored Example Difference", combine_adjacent=True)
255
+
256
+ output_is_adv_df = gr.DataFrame(label="Adversary Prediction")
257
+ output_df = gr.DataFrame(label="Standard Classification Prediction")
258
+
259
+ # Bind functions to buttons
260
+ button_gen.click(fn=generate_adversarial_example,
261
+ inputs=[input_dataset, input_attacker, input_sentence, input_label],
262
+ outputs=[output_origin_example,
263
+ output_original_label,
264
+ output_adv_example,
265
+ adv_text_diff,
266
+ restored_text_diff,
267
+ output_adv_label,
268
+ output_df,
269
+ output_is_adv_df])
270
+
271
+ demo.launch()
checkpoints.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a5452cd89dcd3132d616cc81e2a1b063efa7d11e5798719b0779715b1c6edeb
3
+ size 1846862527
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ findfile>=1.7.9.8
2
+ autocuda>=0.11
3
+ metric-visualizer>=0.5.5
4
+ boostaug>=2.2.3
5
+ spacy
6
+ networkx
7
+ seqeval
8
+ update-checker
9
+ typing_extensions
10
+ tqdm
11
+ pytorch_warmup
12
+ termcolor
13
+ gitpython
14
+ gdown>=4.4.0
15
+ transformers>4.20.0
16
+ torch>1.0.0
17
+ sentencepiece
18
+ tensorflow_text
19
+ textattack
text_defense/201.SST2/stsa.binary.dev.dat ADDED
The diff for this file is too large to render. See raw diff
text_defense/201.SST2/stsa.binary.test.dat ADDED
The diff for this file is too large to render. See raw diff
text_defense/201.SST2/stsa.binary.train.dat ADDED
The diff for this file is too large to render. See raw diff
text_defense/204.AGNews10K/AGNews10K.test.dat ADDED
The diff for this file is too large to render. See raw diff
text_defense/204.AGNews10K/AGNews10K.train.dat ADDED
The diff for this file is too large to render. See raw diff
text_defense/204.AGNews10K/AGNews10K.valid.dat ADDED
The diff for this file is too large to render. See raw diff
text_defense/206.Amazon_Review_Polarity10K/amazon.test.dat ADDED
The diff for this file is too large to render. See raw diff
text_defense/206.Amazon_Review_Polarity10K/amazon.train.dat ADDED
The diff for this file is too large to render. See raw diff
textattack/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Welcome to the API references for TextAttack!
2
+
3
+ What is TextAttack?
4
+
5
+ `TextAttack <https://github.com/QData/TextAttack>`__ is a Python framework for adversarial attacks, adversarial training, and data augmentation in NLP.
6
+
7
+ TextAttack makes experimenting with the robustness of NLP models seamless, fast, and easy. It's also useful for NLP model training, adversarial training, and data augmentation.
8
+
9
+ TextAttack provides components for common NLP tasks like sentence encoding, grammar-checking, and word replacement that can be used on their own.
10
+ """
11
+ from .attack_args import AttackArgs, CommandLineAttackArgs
12
+ from .augment_args import AugmenterArgs
13
+ from .dataset_args import DatasetArgs
14
+ from .model_args import ModelArgs
15
+ from .training_args import TrainingArgs, CommandLineTrainingArgs
16
+ from .attack import Attack
17
+ from .attacker import Attacker
18
+ from .trainer import Trainer
19
+ from .metrics import Metric
20
+
21
+ from . import (
22
+ attack_recipes,
23
+ attack_results,
24
+ augmentation,
25
+ commands,
26
+ constraints,
27
+ datasets,
28
+ goal_function_results,
29
+ goal_functions,
30
+ loggers,
31
+ metrics,
32
+ models,
33
+ search_methods,
34
+ shared,
35
+ transformations,
36
+ )
37
+
38
+
39
+ name = "textattack"
textattack/__main__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ if __name__ == "__main__":
4
+ import textattack
5
+
6
+ textattack.commands.textattack_cli.main()
textattack/attack.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Attack Class
3
+ ============
4
+ """
5
+
6
+ from collections import OrderedDict
7
+ from typing import List, Union
8
+
9
+ import lru
10
+ import torch
11
+
12
+ import textattack
13
+ from textattack.attack_results import (
14
+ FailedAttackResult,
15
+ MaximizedAttackResult,
16
+ SkippedAttackResult,
17
+ SuccessfulAttackResult,
18
+ )
19
+ from textattack.constraints import Constraint, PreTransformationConstraint
20
+ from textattack.goal_function_results import GoalFunctionResultStatus
21
+ from textattack.goal_functions import GoalFunction
22
+ from textattack.models.wrappers import ModelWrapper
23
+ from textattack.search_methods import SearchMethod
24
+ from textattack.shared import AttackedText, utils
25
+ from textattack.transformations import CompositeTransformation, Transformation
26
+
27
+
28
+ class Attack:
29
+ """An attack generates adversarial examples on text.
30
+
31
+ An attack is comprised of a goal function, constraints, transformation, and a search method. Use :meth:`attack` method to attack one sample at a time.
32
+
33
+ Args:
34
+ goal_function (:class:`~textattack.goal_functions.GoalFunction`):
35
+ A function for determining how well a perturbation is doing at achieving the attack's goal.
36
+ constraints (list of :class:`~textattack.constraints.Constraint` or :class:`~textattack.constraints.PreTransformationConstraint`):
37
+ A list of constraints to add to the attack, defining which perturbations are valid.
38
+ transformation (:class:`~textattack.transformations.Transformation`):
39
+ The transformation applied at each step of the attack.
40
+ search_method (:class:`~textattack.search_methods.SearchMethod`):
41
+ The method for exploring the search space of possible perturbations
42
+ transformation_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**15`):
43
+ The number of items to keep in the transformations cache
44
+ constraint_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**15`):
45
+ The number of items to keep in the constraints cache
46
+
47
+ Example::
48
+
49
+ >>> import textattack
50
+ >>> import transformers
51
+
52
+ >>> # Load model, tokenizer, and model_wrapper
53
+ >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb")
54
+ >>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb")
55
+ >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
56
+
57
+ >>> # Construct our four components for `Attack`
58
+ >>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
59
+ >>> from textattack.constraints.semantics import WordEmbeddingDistance
60
+
61
+ >>> goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
62
+ >>> constraints = [
63
+ ... RepeatModification(),
64
+ ... StopwordModification()
65
+ ... WordEmbeddingDistance(min_cos_sim=0.9)
66
+ ... ]
67
+ >>> transformation = WordSwapEmbedding(max_candidates=50)
68
+ >>> search_method = GreedyWordSwapWIR(wir_method="delete")
69
+
70
+ >>> # Construct the actual attack
71
+ >>> attack = Attack(goal_function, constraints, transformation, search_method)
72
+
73
+ >>> input_text = "I really enjoyed the new movie that came out last month."
74
+ >>> label = 1 #Positive
75
+ >>> attack_result = attack.attack(input_text, label)
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ goal_function: GoalFunction,
81
+ constraints: List[Union[Constraint, PreTransformationConstraint]],
82
+ transformation: Transformation,
83
+ search_method: SearchMethod,
84
+ transformation_cache_size=2**15,
85
+ constraint_cache_size=2**15,
86
+ ):
87
+ """Initialize an attack object.
88
+
89
+ Attacks can be run multiple times.
90
+ """
91
+ assert isinstance(
92
+ goal_function, GoalFunction
93
+ ), f"`goal_function` must be of type `textattack.goal_functions.GoalFunction`, but got type `{type(goal_function)}`."
94
+ assert isinstance(
95
+ constraints, list
96
+ ), "`constraints` must be a list of `textattack.constraints.Constraint` or `textattack.constraints.PreTransformationConstraint`."
97
+ for c in constraints:
98
+ assert isinstance(
99
+ c, (Constraint, PreTransformationConstraint)
100
+ ), "`constraints` must be a list of `textattack.constraints.Constraint` or `textattack.constraints.PreTransformationConstraint`."
101
+ assert isinstance(
102
+ transformation, Transformation
103
+ ), f"`transformation` must be of type `textattack.transformations.Transformation`, but got type `{type(transformation)}`."
104
+ assert isinstance(
105
+ search_method, SearchMethod
106
+ ), f"`search_method` must be of type `textattack.search_methods.SearchMethod`, but got type `{type(search_method)}`."
107
+
108
+ self.goal_function = goal_function
109
+ self.search_method = search_method
110
+ self.transformation = transformation
111
+ self.is_black_box = (
112
+ getattr(transformation, "is_black_box", True) and search_method.is_black_box
113
+ )
114
+
115
+ if not self.search_method.check_transformation_compatibility(
116
+ self.transformation
117
+ ):
118
+ raise ValueError(
119
+ f"SearchMethod {self.search_method} incompatible with transformation {self.transformation}"
120
+ )
121
+
122
+ self.constraints = []
123
+ self.pre_transformation_constraints = []
124
+ for constraint in constraints:
125
+ if isinstance(
126
+ constraint,
127
+ textattack.constraints.PreTransformationConstraint,
128
+ ):
129
+ self.pre_transformation_constraints.append(constraint)
130
+ else:
131
+ self.constraints.append(constraint)
132
+
133
+ # Check if we can use transformation cache for our transformation.
134
+ if not self.transformation.deterministic:
135
+ self.use_transformation_cache = False
136
+ elif isinstance(self.transformation, CompositeTransformation):
137
+ self.use_transformation_cache = True
138
+ for t in self.transformation.transformations:
139
+ if not t.deterministic:
140
+ self.use_transformation_cache = False
141
+ break
142
+ else:
143
+ self.use_transformation_cache = True
144
+ self.transformation_cache_size = transformation_cache_size
145
+ self.transformation_cache = lru.LRU(transformation_cache_size)
146
+
147
+ self.constraint_cache_size = constraint_cache_size
148
+ self.constraints_cache = lru.LRU(constraint_cache_size)
149
+
150
+ # Give search method access to functions for getting transformations and evaluating them
151
+ self.search_method.get_transformations = self.get_transformations
152
+ # Give search method access to self.goal_function for model query count, etc.
153
+ self.search_method.goal_function = self.goal_function
154
+ # The search method only needs access to the first argument. The second is only used
155
+ # by the attack class when checking whether to skip the sample
156
+ self.search_method.get_goal_results = self.goal_function.get_results
157
+
158
+ # Give search method access to get indices which need to be ordered / searched
159
+ self.search_method.get_indices_to_order = self.get_indices_to_order
160
+
161
+ self.search_method.filter_transformations = self.filter_transformations
162
+
163
+ def clear_cache(self, recursive=True):
164
+ self.constraints_cache.clear()
165
+ if self.use_transformation_cache:
166
+ self.transformation_cache.clear()
167
+ if recursive:
168
+ self.goal_function.clear_cache()
169
+ for constraint in self.constraints:
170
+ if hasattr(constraint, "clear_cache"):
171
+ constraint.clear_cache()
172
+
173
+ def cpu_(self):
174
+ """Move any `torch.nn.Module` models that are part of Attack to CPU."""
175
+ visited = set()
176
+
177
+ def to_cpu(obj):
178
+ visited.add(id(obj))
179
+ if isinstance(obj, torch.nn.Module):
180
+ obj.cpu()
181
+ elif isinstance(
182
+ obj,
183
+ (
184
+ Attack,
185
+ GoalFunction,
186
+ Transformation,
187
+ SearchMethod,
188
+ Constraint,
189
+ PreTransformationConstraint,
190
+ ModelWrapper,
191
+ ),
192
+ ):
193
+ for key in obj.__dict__:
194
+ s_obj = obj.__dict__[key]
195
+ if id(s_obj) not in visited:
196
+ to_cpu(s_obj)
197
+ elif isinstance(obj, (list, tuple)):
198
+ for item in obj:
199
+ if id(item) not in visited and isinstance(
200
+ item, (Transformation, Constraint, PreTransformationConstraint)
201
+ ):
202
+ to_cpu(item)
203
+
204
+ to_cpu(self)
205
+
206
+ def cuda_(self):
207
+ """Move any `torch.nn.Module` models that are part of Attack to GPU."""
208
+ visited = set()
209
+
210
+ def to_cuda(obj):
211
+ visited.add(id(obj))
212
+ if isinstance(obj, torch.nn.Module):
213
+ obj.to(textattack.shared.utils.device)
214
+ elif isinstance(
215
+ obj,
216
+ (
217
+ Attack,
218
+ GoalFunction,
219
+ Transformation,
220
+ SearchMethod,
221
+ Constraint,
222
+ PreTransformationConstraint,
223
+ ModelWrapper,
224
+ ),
225
+ ):
226
+ for key in obj.__dict__:
227
+ s_obj = obj.__dict__[key]
228
+ if id(s_obj) not in visited:
229
+ to_cuda(s_obj)
230
+ elif isinstance(obj, (list, tuple)):
231
+ for item in obj:
232
+ if id(item) not in visited and isinstance(
233
+ item, (Transformation, Constraint, PreTransformationConstraint)
234
+ ):
235
+ to_cuda(item)
236
+
237
+ to_cuda(self)
238
+
239
+ def get_indices_to_order(self, current_text, **kwargs):
240
+ """Applies ``pre_transformation_constraints`` to ``text`` to get all
241
+ the indices that can be used to search and order.
242
+
243
+ Args:
244
+ current_text: The current ``AttackedText`` for which we need to find indices are eligible to be ordered.
245
+ Returns:
246
+ The length and the filtered list of indices which search methods can use to search/order.
247
+ """
248
+
249
+ indices_to_order = self.transformation(
250
+ current_text,
251
+ pre_transformation_constraints=self.pre_transformation_constraints,
252
+ return_indices=True,
253
+ **kwargs,
254
+ )
255
+
256
+ len_text = len(indices_to_order)
257
+
258
+ # Convert indices_to_order to list for easier shuffling later
259
+ return len_text, list(indices_to_order)
260
+
261
+ def _get_transformations_uncached(self, current_text, original_text=None, **kwargs):
262
+ """Applies ``self.transformation`` to ``text``, then filters the list
263
+ of possible transformations through the applicable constraints.
264
+
265
+ Args:
266
+ current_text: The current ``AttackedText`` on which to perform the transformations.
267
+ original_text: The original ``AttackedText`` from which the attack started.
268
+ Returns:
269
+ A filtered list of transformations where each transformation matches the constraints
270
+ """
271
+ transformed_texts = self.transformation(
272
+ current_text,
273
+ pre_transformation_constraints=self.pre_transformation_constraints,
274
+ **kwargs,
275
+ )
276
+
277
+ return transformed_texts
278
+
279
+ def get_transformations(self, current_text, original_text=None, **kwargs):
280
+ """Applies ``self.transformation`` to ``text``, then filters the list
281
+ of possible transformations through the applicable constraints.
282
+
283
+ Args:
284
+ current_text: The current ``AttackedText`` on which to perform the transformations.
285
+ original_text: The original ``AttackedText`` from which the attack started.
286
+ Returns:
287
+ A filtered list of transformations where each transformation matches the constraints
288
+ """
289
+ if not self.transformation:
290
+ raise RuntimeError(
291
+ "Cannot call `get_transformations` without a transformation."
292
+ )
293
+
294
+ if self.use_transformation_cache:
295
+ cache_key = tuple([current_text] + sorted(kwargs.items()))
296
+ if utils.hashable(cache_key) and cache_key in self.transformation_cache:
297
+ # promote transformed_text to the top of the LRU cache
298
+ self.transformation_cache[cache_key] = self.transformation_cache[
299
+ cache_key
300
+ ]
301
+ transformed_texts = list(self.transformation_cache[cache_key])
302
+ else:
303
+ transformed_texts = self._get_transformations_uncached(
304
+ current_text, original_text, **kwargs
305
+ )
306
+ if utils.hashable(cache_key):
307
+ self.transformation_cache[cache_key] = tuple(transformed_texts)
308
+ else:
309
+ transformed_texts = self._get_transformations_uncached(
310
+ current_text, original_text, **kwargs
311
+ )
312
+
313
+ return self.filter_transformations(
314
+ transformed_texts, current_text, original_text
315
+ )
316
+
317
+ def _filter_transformations_uncached(
318
+ self, transformed_texts, current_text, original_text=None
319
+ ):
320
+ """Filters a list of potential transformed texts based on
321
+ ``self.constraints``
322
+
323
+ Args:
324
+ transformed_texts: A list of candidate transformed ``AttackedText`` to filter.
325
+ current_text: The current ``AttackedText`` on which the transformation was applied.
326
+ original_text: The original ``AttackedText`` from which the attack started.
327
+ """
328
+ filtered_texts = transformed_texts[:]
329
+ for C in self.constraints:
330
+ if len(filtered_texts) == 0:
331
+ break
332
+ if C.compare_against_original:
333
+ if not original_text:
334
+ raise ValueError(
335
+ f"Missing `original_text` argument when constraint {type(C)} is set to compare against `original_text`"
336
+ )
337
+
338
+ filtered_texts = C.call_many(filtered_texts, original_text)
339
+ else:
340
+ filtered_texts = C.call_many(filtered_texts, current_text)
341
+ # Default to false for all original transformations.
342
+ for original_transformed_text in transformed_texts:
343
+ self.constraints_cache[(current_text, original_transformed_text)] = False
344
+ # Set unfiltered transformations to True in the cache.
345
+ for filtered_text in filtered_texts:
346
+ self.constraints_cache[(current_text, filtered_text)] = True
347
+ return filtered_texts
348
+
349
+ def filter_transformations(
350
+ self, transformed_texts, current_text, original_text=None
351
+ ):
352
+ """Filters a list of potential transformed texts based on
353
+ ``self.constraints`` Utilizes an LRU cache to attempt to avoid
354
+ recomputing common transformations.
355
+
356
+ Args:
357
+ transformed_texts: A list of candidate transformed ``AttackedText`` to filter.
358
+ current_text: The current ``AttackedText`` on which the transformation was applied.
359
+ original_text: The original ``AttackedText`` from which the attack started.
360
+ """
361
+ # Remove any occurences of current_text in transformed_texts
362
+ transformed_texts = [
363
+ t for t in transformed_texts if t.text != current_text.text
364
+ ]
365
+ # Populate cache with transformed_texts
366
+ uncached_texts = []
367
+ filtered_texts = []
368
+ for transformed_text in transformed_texts:
369
+ if (current_text, transformed_text) not in self.constraints_cache:
370
+ uncached_texts.append(transformed_text)
371
+ else:
372
+ # promote transformed_text to the top of the LRU cache
373
+ self.constraints_cache[
374
+ (current_text, transformed_text)
375
+ ] = self.constraints_cache[(current_text, transformed_text)]
376
+ if self.constraints_cache[(current_text, transformed_text)]:
377
+ filtered_texts.append(transformed_text)
378
+ filtered_texts += self._filter_transformations_uncached(
379
+ uncached_texts, current_text, original_text=original_text
380
+ )
381
+ # Sort transformations to ensure order is preserved between runs
382
+ filtered_texts.sort(key=lambda t: t.text)
383
+ return filtered_texts
384
+
385
+ def _attack(self, initial_result):
386
+ """Calls the ``SearchMethod`` to perturb the ``AttackedText`` stored in
387
+ ``initial_result``.
388
+
389
+ Args:
390
+ initial_result: The initial ``GoalFunctionResult`` from which to perturb.
391
+
392
+ Returns:
393
+ A ``SuccessfulAttackResult``, ``FailedAttackResult``,
394
+ or ``MaximizedAttackResult``.
395
+ """
396
+ final_result = self.search_method(initial_result)
397
+ self.clear_cache()
398
+ if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
399
+ result = SuccessfulAttackResult(
400
+ initial_result,
401
+ final_result,
402
+ )
403
+ elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
404
+ result = FailedAttackResult(
405
+ initial_result,
406
+ final_result,
407
+ )
408
+ elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING:
409
+ result = MaximizedAttackResult(
410
+ initial_result,
411
+ final_result,
412
+ )
413
+ else:
414
+ raise ValueError(f"Unrecognized goal status {final_result.goal_status}")
415
+ return result
416
+
417
+ def attack(self, example, ground_truth_output):
418
+ """Attack a single example.
419
+
420
+ Args:
421
+ example (:obj:`str`, :obj:`OrderedDict[str, str]` or :class:`~textattack.shared.AttackedText`):
422
+ Example to attack. It can be a single string or an `OrderedDict` where
423
+ keys represent the input fields (e.g. "premise", "hypothesis") and the values are the actual input textx.
424
+ Also accepts :class:`~textattack.shared.AttackedText` that wraps around the input.
425
+ ground_truth_output(:obj:`int`, :obj:`float` or :obj:`str`):
426
+ Ground truth output of `example`.
427
+ For classification tasks, it should be an integer representing the ground truth label.
428
+ For regression tasks (e.g. STS), it should be the target value.
429
+ For seq2seq tasks (e.g. translation), it should be the target string.
430
+ Returns:
431
+ :class:`~textattack.attack_results.AttackResult` that represents the result of the attack.
432
+ """
433
+ assert isinstance(
434
+ example, (str, OrderedDict, AttackedText)
435
+ ), "`example` must either be `str`, `collections.OrderedDict`, `textattack.shared.AttackedText`."
436
+ if isinstance(example, (str, OrderedDict)):
437
+ example = AttackedText(example)
438
+
439
+ assert isinstance(
440
+ ground_truth_output, (int, str)
441
+ ), "`ground_truth_output` must either be `str` or `int`."
442
+ goal_function_result, _ = self.goal_function.init_attack_example(
443
+ example, ground_truth_output
444
+ )
445
+ if goal_function_result.goal_status == GoalFunctionResultStatus.SKIPPED:
446
+ return SkippedAttackResult(goal_function_result)
447
+ else:
448
+ result = self._attack(goal_function_result)
449
+ return result
450
+
451
+ def __repr__(self):
452
+ """Prints attack parameters in a human-readable string.
453
+
454
+ Inspired by the readability of printing PyTorch nn.Modules:
455
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
456
+ """
457
+ main_str = "Attack" + "("
458
+ lines = []
459
+
460
+ lines.append(utils.add_indent(f"(search_method): {self.search_method}", 2))
461
+ # self.goal_function
462
+ lines.append(utils.add_indent(f"(goal_function): {self.goal_function}", 2))
463
+ # self.transformation
464
+ lines.append(utils.add_indent(f"(transformation): {self.transformation}", 2))
465
+ # self.constraints
466
+ constraints_lines = []
467
+ constraints = self.constraints + self.pre_transformation_constraints
468
+ if len(constraints):
469
+ for i, constraint in enumerate(constraints):
470
+ constraints_lines.append(utils.add_indent(f"({i}): {constraint}", 2))
471
+ constraints_str = utils.add_indent("\n" + "\n".join(constraints_lines), 2)
472
+ else:
473
+ constraints_str = "None"
474
+ lines.append(utils.add_indent(f"(constraints): {constraints_str}", 2))
475
+ # self.is_black_box
476
+ lines.append(utils.add_indent(f"(is_black_box): {self.is_black_box}", 2))
477
+ main_str += "\n " + "\n ".join(lines) + "\n"
478
+ main_str += ")"
479
+ return main_str
480
+
481
+ def __getstate__(self):
482
+ state = self.__dict__.copy()
483
+ state["transformation_cache"] = None
484
+ state["constraints_cache"] = None
485
+ return state
486
+
487
+ def __setstate__(self, state):
488
+ self.__dict__ = state
489
+ self.transformation_cache = lru.LRU(self.transformation_cache_size)
490
+ self.constraints_cache = lru.LRU(self.constraint_cache_size)
491
+
492
+ __str__ = __repr__
textattack/attack_args.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AttackArgs Class
3
+ ================
4
+ """
5
+
6
+ from dataclasses import dataclass, field
7
+ import json
8
+ import os
9
+ import sys
10
+ import time
11
+ from typing import Dict, Optional
12
+
13
+ import textattack
14
+ from textattack.shared.utils import ARGS_SPLIT_TOKEN, load_module_from_file
15
+
16
+ from .attack import Attack
17
+ from .dataset_args import DatasetArgs
18
+ from .model_args import ModelArgs
19
+
20
+ ATTACK_RECIPE_NAMES = {
21
+ "alzantot": "textattack.attack_recipes.GeneticAlgorithmAlzantot2018",
22
+ "bae": "textattack.attack_recipes.BAEGarg2019",
23
+ "bert-attack": "textattack.attack_recipes.BERTAttackLi2020",
24
+ "faster-alzantot": "textattack.attack_recipes.FasterGeneticAlgorithmJia2019",
25
+ "deepwordbug": "textattack.attack_recipes.DeepWordBugGao2018",
26
+ "hotflip": "textattack.attack_recipes.HotFlipEbrahimi2017",
27
+ "input-reduction": "textattack.attack_recipes.InputReductionFeng2018",
28
+ "kuleshov": "textattack.attack_recipes.Kuleshov2017",
29
+ "morpheus": "textattack.attack_recipes.MorpheusTan2020",
30
+ "seq2sick": "textattack.attack_recipes.Seq2SickCheng2018BlackBox",
31
+ "textbugger": "textattack.attack_recipes.TextBuggerLi2018",
32
+ "textfooler": "textattack.attack_recipes.TextFoolerJin2019",
33
+ "pwws": "textattack.attack_recipes.PWWSRen2019",
34
+ "iga": "textattack.attack_recipes.IGAWang2019",
35
+ "pruthi": "textattack.attack_recipes.Pruthi2019",
36
+ "pso": "textattack.attack_recipes.PSOZang2020",
37
+ "checklist": "textattack.attack_recipes.CheckList2020",
38
+ "clare": "textattack.attack_recipes.CLARE2020",
39
+ "a2t": "textattack.attack_recipes.A2TYoo2021",
40
+ }
41
+
42
+
43
+ BLACK_BOX_TRANSFORMATION_CLASS_NAMES = {
44
+ "random-synonym-insertion": "textattack.transformations.RandomSynonymInsertion",
45
+ "word-deletion": "textattack.transformations.WordDeletion",
46
+ "word-swap-embedding": "textattack.transformations.WordSwapEmbedding",
47
+ "word-swap-homoglyph": "textattack.transformations.WordSwapHomoglyphSwap",
48
+ "word-swap-inflections": "textattack.transformations.WordSwapInflections",
49
+ "word-swap-neighboring-char-swap": "textattack.transformations.WordSwapNeighboringCharacterSwap",
50
+ "word-swap-random-char-deletion": "textattack.transformations.WordSwapRandomCharacterDeletion",
51
+ "word-swap-random-char-insertion": "textattack.transformations.WordSwapRandomCharacterInsertion",
52
+ "word-swap-random-char-substitution": "textattack.transformations.WordSwapRandomCharacterSubstitution",
53
+ "word-swap-wordnet": "textattack.transformations.WordSwapWordNet",
54
+ "word-swap-masked-lm": "textattack.transformations.WordSwapMaskedLM",
55
+ "word-swap-hownet": "textattack.transformations.WordSwapHowNet",
56
+ "word-swap-qwerty": "textattack.transformations.WordSwapQWERTY",
57
+ }
58
+
59
+
60
+ WHITE_BOX_TRANSFORMATION_CLASS_NAMES = {
61
+ "word-swap-gradient": "textattack.transformations.WordSwapGradientBased"
62
+ }
63
+
64
+
65
+ CONSTRAINT_CLASS_NAMES = {
66
+ #
67
+ # Semantics constraints
68
+ #
69
+ "embedding": "textattack.constraints.semantics.WordEmbeddingDistance",
70
+ "bert": "textattack.constraints.semantics.sentence_encoders.BERT",
71
+ "infer-sent": "textattack.constraints.semantics.sentence_encoders.InferSent",
72
+ "thought-vector": "textattack.constraints.semantics.sentence_encoders.ThoughtVector",
73
+ "use": "textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder",
74
+ "muse": "textattack.constraints.semantics.sentence_encoders.MultilingualUniversalSentenceEncoder",
75
+ "bert-score": "textattack.constraints.semantics.BERTScore",
76
+ #
77
+ # Grammaticality constraints
78
+ #
79
+ "lang-tool": "textattack.constraints.grammaticality.LanguageTool",
80
+ "part-of-speech": "textattack.constraints.grammaticality.PartOfSpeech",
81
+ "goog-lm": "textattack.constraints.grammaticality.language_models.GoogleLanguageModel",
82
+ "gpt2": "textattack.constraints.grammaticality.language_models.GPT2",
83
+ "learning-to-write": "textattack.constraints.grammaticality.language_models.LearningToWriteLanguageModel",
84
+ "cola": "textattack.constraints.grammaticality.COLA",
85
+ #
86
+ # Overlap constraints
87
+ #
88
+ "bleu": "textattack.constraints.overlap.BLEU",
89
+ "chrf": "textattack.constraints.overlap.chrF",
90
+ "edit-distance": "textattack.constraints.overlap.LevenshteinEditDistance",
91
+ "meteor": "textattack.constraints.overlap.METEOR",
92
+ "max-words-perturbed": "textattack.constraints.overlap.MaxWordsPerturbed",
93
+ #
94
+ # Pre-transformation constraints
95
+ #
96
+ "repeat": "textattack.constraints.pre_transformation.RepeatModification",
97
+ "stopword": "textattack.constraints.pre_transformation.StopwordModification",
98
+ "max-word-index": "textattack.constraints.pre_transformation.MaxWordIndexModification",
99
+ }
100
+
101
+
102
+ SEARCH_METHOD_CLASS_NAMES = {
103
+ "beam-search": "textattack.search_methods.BeamSearch",
104
+ "greedy": "textattack.search_methods.GreedySearch",
105
+ "ga-word": "textattack.search_methods.GeneticAlgorithm",
106
+ "greedy-word-wir": "textattack.search_methods.GreedyWordSwapWIR",
107
+ "pso": "textattack.search_methods.ParticleSwarmOptimization",
108
+ }
109
+
110
+
111
+ GOAL_FUNCTION_CLASS_NAMES = {
112
+ #
113
+ # Classification goal functions
114
+ #
115
+ "targeted-classification": "textattack.goal_functions.classification.TargetedClassification",
116
+ "untargeted-classification": "textattack.goal_functions.classification.UntargetedClassification",
117
+ "input-reduction": "textattack.goal_functions.classification.InputReduction",
118
+ #
119
+ # Text goal functions
120
+ #
121
+ "minimize-bleu": "textattack.goal_functions.text.MinimizeBleu",
122
+ "non-overlapping-output": "textattack.goal_functions.text.NonOverlappingOutput",
123
+ "text-to-text": "textattack.goal_functions.text.TextToTextGoalFunction",
124
+ }
125
+
126
+
127
+ @dataclass
128
+ class AttackArgs:
129
+ """Attack arguments to be passed to :class:`~textattack.Attacker`.
130
+
131
+ Args:
132
+ num_examples (:obj:`int`, 'optional`, defaults to :obj:`10`):
133
+ The number of examples to attack. :obj:`-1` for entire dataset.
134
+ num_successful_examples (:obj:`int`, `optional`, defaults to :obj:`None`):
135
+ The number of successful adversarial examples we want. This is different from :obj:`num_examples`
136
+ as :obj:`num_examples` only cares about attacking `N` samples while :obj:`num_successful_examples` aims to keep attacking
137
+ until we have `N` successful cases.
138
+
139
+ .. note::
140
+ If set, this argument overrides `num_examples` argument.
141
+ num_examples_offset (:obj: `int`, `optional`, defaults to :obj:`0`):
142
+ The offset index to start at in the dataset.
143
+ attack_n (:obj:`bool`, `optional`, defaults to :obj:`False`):
144
+ Whether to run attack until total of `N` examples have been attacked (and not skipped).
145
+ shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`):
146
+ If :obj:`True`, we randomly shuffle the dataset before attacking. However, this avoids actually shuffling
147
+ the dataset internally and opts for shuffling the list of indices of examples we want to attack. This means
148
+ :obj:`shuffle` can now be used with checkpoint saving.
149
+ query_budget (:obj:`int`, `optional`, defaults to :obj:`None`):
150
+ The maximum number of model queries allowed per example attacked.
151
+ If not set, we use the query budget set in the :class:`~textattack.goal_functions.GoalFunction` object (which by default is :obj:`float("inf")`).
152
+
153
+ .. note::
154
+ Setting this overwrites the query budget set in :class:`~textattack.goal_functions.GoalFunction` object.
155
+ checkpoint_interval (:obj:`int`, `optional`, defaults to :obj:`None`):
156
+ If set, checkpoint will be saved after attacking every `N` examples. If :obj:`None` is passed, no checkpoints will be saved.
157
+ checkpoint_dir (:obj:`str`, `optional`, defaults to :obj:`"checkpoints"`):
158
+ The directory to save checkpoint files.
159
+ random_seed (:obj:`int`, `optional`, defaults to :obj:`765`):
160
+ Random seed for reproducibility.
161
+ parallel (:obj:`False`, `optional`, defaults to :obj:`False`):
162
+ If :obj:`True`, run attack using multiple CPUs/GPUs.
163
+ num_workers_per_device (:obj:`int`, `optional`, defaults to :obj:`1`):
164
+ Number of worker processes to run per device in parallel mode (i.e. :obj:`parallel=True`). For example, if you are using GPUs and :obj:`num_workers_per_device=2`,
165
+ then 2 processes will be running in each GPU.
166
+ log_to_txt (:obj:`str`, `optional`, defaults to :obj:`None`):
167
+ If set, save attack logs as a `.txt` file to the directory specified by this argument.
168
+ If the last part of the provided path ends with `.txt` extension, it is assumed to the desired path of the log file.
169
+ log_to_csv (:obj:`str`, `optional`, defaults to :obj:`None`):
170
+ If set, save attack logs as a CSV file to the directory specified by this argument.
171
+ If the last part of the provided path ends with `.csv` extension, it is assumed to the desired path of the log file.
172
+ csv_coloring_style (:obj:`str`, `optional`, defaults to :obj:`"file"`):
173
+ Method for choosing how to mark perturbed parts of the text. Options are :obj:`"file"`, :obj:`"plain"`, and :obj:`"html"`.
174
+ :obj:`"file"` wraps perturbed parts with double brackets :obj:`[[ <text> ]]` while :obj:`"plain"` does not mark the text in any way.
175
+ log_to_visdom (:obj:`dict`, `optional`, defaults to :obj:`None`):
176
+ If set, Visdom logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.VisdomLogger`.
177
+ Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
178
+ three keys and their corresponding values: :obj:`"env", "port", "hostname"`.
179
+ log_to_wandb(:obj:`dict`, `optional`, defaults to :obj:`None`):
180
+ If set, WandB logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.WeightsAndBiasesLogger`.
181
+ Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
182
+ key and its corresponding value: :obj:`"project"`.
183
+ disable_stdout (:obj:`bool`, `optional`, defaults to :obj:`False`):
184
+ Disable displaying individual attack results to stdout.
185
+ silent (:obj:`bool`, `optional`, defaults to :obj:`False`):
186
+ Disable all logging (except for errors). This is stronger than :obj:`disable_stdout`.
187
+ enable_advance_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`):
188
+ Enable calculation and display of optional advance post-hoc metrics like perplexity, grammar errors, etc.
189
+ """
190
+
191
+ num_examples: int = 10
192
+ num_successful_examples: int = None
193
+ num_examples_offset: int = 0
194
+ attack_n: bool = False
195
+ shuffle: bool = False
196
+ query_budget: int = None
197
+ checkpoint_interval: int = None
198
+ checkpoint_dir: str = "checkpoints"
199
+ random_seed: int = 765 # equivalent to sum((ord(c) for c in "TEXTATTACK"))
200
+ parallel: bool = False
201
+ num_workers_per_device: int = 1
202
+ log_to_txt: str = None
203
+ log_to_csv: str = None
204
+ log_summary_to_json: str = None
205
+ csv_coloring_style: str = "file"
206
+ log_to_visdom: dict = None
207
+ log_to_wandb: dict = None
208
+ disable_stdout: bool = False
209
+ silent: bool = False
210
+ enable_advance_metrics: bool = False
211
+ metrics: Optional[Dict] = None
212
+
213
+ def __post_init__(self):
214
+ if self.num_successful_examples:
215
+ self.num_examples = None
216
+ if self.num_examples:
217
+ assert (
218
+ self.num_examples >= 0 or self.num_examples == -1
219
+ ), "`num_examples` must be greater than or equal to 0 or equal to -1."
220
+ if self.num_successful_examples:
221
+ assert (
222
+ self.num_successful_examples >= 0
223
+ ), "`num_examples` must be greater than or equal to 0."
224
+
225
+ if self.query_budget:
226
+ assert self.query_budget > 0, "`query_budget` must be greater than 0."
227
+
228
+ if self.checkpoint_interval:
229
+ assert (
230
+ self.checkpoint_interval > 0
231
+ ), "`checkpoint_interval` must be greater than 0."
232
+
233
+ assert (
234
+ self.num_workers_per_device > 0
235
+ ), "`num_workers_per_device` must be greater than 0."
236
+
237
+ @classmethod
238
+ def _add_parser_args(cls, parser):
239
+ """Add listed args to command line parser."""
240
+ default_obj = cls()
241
+ num_ex_group = parser.add_mutually_exclusive_group(required=False)
242
+ num_ex_group.add_argument(
243
+ "--num-examples",
244
+ "-n",
245
+ type=int,
246
+ default=default_obj.num_examples,
247
+ help="The number of examples to process, -1 for entire dataset.",
248
+ )
249
+ num_ex_group.add_argument(
250
+ "--num-successful-examples",
251
+ type=int,
252
+ default=default_obj.num_successful_examples,
253
+ help="The number of successful adversarial examples we want.",
254
+ )
255
+ parser.add_argument(
256
+ "--num-examples-offset",
257
+ "-o",
258
+ type=int,
259
+ required=False,
260
+ default=default_obj.num_examples_offset,
261
+ help="The offset to start at in the dataset.",
262
+ )
263
+ parser.add_argument(
264
+ "--query-budget",
265
+ "-q",
266
+ type=int,
267
+ default=default_obj.query_budget,
268
+ help="The maximum number of model queries allowed per example attacked. Setting this overwrites the query budget set in `GoalFunction` object.",
269
+ )
270
+ parser.add_argument(
271
+ "--shuffle",
272
+ action="store_true",
273
+ default=default_obj.shuffle,
274
+ help="If `True`, shuffle the samples before we attack the dataset. Default is False.",
275
+ )
276
+ parser.add_argument(
277
+ "--attack-n",
278
+ action="store_true",
279
+ default=default_obj.attack_n,
280
+ help="Whether to run attack until `n` examples have been attacked (not skipped).",
281
+ )
282
+ parser.add_argument(
283
+ "--checkpoint-dir",
284
+ required=False,
285
+ type=str,
286
+ default=default_obj.checkpoint_dir,
287
+ help="The directory to save checkpoint files.",
288
+ )
289
+ parser.add_argument(
290
+ "--checkpoint-interval",
291
+ required=False,
292
+ type=int,
293
+ default=default_obj.checkpoint_interval,
294
+ help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
295
+ )
296
+ parser.add_argument(
297
+ "--random-seed",
298
+ default=default_obj.random_seed,
299
+ type=int,
300
+ help="Random seed for reproducibility.",
301
+ )
302
+ parser.add_argument(
303
+ "--parallel",
304
+ action="store_true",
305
+ default=default_obj.parallel,
306
+ help="Run attack using multiple GPUs.",
307
+ )
308
+ parser.add_argument(
309
+ "--num-workers-per-device",
310
+ default=default_obj.num_workers_per_device,
311
+ type=int,
312
+ help="Number of worker processes to run per device.",
313
+ )
314
+ parser.add_argument(
315
+ "--log-to-txt",
316
+ nargs="?",
317
+ default=default_obj.log_to_txt,
318
+ const="",
319
+ type=str,
320
+ help="Path to which to save attack logs as a text file. Set this argument if you want to save text logs. "
321
+ "If the last part of the path ends with `.txt` extension, the path is assumed to path for output file.",
322
+ )
323
+ parser.add_argument(
324
+ "--log-to-csv",
325
+ nargs="?",
326
+ default=default_obj.log_to_csv,
327
+ const="",
328
+ type=str,
329
+ help="Path to which to save attack logs as a CSV file. Set this argument if you want to save CSV logs. "
330
+ "If the last part of the path ends with `.csv` extension, the path is assumed to path for output file.",
331
+ )
332
+ parser.add_argument(
333
+ "--log-summary-to-json",
334
+ nargs="?",
335
+ default=default_obj.log_summary_to_json,
336
+ const="",
337
+ type=str,
338
+ help="Path to which to save attack summary as a JSON file. Set this argument if you want to save attack results summary in a JSON. "
339
+ "If the last part of the path ends with `.json` extension, the path is assumed to path for output file.",
340
+ )
341
+ parser.add_argument(
342
+ "--csv-coloring-style",
343
+ default=default_obj.csv_coloring_style,
344
+ type=str,
345
+ help='Method for choosing how to mark perturbed parts of the text in CSV logs. Options are "file" and "plain". '
346
+ '"file" wraps text with double brackets `[[ <text> ]]` while "plain" does not mark any text. Default is "file".',
347
+ )
348
+ parser.add_argument(
349
+ "--log-to-visdom",
350
+ nargs="?",
351
+ default=None,
352
+ const='{"env": "main", "port": 8097, "hostname": "localhost"}',
353
+ type=json.loads,
354
+ help="Set this argument if you want to log attacks to Visdom. The dictionary should have the following "
355
+ 'three keys and their corresponding values: `"env", "port", "hostname"`. '
356
+ 'Example for command line use: `--log-to-visdom {"env": "main", "port": 8097, "hostname": "localhost"}`.',
357
+ )
358
+ parser.add_argument(
359
+ "--log-to-wandb",
360
+ nargs="?",
361
+ default=None,
362
+ const='{"project": "textattack"}',
363
+ type=json.loads,
364
+ help="Set this argument if you want to log attacks to WandB. The dictionary should have the following "
365
+ 'key and its corresponding value: `"project"`. '
366
+ 'Example for command line use: `--log-to-wandb {"project": "textattack"}`.',
367
+ )
368
+ parser.add_argument(
369
+ "--disable-stdout",
370
+ action="store_true",
371
+ default=default_obj.disable_stdout,
372
+ help="Disable logging attack results to stdout",
373
+ )
374
+ parser.add_argument(
375
+ "--silent",
376
+ action="store_true",
377
+ default=default_obj.silent,
378
+ help="Disable all logging",
379
+ )
380
+ parser.add_argument(
381
+ "--enable-advance-metrics",
382
+ action="store_true",
383
+ default=default_obj.enable_advance_metrics,
384
+ help="Enable calculation and display of optional advance post-hoc metrics like perplexity, USE distance, etc.",
385
+ )
386
+
387
+ return parser
388
+
389
+ @classmethod
390
+ def create_loggers_from_args(cls, args):
391
+ """Creates AttackLogManager from an AttackArgs object."""
392
+ assert isinstance(
393
+ args, cls
394
+ ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."
395
+
396
+ # Create logger
397
+ attack_log_manager = textattack.loggers.AttackLogManager(args.metrics)
398
+
399
+ # Get current time for file naming
400
+ timestamp = time.strftime("%Y-%m-%d-%H-%M")
401
+
402
+ # if '--log-to-txt' specified with arguments
403
+ if args.log_to_txt is not None:
404
+ if args.log_to_txt.lower().endswith(".txt"):
405
+ txt_file_path = args.log_to_txt
406
+ else:
407
+ txt_file_path = os.path.join(args.log_to_txt, f"{timestamp}-log.txt")
408
+
409
+ dir_path = os.path.dirname(txt_file_path)
410
+ dir_path = dir_path if dir_path else "."
411
+ if not os.path.exists(dir_path):
412
+ os.makedirs(os.path.dirname(txt_file_path))
413
+
414
+ color_method = "file"
415
+ attack_log_manager.add_output_file(txt_file_path, color_method)
416
+
417
+ # if '--log-to-csv' specified with arguments
418
+ if args.log_to_csv is not None:
419
+ if args.log_to_csv.lower().endswith(".csv"):
420
+ csv_file_path = args.log_to_csv
421
+ else:
422
+ csv_file_path = os.path.join(args.log_to_csv, f"{timestamp}-log.csv")
423
+
424
+ dir_path = os.path.dirname(csv_file_path)
425
+ dir_path = dir_path if dir_path else "."
426
+ if not os.path.exists(dir_path):
427
+ os.makedirs(dir_path)
428
+
429
+ color_method = (
430
+ None if args.csv_coloring_style == "plain" else args.csv_coloring_style
431
+ )
432
+ attack_log_manager.add_output_csv(csv_file_path, color_method)
433
+
434
+ # if '--log-summary-to-json' specified with arguments
435
+ if args.log_summary_to_json is not None:
436
+ if args.log_summary_to_json.lower().endswith(".json"):
437
+ summary_json_file_path = args.log_summary_to_json
438
+ else:
439
+ summary_json_file_path = os.path.join(
440
+ args.log_summary_to_json, f"{timestamp}-attack_summary_log.json"
441
+ )
442
+
443
+ dir_path = os.path.dirname(summary_json_file_path)
444
+ dir_path = dir_path if dir_path else "."
445
+ if not os.path.exists(dir_path):
446
+ os.makedirs(os.path.dirname(summary_json_file_path))
447
+
448
+ attack_log_manager.add_output_summary_json(summary_json_file_path)
449
+
450
+ # Visdom
451
+ if args.log_to_visdom is not None:
452
+ attack_log_manager.enable_visdom(**args.log_to_visdom)
453
+
454
+ # Weights & Biases
455
+ if args.log_to_wandb is not None:
456
+ attack_log_manager.enable_wandb(**args.log_to_wandb)
457
+
458
+ # Stdout
459
+ if not args.disable_stdout and not sys.stdout.isatty():
460
+ attack_log_manager.disable_color()
461
+ elif not args.disable_stdout:
462
+ attack_log_manager.enable_stdout()
463
+
464
+ return attack_log_manager
465
+
466
+
467
+ @dataclass
468
+ class _CommandLineAttackArgs:
469
+ """Attack args for command line execution. This requires more arguments to
470
+ create ``Attack`` object as specified.
471
+
472
+ Args:
473
+ transformation (:obj:`str`, `optional`, defaults to :obj:`"word-swap-embedding"`):
474
+ Name of transformation to use.
475
+ constraints (:obj:`list[str]`, `optional`, defaults to :obj:`["repeat", "stopword"]`):
476
+ List of names of constraints to use.
477
+ goal_function (:obj:`str`, `optional`, defaults to :obj:`"untargeted-classification"`):
478
+ Name of goal function to use.
479
+ search_method (:obj:`str`, `optional`, defualts to :obj:`"greedy-word-wir"`):
480
+ Name of search method to use.
481
+ attack_recipe (:obj:`str`, `optional`, defaults to :obj:`None`):
482
+ Name of attack recipe to use.
483
+ .. note::
484
+ Setting this overrides any previous selection of transformation, constraints, goal function, and search method.
485
+ attack_from_file (:obj:`str`, `optional`, defaults to :obj:`None`):
486
+ Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.
487
+ .. note::
488
+ If this is set, it overrides any previous selection of transformation, constraints, goal function, and search method
489
+ interactive (:obj:`bool`, `optional`, defaults to :obj:`False`):
490
+ If `True`, carry attack in interactive mode.
491
+ parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
492
+ If `True`, attack in parallel.
493
+ model_batch_size (:obj:`int`, `optional`, defaults to :obj:`32`):
494
+ The batch size for making queries to the victim model.
495
+ model_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
496
+ The maximum number of items to keep in the model results cache at once.
497
+ constraint-cache-size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
498
+ The maximum number of items to keep in the constraints cache at once.
499
+ """
500
+
501
+ transformation: str = "word-swap-embedding"
502
+ constraints: list = field(default_factory=lambda: ["repeat", "stopword"])
503
+ goal_function: str = "untargeted-classification"
504
+ search_method: str = "greedy-word-wir"
505
+ attack_recipe: str = None
506
+ attack_from_file: str = None
507
+ interactive: bool = False
508
+ parallel: bool = False
509
+ model_batch_size: int = 32
510
+ model_cache_size: int = 2**18
511
+ constraint_cache_size: int = 2**18
512
+
513
+ @classmethod
514
+ def _add_parser_args(cls, parser):
515
+ """Add listed args to command line parser."""
516
+ default_obj = cls()
517
+ transformation_names = set(BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(
518
+ WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys()
519
+ )
520
+ parser.add_argument(
521
+ "--transformation",
522
+ type=str,
523
+ required=False,
524
+ default=default_obj.transformation,
525
+ help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
526
+ + str(transformation_names),
527
+ )
528
+ parser.add_argument(
529
+ "--constraints",
530
+ type=str,
531
+ required=False,
532
+ nargs="*",
533
+ default=default_obj.constraints,
534
+ help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
535
+ + str(CONSTRAINT_CLASS_NAMES.keys()),
536
+ )
537
+ goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys())
538
+ parser.add_argument(
539
+ "--goal-function",
540
+ "-g",
541
+ default=default_obj.goal_function,
542
+ help=f"The goal function to use. choices: {goal_function_choices}",
543
+ )
544
+ attack_group = parser.add_mutually_exclusive_group(required=False)
545
+ search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())
546
+ attack_group.add_argument(
547
+ "--search-method",
548
+ "--search",
549
+ "-s",
550
+ type=str,
551
+ required=False,
552
+ default=default_obj.search_method,
553
+ help=f"The search method to use. choices: {search_choices}",
554
+ )
555
+ attack_group.add_argument(
556
+ "--attack-recipe",
557
+ "--recipe",
558
+ "-r",
559
+ type=str,
560
+ required=False,
561
+ default=default_obj.attack_recipe,
562
+ help="full attack recipe (overrides provided goal function, transformation & constraints)",
563
+ choices=ATTACK_RECIPE_NAMES.keys(),
564
+ )
565
+ attack_group.add_argument(
566
+ "--attack-from-file",
567
+ type=str,
568
+ required=False,
569
+ default=default_obj.attack_from_file,
570
+ help="Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.",
571
+ )
572
+ parser.add_argument(
573
+ "--interactive",
574
+ action="store_true",
575
+ default=default_obj.interactive,
576
+ help="Whether to run attacks interactively.",
577
+ )
578
+ parser.add_argument(
579
+ "--model-batch-size",
580
+ type=int,
581
+ default=default_obj.model_batch_size,
582
+ help="The batch size for making calls to the model.",
583
+ )
584
+ parser.add_argument(
585
+ "--model-cache-size",
586
+ type=int,
587
+ default=default_obj.model_cache_size,
588
+ help="The maximum number of items to keep in the model results cache at once.",
589
+ )
590
+ parser.add_argument(
591
+ "--constraint-cache-size",
592
+ type=int,
593
+ default=default_obj.constraint_cache_size,
594
+ help="The maximum number of items to keep in the constraints cache at once.",
595
+ )
596
+
597
+ return parser
598
+
599
+ @classmethod
600
+ def _create_transformation_from_args(cls, args, model_wrapper):
601
+ """Create `Transformation` based on provided `args` and
602
+ `model_wrapper`."""
603
+
604
+ transformation_name = args.transformation
605
+ if ARGS_SPLIT_TOKEN in transformation_name:
606
+ transformation_name, params = transformation_name.split(ARGS_SPLIT_TOKEN)
607
+
608
+ if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
609
+ transformation = eval(
610
+ f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model, {params})"
611
+ )
612
+ elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
613
+ transformation = eval(
614
+ f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}({params})"
615
+ )
616
+ else:
617
+ raise ValueError(
618
+ f"Error: unsupported transformation {transformation_name}"
619
+ )
620
+ else:
621
+ if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
622
+ transformation = eval(
623
+ f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model)"
624
+ )
625
+ elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
626
+ transformation = eval(
627
+ f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}()"
628
+ )
629
+ else:
630
+ raise ValueError(
631
+ f"Error: unsupported transformation {transformation_name}"
632
+ )
633
+ return transformation
634
+
635
+ @classmethod
636
+ def _create_goal_function_from_args(cls, args, model_wrapper):
637
+ """Create `GoalFunction` based on provided `args` and
638
+ `model_wrapper`."""
639
+
640
+ goal_function = args.goal_function
641
+ if ARGS_SPLIT_TOKEN in goal_function:
642
+ goal_function_name, params = goal_function.split(ARGS_SPLIT_TOKEN)
643
+ if goal_function_name not in GOAL_FUNCTION_CLASS_NAMES:
644
+ raise ValueError(
645
+ f"Error: unsupported goal_function {goal_function_name}"
646
+ )
647
+ goal_function = eval(
648
+ f"{GOAL_FUNCTION_CLASS_NAMES[goal_function_name]}(model_wrapper, {params})"
649
+ )
650
+ elif goal_function in GOAL_FUNCTION_CLASS_NAMES:
651
+ goal_function = eval(
652
+ f"{GOAL_FUNCTION_CLASS_NAMES[goal_function]}(model_wrapper)"
653
+ )
654
+ else:
655
+ raise ValueError(f"Error: unsupported goal_function {goal_function}")
656
+ if args.query_budget:
657
+ goal_function.query_budget = args.query_budget
658
+ goal_function.model_cache_size = args.model_cache_size
659
+ goal_function.batch_size = args.model_batch_size
660
+ return goal_function
661
+
662
+ @classmethod
663
+ def _create_constraints_from_args(cls, args):
664
+ """Create list of `Constraints` based on provided `args`."""
665
+
666
+ if not args.constraints:
667
+ return []
668
+
669
+ _constraints = []
670
+ for constraint in args.constraints:
671
+ if ARGS_SPLIT_TOKEN in constraint:
672
+ constraint_name, params = constraint.split(ARGS_SPLIT_TOKEN)
673
+ if constraint_name not in CONSTRAINT_CLASS_NAMES:
674
+ raise ValueError(f"Error: unsupported constraint {constraint_name}")
675
+ _constraints.append(
676
+ eval(f"{CONSTRAINT_CLASS_NAMES[constraint_name]}({params})")
677
+ )
678
+ elif constraint in CONSTRAINT_CLASS_NAMES:
679
+ _constraints.append(eval(f"{CONSTRAINT_CLASS_NAMES[constraint]}()"))
680
+ else:
681
+ raise ValueError(f"Error: unsupported constraint {constraint}")
682
+
683
+ return _constraints
684
+
685
+ @classmethod
686
+ def _create_attack_from_args(cls, args, model_wrapper):
687
+ """Given ``CommandLineArgs`` and ``ModelWrapper``, return specified
688
+ ``Attack`` object."""
689
+
690
+ assert isinstance(
691
+ args, cls
692
+ ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."
693
+
694
+ if args.attack_recipe:
695
+ if ARGS_SPLIT_TOKEN in args.attack_recipe:
696
+ recipe_name, params = args.attack_recipe.split(ARGS_SPLIT_TOKEN)
697
+ if recipe_name not in ATTACK_RECIPE_NAMES:
698
+ raise ValueError(f"Error: unsupported recipe {recipe_name}")
699
+ recipe = eval(
700
+ f"{ATTACK_RECIPE_NAMES[recipe_name]}.build(model_wrapper, {params})"
701
+ )
702
+ elif args.attack_recipe in ATTACK_RECIPE_NAMES:
703
+ recipe = eval(
704
+ f"{ATTACK_RECIPE_NAMES[args.attack_recipe]}.build(model_wrapper)"
705
+ )
706
+ else:
707
+ raise ValueError(f"Invalid recipe {args.attack_recipe}")
708
+ if args.query_budget:
709
+ recipe.goal_function.query_budget = args.query_budget
710
+ recipe.goal_function.model_cache_size = args.model_cache_size
711
+ recipe.constraint_cache_size = args.constraint_cache_size
712
+ return recipe
713
+ elif args.attack_from_file:
714
+ if ARGS_SPLIT_TOKEN in args.attack_from_file:
715
+ attack_file, attack_name = args.attack_from_file.split(ARGS_SPLIT_TOKEN)
716
+ else:
717
+ attack_file, attack_name = args.attack_from_file, "attack"
718
+ attack_module = load_module_from_file(attack_file)
719
+ if not hasattr(attack_module, attack_name):
720
+ raise ValueError(
721
+ f"Loaded `{attack_file}` but could not find `{attack_name}`."
722
+ )
723
+ attack_func = getattr(attack_module, attack_name)
724
+ return attack_func(model_wrapper)
725
+ else:
726
+ goal_function = cls._create_goal_function_from_args(args, model_wrapper)
727
+ transformation = cls._create_transformation_from_args(args, model_wrapper)
728
+ constraints = cls._create_constraints_from_args(args)
729
+ if ARGS_SPLIT_TOKEN in args.search_method:
730
+ search_name, params = args.search_method.split(ARGS_SPLIT_TOKEN)
731
+ if search_name not in SEARCH_METHOD_CLASS_NAMES:
732
+ raise ValueError(f"Error: unsupported search {search_name}")
733
+ search_method = eval(
734
+ f"{SEARCH_METHOD_CLASS_NAMES[search_name]}({params})"
735
+ )
736
+ elif args.search_method in SEARCH_METHOD_CLASS_NAMES:
737
+ search_method = eval(
738
+ f"{SEARCH_METHOD_CLASS_NAMES[args.search_method]}()"
739
+ )
740
+ else:
741
+ raise ValueError(f"Error: unsupported attack {args.search_method}")
742
+
743
+ return Attack(
744
+ goal_function,
745
+ constraints,
746
+ transformation,
747
+ search_method,
748
+ constraint_cache_size=args.constraint_cache_size,
749
+ )
750
+
751
+
752
+ # This neat trick allows use to reorder the arguments to avoid TypeErrors commonly found when inheriting dataclass.
753
+ # https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
754
+ @dataclass
755
+ class CommandLineAttackArgs(AttackArgs, _CommandLineAttackArgs, DatasetArgs, ModelArgs):
756
+ @classmethod
757
+ def _add_parser_args(cls, parser):
758
+ """Add listed args to command line parser."""
759
+ parser = ModelArgs._add_parser_args(parser)
760
+ parser = DatasetArgs._add_parser_args(parser)
761
+ parser = _CommandLineAttackArgs._add_parser_args(parser)
762
+ parser = AttackArgs._add_parser_args(parser)
763
+ return parser
textattack/attack_recipes/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """.. _attack_recipes:
2
+
3
+ Attack Recipes Package:
4
+ ========================
5
+
6
+ We provide a number of pre-built attack recipes, which correspond to attacks from the literature. To run an attack recipe from the command line, run::
7
+
8
+ textattack attack --recipe [recipe_name]
9
+
10
+ To initialize an attack in Python script, use::
11
+
12
+ <recipe name>.build(model_wrapper)
13
+
14
+ For example, ``attack = InputReductionFeng2018.build(model)`` creates `attack`, an object of type ``Attack`` with the goal function, transformation, constraints, and search method specified in that paper. This object can then be used just like any other attack; for example, by calling ``attack.attack_dataset``.
15
+
16
+ TextAttack supports the following attack recipes (each recipe's documentation contains a link to the corresponding paper):
17
+
18
+ .. contents:: :local:
19
+ """
20
+
21
+ from .attack_recipe import AttackRecipe
22
+
23
+ from .a2t_yoo_2021 import A2TYoo2021
24
+ from .bae_garg_2019 import BAEGarg2019
25
+ from .bert_attack_li_2020 import BERTAttackLi2020
26
+ from .genetic_algorithm_alzantot_2018 import GeneticAlgorithmAlzantot2018
27
+ from .faster_genetic_algorithm_jia_2019 import FasterGeneticAlgorithmJia2019
28
+ from .deepwordbug_gao_2018 import DeepWordBugGao2018
29
+ from .hotflip_ebrahimi_2017 import HotFlipEbrahimi2017
30
+ from .input_reduction_feng_2018 import InputReductionFeng2018
31
+ from .kuleshov_2017 import Kuleshov2017
32
+ from .morpheus_tan_2020 import MorpheusTan2020
33
+ from .seq2sick_cheng_2018_blackbox import Seq2SickCheng2018BlackBox
34
+ from .textbugger_li_2018 import TextBuggerLi2018
35
+ from .textfooler_jin_2019 import TextFoolerJin2019
36
+ from .pwws_ren_2019 import PWWSRen2019
37
+ from .iga_wang_2019 import IGAWang2019
38
+ from .pruthi_2019 import Pruthi2019
39
+ from .pso_zang_2020 import PSOZang2020
40
+ from .checklist_ribeiro_2020 import CheckList2020
41
+ from .clare_li_2020 import CLARE2020
42
+ from .french_recipe import FrenchRecipe
43
+ from .spanish_recipe import SpanishRecipe
textattack/attack_recipes/a2t_yoo_2021.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A2T (A2T: Attack for Adversarial Training Recipe)
3
+ ==================================================
4
+
5
+ """
6
+
7
+ from textattack import Attack
8
+ from textattack.constraints.grammaticality import PartOfSpeech
9
+ from textattack.constraints.pre_transformation import (
10
+ InputColumnModification,
11
+ MaxModificationRate,
12
+ RepeatModification,
13
+ StopwordModification,
14
+ )
15
+ from textattack.constraints.semantics import WordEmbeddingDistance
16
+ from textattack.constraints.semantics.sentence_encoders import BERT
17
+ from textattack.goal_functions import UntargetedClassification
18
+ from textattack.search_methods import GreedyWordSwapWIR
19
+ from textattack.transformations import WordSwapEmbedding, WordSwapMaskedLM
20
+
21
+ from .attack_recipe import AttackRecipe
22
+
23
+
24
+ class A2TYoo2021(AttackRecipe):
25
+ """Towards Improving Adversarial Training of NLP Models.
26
+
27
+ (Yoo et al., 2021)
28
+
29
+ https://arxiv.org/abs/2109.00544
30
+ """
31
+
32
+ @staticmethod
33
+ def build(model_wrapper, mlm=False):
34
+ """Build attack recipe.
35
+
36
+ Args:
37
+ model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`):
38
+ Model wrapper containing both the model and the tokenizer.
39
+ mlm (:obj:`bool`, `optional`, defaults to :obj:`False`):
40
+ If :obj:`True`, load `A2T-MLM` attack. Otherwise, load regular `A2T` attack.
41
+
42
+ Returns:
43
+ :class:`~textattack.Attack`: A2T attack.
44
+ """
45
+ constraints = [RepeatModification(), StopwordModification()]
46
+ input_column_modification = InputColumnModification(
47
+ ["premise", "hypothesis"], {"premise"}
48
+ )
49
+ constraints.append(input_column_modification)
50
+ constraints.append(PartOfSpeech(allow_verb_noun_swap=False))
51
+ constraints.append(MaxModificationRate(max_rate=0.1, min_threshold=4))
52
+ sent_encoder = BERT(
53
+ model_name="stsb-distilbert-base", threshold=0.9, metric="cosine"
54
+ )
55
+ constraints.append(sent_encoder)
56
+
57
+ if mlm:
58
+ transformation = transformation = WordSwapMaskedLM(
59
+ method="bae", max_candidates=20, min_confidence=0.0, batch_size=16
60
+ )
61
+ else:
62
+ transformation = WordSwapEmbedding(max_candidates=20)
63
+ constraints.append(WordEmbeddingDistance(min_cos_sim=0.8))
64
+
65
+ #
66
+ # Goal is untargeted classification
67
+ #
68
+ goal_function = UntargetedClassification(model_wrapper, model_batch_size=32)
69
+ #
70
+ # Greedily swap words with "Word Importance Ranking".
71
+ #
72
+ search_method = GreedyWordSwapWIR(wir_method="gradient")
73
+
74
+ return Attack(goal_function, constraints, transformation, search_method)
textattack/attack_recipes/attack_recipe.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Attack Recipe Class
3
+ ========================
4
+
5
+ """
6
+
7
+ from abc import ABC, abstractmethod
8
+
9
+ from textattack import Attack
10
+
11
+
12
+ class AttackRecipe(Attack, ABC):
13
+ """A recipe for building an NLP adversarial attack from the literature."""
14
+
15
+ @staticmethod
16
+ @abstractmethod
17
+ def build(model_wrapper, **kwargs):
18
+ """Creates pre-built :class:`~textattack.Attack` that correspond to
19
+ attacks from the literature.
20
+
21
+ Args:
22
+ model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`):
23
+ :class:`~textattack.models.wrappers.ModelWrapper` that contains the victim model and tokenizer.
24
+ This is passed to :class:`~textattack.goal_functions.GoalFunction` when constructing the attack.
25
+ kwargs:
26
+ Additional keyword arguments.
27
+ Returns:
28
+ :class:`~textattack.Attack`
29
+ """
30
+ raise NotImplementedError()
textattack/attack_recipes/bae_garg_2019.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BAE (BAE: BERT-Based Adversarial Examples)
3
+ ============================================
4
+
5
+ """
6
+ from textattack.constraints.grammaticality import PartOfSpeech
7
+ from textattack.constraints.pre_transformation import (
8
+ RepeatModification,
9
+ StopwordModification,
10
+ )
11
+ from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
12
+ from textattack.goal_functions import UntargetedClassification
13
+ from textattack.search_methods import GreedyWordSwapWIR
14
+ from textattack.transformations import WordSwapMaskedLM
15
+
16
+ from .attack_recipe import AttackRecipe
17
+
18
+
19
+ class BAEGarg2019(AttackRecipe):
20
+ """Siddhant Garg and Goutham Ramakrishnan, 2019.
21
+
22
+ BAE: BERT-based Adversarial Examples for Text Classification.
23
+
24
+ https://arxiv.org/pdf/2004.01970
25
+
26
+ This is "attack mode" 1 from the paper, BAE-R, word replacement.
27
+
28
+ We present 4 attack modes for BAE based on the
29
+ R and I operations, where for each token t in S:
30
+ • BAE-R: Replace token t (See Algorithm 1)
31
+ • BAE-I: Insert a token to the left or right of t
32
+ • BAE-R/I: Either replace token t or insert a
33
+ token to the left or right of t
34
+ • BAE-R+I: First replace token t, then insert a
35
+ token to the left or right of t
36
+ """
37
+
38
+ @staticmethod
39
+ def build(model_wrapper):
40
+ # "In this paper, we present a simple yet novel technique: BAE (BERT-based
41
+ # Adversarial Examples), which uses a language model (LM) for token
42
+ # replacement to best fit the overall context. We perturb an input sentence
43
+ # by either replacing a token or inserting a new token in the sentence, by
44
+ # means of masking a part of the input and using a LM to fill in the mask."
45
+ #
46
+ # We only consider the top K=50 synonyms from the MLM predictions.
47
+ #
48
+ # [from email correspondance with the author]
49
+ # "When choosing the top-K candidates from the BERT masked LM, we filter out
50
+ # the sub-words and only retain the whole words (by checking if they are
51
+ # present in the GloVE vocabulary)"
52
+ #
53
+ transformation = WordSwapMaskedLM(
54
+ method="bae", max_candidates=50, min_confidence=0.0
55
+ )
56
+ #
57
+ # Don't modify the same word twice or stopwords.
58
+ #
59
+ constraints = [RepeatModification(), StopwordModification()]
60
+
61
+ # For the R operations we add an additional check for
62
+ # grammatical correctness of the generated adversarial example by filtering
63
+ # out predicted tokens that do not form the same part of speech (POS) as the
64
+ # original token t_i in the sentence.
65
+ constraints.append(PartOfSpeech(allow_verb_noun_swap=True))
66
+
67
+ # "To ensure semantic similarity on introducing perturbations in the input
68
+ # text, we filter the set of top-K masked tokens (K is a pre-defined
69
+ # constant) predicted by BERT-MLM using a Universal Sentence Encoder (USE)
70
+ # (Cer et al., 2018)-based sentence similarity scorer."
71
+ #
72
+ # "[We] set a threshold of 0.8 for the cosine similarity between USE-based
73
+ # embeddings of the adversarial and input text."
74
+ #
75
+ # [from email correspondence with the author]
76
+ # "For a fair comparison of the benefits of using a BERT-MLM in our paper,
77
+ # we retained the majority of TextFooler's specifications. Thus we:
78
+ # 1. Use the USE for comparison within a window of size 15 around the word
79
+ # being replaced/inserted.
80
+ # 2. Set the similarity score threshold to 0.1 for inputs shorter than the
81
+ # window size (this translates roughly to almost always accepting the new text).
82
+ # 3. Perform the USE similarity thresholding of 0.8 with respect to the text
83
+ # just before the replacement/insertion and not the original text (For
84
+ # example: at the 3rd R/I operation, we compute the USE score on a window
85
+ # of size 15 of the text obtained after the first 2 R/I operations and not
86
+ # the original text).
87
+ # ...
88
+ # To address point (3) from above, compare the USE with the original text
89
+ # at each iteration instead of the current one (While doing this change
90
+ # for the R-operation is trivial, doing it for the I-operation with the
91
+ # window based USE comparison might be more involved)."
92
+ #
93
+ # Finally, since the BAE code is based on the TextFooler code, we need to
94
+ # adjust the threshold to account for the missing / pi in the cosine
95
+ # similarity comparison. So the final threshold is 1 - (1 - 0.8) / pi
96
+ # = 1 - (0.2 / pi) = 0.936338023.
97
+ use_constraint = UniversalSentenceEncoder(
98
+ threshold=0.936338023,
99
+ metric="cosine",
100
+ compare_against_original=True,
101
+ window_size=15,
102
+ skip_text_shorter_than_window=True,
103
+ )
104
+ constraints.append(use_constraint)
105
+ #
106
+ # Goal is untargeted classification.
107
+ #
108
+ goal_function = UntargetedClassification(model_wrapper)
109
+ #
110
+ # "We estimate the token importance Ii of each token
111
+ # t_i ∈ S = [t1, . . . , tn], by deleting ti from S and computing the
112
+ # decrease in probability of predicting the correct label y, similar
113
+ # to (Jin et al., 2019).
114
+ #
115
+ # • "If there are multiple tokens can cause C to misclassify S when they
116
+ # replace the mask, we choose the token which makes Sadv most similar to
117
+ # the original S based on the USE score."
118
+ # • "If no token causes misclassification, we choose the perturbation that
119
+ # decreases the prediction probability P(C(Sadv)=y) the most."
120
+ #
121
+ search_method = GreedyWordSwapWIR(wir_method="delete")
122
+
123
+ return BAEGarg2019(goal_function, constraints, transformation, search_method)