anonymous8 commited on
Commit
d65ddc0
·
1 Parent(s): ecdc8b8
Files changed (35) hide show
  1. anonymous_demo/__init__.py +2 -2
  2. anonymous_demo/core/tad/classic/__bert__/dataset_utils/data_utils_for_inference.py +47 -42
  3. anonymous_demo/core/tad/classic/__bert__/models/tad_bert.py +12 -9
  4. anonymous_demo/core/tad/prediction/tad_classifier.py +305 -177
  5. anonymous_demo/functional/checkpoint/checkpoint_manager.py +4 -5
  6. anonymous_demo/functional/config/config_manager.py +10 -12
  7. anonymous_demo/functional/config/tad_config_manager.py +132 -124
  8. anonymous_demo/functional/dataset/__init__.py +1 -1
  9. anonymous_demo/functional/dataset/dataset_manager.py +30 -6
  10. anonymous_demo/network/lcf_pooler.py +4 -2
  11. anonymous_demo/network/lsa.py +34 -13
  12. anonymous_demo/network/sa_encoder.py +57 -17
  13. anonymous_demo/utils/demo_utils.py +86 -48
  14. anonymous_demo/utils/logger.py +5 -5
  15. app.py +31 -20
  16. requirements.txt +1 -1
  17. textattack/attack_recipes/morpheus_tan_2020.py +0 -1
  18. textattack/attack_recipes/seq2sick_cheng_2018_blackbox.py +0 -1
  19. textattack/attacker.py +7 -5
  20. textattack/commands/augment_command.py +0 -1
  21. textattack/commands/eval_model_command.py +1 -1
  22. textattack/constraints/overlap/max_words_perturbed.py +0 -1
  23. textattack/goal_function_results/classification_goal_function_result.py +0 -1
  24. textattack/goal_function_results/text_to_text_goal_function_result.py +0 -1
  25. textattack/loggers/weights_and_biases_logger.py +0 -1
  26. textattack/metrics/quality_metrics/perplexity.py +0 -1
  27. textattack/models/wrappers/demo_model_wrapper.py +6 -6
  28. textattack/reactive_defense/reactive_defender.py +0 -1
  29. textattack/reactive_defense/tad_reactive_defender.py +12 -9
  30. textattack/search_methods/greedy_word_swap_wir.py +0 -1
  31. textattack/shared/validators.py +4 -1
  32. textattack/trainer.py +2 -1
  33. textattack/training_args.py +0 -1
  34. textattack/transformations/word_swaps/word_swap_change_name.py +0 -1
  35. textattack/transformations/word_swaps/word_swap_change_number.py +1 -1
anonymous_demo/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- __version__ = '1.0.0'
2
 
3
- __name__ = 'anonymous_demo'
4
 
5
  from anonymous_demo.functional import TADCheckpointManager
 
1
+ __version__ = "1.0.0"
2
 
3
+ __name__ = "anonymous_demo"
4
 
5
  from anonymous_demo.functional import TADCheckpointManager
anonymous_demo/core/tad/classic/__bert__/dataset_utils/data_utils_for_inference.py CHANGED
@@ -6,26 +6,30 @@ from transformers import AutoTokenizer
6
 
7
  class Tokenizer4Pretraining:
8
  def __init__(self, max_seq_len, opt, **kwargs):
9
- if kwargs.pop('offline', False):
10
- self.tokenizer = AutoTokenizer.from_pretrained(find_cwd_dir(opt.pretrained_bert.split('/')[-1]),
11
- do_lower_case='uncased' in opt.pretrained_bert)
 
 
12
  else:
13
- self.tokenizer = AutoTokenizer.from_pretrained(opt.pretrained_bert,
14
- do_lower_case='uncased' in opt.pretrained_bert)
 
15
  self.max_seq_len = max_seq_len
16
 
17
- def text_to_sequence(self, text, reverse=False, padding='post', truncating='post'):
18
-
19
- return self.tokenizer.encode(text, truncation=True, padding='max_length', max_length=self.max_seq_len,
20
- return_tensors='pt')
 
 
 
 
21
 
22
 
23
  class BERTTADDataset(Dataset):
24
-
25
  def __init__(self, tokenizer, opt):
26
- self.bert_baseline_input_colses = {
27
- 'bert': ['text_bert_indices']
28
- }
29
 
30
  self.tokenizer = tokenizer
31
  self.opt = opt
@@ -40,33 +44,39 @@ class BERTTADDataset(Dataset):
40
  def process_data(self, samples, ignore_error=True):
41
  all_data = []
42
  if len(samples) > 100:
43
- it = tqdm.tqdm(samples, postfix='preparing text classification inference dataloader...')
 
 
44
  else:
45
  it = samples
46
  for text in it:
47
  try:
48
  # handle for empty lines in inference datasets
49
- if text is None or '' == text.strip():
50
- raise RuntimeError('Invalid Input!')
51
 
52
- if '!ref!' in text:
53
- text, _, labels = text.strip().partition('!ref!')
54
  text = text.strip()
55
- if labels.count(',') == 2:
56
- label, is_adv, adv_train_label = labels.strip().split(',')
57
- label, is_adv, adv_train_label = label.strip(), is_adv.strip(), adv_train_label.strip()
58
- elif labels.count(',') == 1:
59
- label, is_adv = labels.strip().split(',')
 
 
 
 
60
  label, is_adv = label.strip(), is_adv.strip()
61
- adv_train_label = '-100'
62
- elif labels.count(',') == 0:
63
  label = labels.strip()
64
- adv_train_label = '-100'
65
- is_adv = '-100'
66
  else:
67
- label = '-100'
68
- adv_train_label = '-100'
69
- is_adv = '-100'
70
 
71
  label = int(label)
72
  adv_train_label = int(adv_train_label)
@@ -78,19 +88,14 @@ class BERTTADDataset(Dataset):
78
  adv_train_label = -100
79
  is_adv = -100
80
 
81
- text_indices = self.tokenizer.text_to_sequence('{}'.format(text))
82
 
83
  data = {
84
- 'text_bert_indices': text_indices[0],
85
-
86
- 'text_raw': text,
87
-
88
- 'label': label,
89
-
90
- 'adv_train_label': adv_train_label,
91
-
92
- 'is_adv': is_adv,
93
-
94
  # 'label': self.opt.label_to_index.get(label, -100) if isinstance(label, str) else label,
95
  #
96
  # 'adv_train_label': self.opt.adv_train_label_to_index.get(adv_train_label, -100) if isinstance(adv_train_label, str) else adv_train_label,
@@ -102,7 +107,7 @@ class BERTTADDataset(Dataset):
102
 
103
  except Exception as e:
104
  if ignore_error:
105
- print('Ignore error while processing:', text)
106
  else:
107
  raise e
108
 
 
6
 
7
  class Tokenizer4Pretraining:
8
  def __init__(self, max_seq_len, opt, **kwargs):
9
+ if kwargs.pop("offline", False):
10
+ self.tokenizer = AutoTokenizer.from_pretrained(
11
+ find_cwd_dir(opt.pretrained_bert.split("/")[-1]),
12
+ do_lower_case="uncased" in opt.pretrained_bert,
13
+ )
14
  else:
15
+ self.tokenizer = AutoTokenizer.from_pretrained(
16
+ opt.pretrained_bert, do_lower_case="uncased" in opt.pretrained_bert
17
+ )
18
  self.max_seq_len = max_seq_len
19
 
20
+ def text_to_sequence(self, text, reverse=False, padding="post", truncating="post"):
21
+ return self.tokenizer.encode(
22
+ text,
23
+ truncation=True,
24
+ padding="max_length",
25
+ max_length=self.max_seq_len,
26
+ return_tensors="pt",
27
+ )
28
 
29
 
30
  class BERTTADDataset(Dataset):
 
31
  def __init__(self, tokenizer, opt):
32
+ self.bert_baseline_input_colses = {"bert": ["text_bert_indices"]}
 
 
33
 
34
  self.tokenizer = tokenizer
35
  self.opt = opt
 
44
  def process_data(self, samples, ignore_error=True):
45
  all_data = []
46
  if len(samples) > 100:
47
+ it = tqdm.tqdm(
48
+ samples, postfix="preparing text classification inference dataloader..."
49
+ )
50
  else:
51
  it = samples
52
  for text in it:
53
  try:
54
  # handle for empty lines in inference datasets
55
+ if text is None or "" == text.strip():
56
+ raise RuntimeError("Invalid Input!")
57
 
58
+ if "!ref!" in text:
59
+ text, _, labels = text.strip().partition("!ref!")
60
  text = text.strip()
61
+ if labels.count(",") == 2:
62
+ label, is_adv, adv_train_label = labels.strip().split(",")
63
+ label, is_adv, adv_train_label = (
64
+ label.strip(),
65
+ is_adv.strip(),
66
+ adv_train_label.strip(),
67
+ )
68
+ elif labels.count(",") == 1:
69
+ label, is_adv = labels.strip().split(",")
70
  label, is_adv = label.strip(), is_adv.strip()
71
+ adv_train_label = "-100"
72
+ elif labels.count(",") == 0:
73
  label = labels.strip()
74
+ adv_train_label = "-100"
75
+ is_adv = "-100"
76
  else:
77
+ label = "-100"
78
+ adv_train_label = "-100"
79
+ is_adv = "-100"
80
 
81
  label = int(label)
82
  adv_train_label = int(adv_train_label)
 
88
  adv_train_label = -100
89
  is_adv = -100
90
 
91
+ text_indices = self.tokenizer.text_to_sequence("{}".format(text))
92
 
93
  data = {
94
+ "text_bert_indices": text_indices[0],
95
+ "text_raw": text,
96
+ "label": label,
97
+ "adv_train_label": adv_train_label,
98
+ "is_adv": is_adv,
 
 
 
 
 
99
  # 'label': self.opt.label_to_index.get(label, -100) if isinstance(label, str) else label,
100
  #
101
  # 'adv_train_label': self.opt.adv_train_label_to_index.get(adv_train_label, -100) if isinstance(adv_train_label, str) else adv_train_label,
 
107
 
108
  except Exception as e:
109
  if ignore_error:
110
+ print("Ignore error while processing:", text)
111
  else:
112
  raise e
113
 
anonymous_demo/core/tad/classic/__bert__/models/tad_bert.py CHANGED
@@ -6,7 +6,7 @@ from anonymous_demo.network.sa_encoder import Encoder
6
 
7
 
8
  class TADBERT(nn.Module):
9
- inputs = ['text_bert_indices']
10
 
11
  def __init__(self, bert, opt):
12
  super(TADBERT, self).__init__()
@@ -23,21 +23,24 @@ class TADBERT(nn.Module):
23
 
24
  def forward(self, inputs):
25
  text_raw_indices = inputs[0]
26
- last_hidden_state = self.bert(text_raw_indices)['last_hidden_state']
27
 
28
  sent_logits = self.dense1(self.pooler(last_hidden_state))
29
  advdet_logits = self.dense2(self.pooler(last_hidden_state))
30
  adv_tr_logits = self.dense3(self.pooler(last_hidden_state))
31
 
32
  att_score = torch.nn.functional.normalize(
33
- last_hidden_state.abs().sum(dim=1, keepdim=False) - last_hidden_state.abs().min(dim=1, keepdim=True)[0],
34
- p=1, dim=1)
 
 
 
35
 
36
  outputs = {
37
- 'sent_logits': sent_logits,
38
- 'advdet_logits': advdet_logits,
39
- 'adv_tr_logits': adv_tr_logits,
40
- 'last_hidden_state': last_hidden_state,
41
- 'att_score': att_score
42
  }
43
  return outputs
 
6
 
7
 
8
  class TADBERT(nn.Module):
9
+ inputs = ["text_bert_indices"]
10
 
11
  def __init__(self, bert, opt):
12
  super(TADBERT, self).__init__()
 
23
 
24
  def forward(self, inputs):
25
  text_raw_indices = inputs[0]
26
+ last_hidden_state = self.bert(text_raw_indices)["last_hidden_state"]
27
 
28
  sent_logits = self.dense1(self.pooler(last_hidden_state))
29
  advdet_logits = self.dense2(self.pooler(last_hidden_state))
30
  adv_tr_logits = self.dense3(self.pooler(last_hidden_state))
31
 
32
  att_score = torch.nn.functional.normalize(
33
+ last_hidden_state.abs().sum(dim=1, keepdim=False)
34
+ - last_hidden_state.abs().min(dim=1, keepdim=True)[0],
35
+ p=1,
36
+ dim=1,
37
+ )
38
 
39
  outputs = {
40
+ "sent_logits": sent_logits,
41
+ "advdet_logits": advdet_logits,
42
+ "adv_tr_logits": adv_tr_logits,
43
+ "last_hidden_state": last_hidden_state,
44
+ "att_score": att_score,
45
  }
46
  return outputs
anonymous_demo/core/tad/prediction/tad_classifier.py CHANGED
@@ -9,21 +9,43 @@ from findfile import find_file, find_cwd_dir
9
  from termcolor import colored
10
 
11
  from torch.utils.data import DataLoader
12
- from transformers import AutoTokenizer, AutoModel, AutoConfig, DebertaV2ForMaskedLM, RobertaForMaskedLM, BertForMaskedLM
 
 
 
 
 
 
 
13
 
14
  from ....functional.dataset.dataset_manager import detect_infer_dataset
15
 
16
  from ..models import BERTTADModelList
17
- from ..classic.__bert__.dataset_utils.data_utils_for_inference import BERTTADDataset, Tokenizer4Pretraining
 
 
 
18
 
19
- from ....utils.demo_utils import print_args, TransformerConnectionError, get_device, build_embedding_matrix
 
 
 
 
 
20
 
21
 
22
  def init_attacker(tad_classifier, defense):
23
  try:
24
  from textattack import Attacker
25
- from textattack.attack_recipes import BAEGarg2019, PWWSRen2019, TextFoolerJin2019, PSOZang2020, IGAWang2019, \
26
- GeneticAlgorithmAlzantot2018, DeepWordBugGao2018
 
 
 
 
 
 
 
27
  from textattack.datasets import Dataset
28
  from textattack.models.wrappers import HuggingFaceModelWrapper
29
 
@@ -34,36 +56,36 @@ def init_attacker(tad_classifier, defense):
34
  def __call__(self, text_inputs, **kwargs):
35
  outputs = []
36
  for text_input in text_inputs:
37
- raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
38
- outputs.append(raw_outputs['probs'])
 
 
39
  return outputs
40
 
41
  class SentAttacker:
42
-
43
  def __init__(self, model, recipe_class=BAEGarg2019):
44
  model = model
45
  model_wrapper = DemoModelWrapper(model)
46
 
47
  recipe = recipe_class.build(model_wrapper)
48
 
49
- _dataset = [('', 0)]
50
  _dataset = Dataset(_dataset)
51
 
52
  self.attacker = Attacker(recipe, _dataset)
53
 
54
  attackers = {
55
- 'bae': BAEGarg2019,
56
- 'pwws': PWWSRen2019,
57
- 'textfooler': TextFoolerJin2019,
58
- 'pso': PSOZang2020,
59
- 'iga': IGAWang2019,
60
- 'ga': GeneticAlgorithmAlzantot2018,
61
- 'wordbugger': DeepWordBugGao2018,
62
  }
63
  return SentAttacker(tad_classifier, attackers[defense])
64
  except Exception as e:
65
-
66
- print('Original error:', e)
67
 
68
 
69
  def get_mlm_and_tokenizer(text_classifier, config):
@@ -72,10 +94,10 @@ def get_mlm_and_tokenizer(text_classifier, config):
72
  else:
73
  base_model = text_classifier.bert.base_model
74
  pretrained_config = AutoConfig.from_pretrained(config.pretrained_bert)
75
- if 'deberta-v3' in config.pretrained_bert:
76
  MLM = DebertaV2ForMaskedLM(pretrained_config)
77
  MLM.deberta = base_model
78
- elif 'roberta' in config.pretrained_bert:
79
  MLM = RobertaForMaskedLM(pretrained_config)
80
  MLM.roberta = base_model
81
  else:
@@ -86,64 +108,85 @@ def get_mlm_and_tokenizer(text_classifier, config):
86
 
87
  class TADTextClassifier:
88
  def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):
89
- '''
90
- from_train_model: load inference model from trained model
91
- '''
92
  self.cal_perplexity = cal_perplexity
93
  # load from a training
94
  if not isinstance(model_arg, str):
95
- print('Load text classifier from training')
96
  self.model = model_arg[0]
97
  self.opt = model_arg[1]
98
  self.tokenizer = model_arg[2]
99
  else:
100
  try:
101
- if 'fine-tuned' in model_arg:
102
  raise ValueError(
103
- 'Do not support to directly load a fine-tuned model, please load a .state_dict or .model instead!')
104
- print('Load text classifier from', model_arg)
105
- state_dict_path = find_file(model_arg, key='.state_dict', exclude_key=['__MACOSX'])
106
- model_path = find_file(model_arg, key='.model', exclude_key=['__MACOSX'])
107
- tokenizer_path = find_file(model_arg, key='.tokenizer', exclude_key=['__MACOSX'])
108
- config_path = find_file(model_arg, key='.config', exclude_key=['__MACOSX'])
109
-
110
- print('config: {}'.format(config_path))
111
- print('state_dict: {}'.format(state_dict_path))
112
- print('model: {}'.format(model_path))
113
- print('tokenizer: {}'.format(tokenizer_path))
114
-
115
- with open(config_path, mode='rb') as f:
 
 
 
 
 
 
 
 
 
116
  self.opt = pickle.load(f)
117
- self.opt.device = get_device(kwargs.pop('auto_device', True))[0]
118
 
119
  if state_dict_path or model_path:
120
  if hasattr(BERTTADModelList, self.opt.model.__name__):
121
  if state_dict_path:
122
- if kwargs.pop('offline', False):
123
  self.bert = AutoModel.from_pretrained(
124
- find_cwd_dir(self.opt.pretrained_bert.split('/')[-1]))
 
 
 
125
  else:
126
- self.bert = AutoModel.from_pretrained(self.opt.pretrained_bert)
 
 
127
  self.model = self.opt.model(self.bert, self.opt)
128
- self.model.load_state_dict(torch.load(state_dict_path, map_location='cpu'))
 
 
129
  elif model_path:
130
- self.model = torch.load(model_path, map_location='cpu')
131
 
132
  try:
133
- self.tokenizer = Tokenizer4Pretraining(max_seq_len=self.opt.max_seq_len, opt=self.opt,
134
- **kwargs)
 
135
  except ValueError:
136
  if tokenizer_path:
137
- with open(tokenizer_path, mode='rb') as f:
138
  self.tokenizer = pickle.load(f)
139
  else:
140
  raise TransformerConnectionError()
141
 
142
  except Exception as e:
143
- raise RuntimeError('Exception: {} Fail to load the model from {}! '.format(e, model_arg))
 
 
 
 
144
 
145
  self.infer_dataloader = None
146
- self.opt.eval_batch_size = kwargs.pop('eval_batch_size', 128)
147
 
148
  self.opt.initializer = self.opt.initializer
149
 
@@ -158,19 +201,19 @@ class TADTextClassifier:
158
  def to(self, device=None):
159
  self.opt.device = device
160
  self.model.to(device)
161
- if hasattr(self, 'MLM'):
162
  self.MLM.to(self.opt.device)
163
 
164
  def cpu(self):
165
- self.opt.device = 'cpu'
166
- self.model.to('cpu')
167
- if hasattr(self, 'MLM'):
168
- self.MLM.to('cpu')
169
 
170
- def cuda(self, device='cuda:0'):
171
  self.opt.device = device
172
  self.model.to(device)
173
- if hasattr(self, 'MLM'):
174
  self.MLM.to(device)
175
 
176
  def _log_write_args(self):
@@ -182,55 +225,67 @@ class TADTextClassifier:
182
  else:
183
  n_nontrainable_params += n_params
184
  print(
185
- 'n_trainable_params: {0}, n_nontrainable_params: {1}'.format(n_trainable_params, n_nontrainable_params))
 
 
 
186
  for arg in vars(self.opt):
187
  if getattr(self.opt, arg) is not None:
188
- print('>>> {0}: {1}'.format(arg, getattr(self.opt, arg)))
189
-
190
- def batch_infer(self,
191
- target_file=None,
192
- print_result=True,
193
- save_result=False,
194
- ignore_error=True,
195
- defense: str = None
196
- ):
197
-
198
- save_path = os.path.join(os.getcwd(), 'tad_text_classification.result.json')
199
-
200
- target_file = detect_infer_dataset(target_file, task='text_defense')
201
  if not target_file:
202
- raise FileNotFoundError('Can not find inference datasets!')
203
 
204
  if hasattr(BERTTADModelList, self.opt.model.__name__):
205
  dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
206
 
207
  dataset.prepare_infer_dataset(target_file, ignore_error=ignore_error)
208
- self.infer_dataloader = DataLoader(dataset=dataset, batch_size=self.opt.eval_batch_size, pin_memory=True,
209
- shuffle=False)
210
- return self._infer(save_path=save_path if save_result else None, print_result=print_result, defense=defense)
211
-
212
- def infer(self,
213
- text: str = None,
214
- print_result=True,
215
- ignore_error=True,
216
- defense: str = None
217
- ):
218
-
 
 
 
 
 
 
 
 
219
  if hasattr(BERTTADModelList, self.opt.model.__name__):
220
  dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
221
 
222
  if text:
223
  dataset.prepare_infer_sample(text, ignore_error=ignore_error)
224
  else:
225
- raise RuntimeError('Please specify your datasets path!')
226
- self.infer_dataloader = DataLoader(dataset=dataset, batch_size=self.opt.eval_batch_size, shuffle=False)
 
 
227
  return self._infer(print_result=print_result, defense=defense)[0]
228
 
229
  def _infer(self, save_path=None, print_result=True, defense=None):
230
-
231
  _params = filter(lambda p: p.requires_grad, self.model.parameters())
232
 
233
- correct = {True: 'Correct', False: 'Wrong'}
234
  results = []
235
 
236
  with torch.no_grad():
@@ -241,86 +296,130 @@ class TADTextClassifier:
241
  n_advdet_correct = 0
242
  n_advdet_labeled = 0
243
  if len(self.infer_dataloader.dataset) >= 100:
244
- it = tqdm.tqdm(self.infer_dataloader, postfix='inferring...')
245
  else:
246
  it = self.infer_dataloader
247
  for _, sample in enumerate(it):
248
- inputs = [sample[col].to(self.opt.device) for col in self.opt.inputs_cols]
 
 
249
  outputs = self.model(inputs)
250
- logits, advdet_logits, adv_tr_logits = outputs['sent_logits'], outputs['advdet_logits'], outputs[
251
- 'adv_tr_logits']
252
- probs, advdet_probs, adv_tr_probs = torch.softmax(logits, dim=-1), torch.softmax(advdet_logits,
253
- dim=-1), torch.softmax(
254
- adv_tr_logits, dim=-1)
255
-
256
- for i, (prob, advdet_prob, adv_tr_prob) in enumerate(zip(probs, advdet_probs, adv_tr_probs)):
257
- text_raw = sample['text_raw'][i]
 
 
 
 
 
 
 
258
 
259
  pred_label = int(prob.argmax(axis=-1))
260
  pred_is_adv_label = int(advdet_prob.argmax(axis=-1))
261
  pred_adv_tr_label = int(adv_tr_prob.argmax(axis=-1))
262
- ref_label = int(sample['label'][i]) if int(sample['label'][i]) in self.opt.index_to_label else ''
263
- ref_is_adv_label = int(sample['is_adv'][i]) if int(
264
- sample['is_adv'][i]) in self.opt.index_to_is_adv else ''
265
- ref_adv_tr_label = int(sample['adv_train_label'][i]) if int(
266
- sample['adv_train_label'][i]) in self.opt.index_to_adv_train_label else ''
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  if self.cal_perplexity:
269
  ids = self.MLM_tokenizer(text_raw, return_tensors="pt")
270
- ids['labels'] = ids['input_ids'].clone()
271
  ids = ids.to(self.opt.device)
272
- loss = self.MLM(**ids)['loss']
273
- perplexity = float(torch.exp(loss / ids['input_ids'].size(1)))
274
  else:
275
- perplexity = 'N.A.'
276
 
277
  result = {
278
- 'text': text_raw,
279
-
280
- 'label': self.opt.index_to_label[pred_label],
281
- 'probs': prob.cpu().numpy(),
282
- 'confidence': float(max(prob)),
283
- 'ref_label': self.opt.index_to_label[ref_label] if isinstance(ref_label, int) else ref_label,
284
- 'ref_label_check': correct[pred_label == ref_label] if ref_label != -100 else '',
285
- 'is_fixed': False,
286
-
287
- 'is_adv_label': self.opt.index_to_is_adv[pred_is_adv_label],
288
- 'is_adv_probs': advdet_prob.cpu().numpy(),
289
- 'is_adv_confidence': float(max(advdet_prob)),
290
- 'ref_is_adv_label': self.opt.index_to_is_adv[ref_is_adv_label] if isinstance(ref_is_adv_label, int) else ref_is_adv_label,
291
- 'ref_is_adv_check': correct[pred_is_adv_label == ref_is_adv_label] if ref_is_adv_label != -100 and isinstance(ref_is_adv_label, int) else '',
292
-
293
- 'pred_adv_tr_label': self.opt.index_to_label[pred_adv_tr_label],
294
- 'ref_adv_tr_label': self.opt.index_to_label[ref_adv_tr_label],
295
-
296
- 'perplexity': perplexity,
 
 
 
 
 
 
 
297
  }
298
  if defense:
299
  try:
300
- if not hasattr(self, 'sent_attacker'):
301
- self.sent_attacker = init_attacker(self, defense.lower())
302
- if result['is_adv_label'] == '1':
303
- res = self.sent_attacker.attacker.simple_attack(text_raw, int(result['label']))
304
- new_infer_res = self.infer(res.perturbed_result.attacked_text.text, print_result=False)
305
- result['perturbed_label'] = result['label']
306
- result['label'] = new_infer_res['label']
307
- result['probs'] = new_infer_res['probs']
308
- result['ref_label_check'] = correct[int(result['label']) == ref_label] if ref_label != -100 else ''
309
- result['restored_text'] = res.perturbed_result.attacked_text.text
310
- result['is_fixed'] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  else:
312
- result['restored_text'] = ''
313
- result['is_fixed'] = False
314
 
315
  except Exception as e:
316
- print('Error:{}, try install TextAttack and tensorflow_text after 10 seconds...'.format(e))
 
 
 
 
317
  time.sleep(10)
318
- raise RuntimeError('Installation done, please run again...')
319
 
320
  if ref_label != -100:
321
  n_labeled += 1
322
 
323
- if result['label'] == result['ref_label']:
324
  n_correct += 1
325
 
326
  if ref_is_adv_label != -100:
@@ -333,56 +432,85 @@ class TADTextClassifier:
333
  try:
334
  if print_result:
335
  for ex_id, result in enumerate(results):
336
- text_printing = result['text'][:]
337
- text_info = ''
338
- if result['label'] != '-100':
339
- if not result['ref_label']:
340
- text_info += ' -> <CLS:{}(ref:{} confidence:{})>'.format(result['label'],
341
- result['ref_label'],
342
- result['confidence'])
343
- elif result['label'] == result['ref_label']:
 
 
344
  text_info += colored(
345
- ' -> <CLS:{}(ref:{} confidence:{})>'.format(result['label'], result['ref_label'],
346
- result['confidence']), 'green')
 
 
 
 
 
347
  else:
348
  text_info += colored(
349
- ' -> <CLS:{}(ref:{} confidence:{})>'.format(result['label'], result['ref_label'],
350
- result['confidence']), 'red')
 
 
 
 
 
351
 
352
  # AdvDet
353
- if result['is_adv_label'] != '-100':
354
- if not result['ref_is_adv_label']:
355
- text_info += ' -> <AdvDet:{}(ref:{} confidence:{})>'.format(result['is_adv_label'],
356
- result['ref_is_adv_check'],
357
- result['is_adv_confidence'])
358
- elif result['is_adv_label'] == result['ref_is_adv_label']:
359
- text_info += colored(' -> <AdvDet:{}(ref:{} confidence:{})>'.format(result['is_adv_label'],
360
- result[
361
- 'ref_is_adv_label'],
362
- result[
363
- 'is_adv_confidence']),
364
- 'green')
 
 
 
 
365
  else:
366
- text_info += colored(' -> <AdvDet:{}(ref:{} confidence:{})>'.format(result['is_adv_label'],
367
- result[
368
- 'ref_is_adv_label'],
369
- result[
370
- 'is_adv_confidence']),
371
- 'red')
 
 
372
  text_printing += text_info
373
  if self.cal_perplexity:
374
- text_printing += colored(' --> <perplexity:{}>'.format(result['perplexity']), 'yellow')
375
- print('Example {}: {}'.format(ex_id, text_printing))
 
 
 
376
  if save_path:
377
- with open(save_path, 'w', encoding='utf8') as fout:
378
  json.dump(str(results), fout, ensure_ascii=False)
379
- print('inference result saved in: {}'.format(save_path))
380
  except Exception as e:
381
- print('Can not save result: {}, Exception: {}'.format(text_raw, e))
382
 
383
  if len(results) > 1:
384
- print('CLS Acc:{}%'.format(100 * n_correct / n_labeled if n_labeled else ''))
385
- print('AdvDet Acc:{}%'.format(100 * n_advdet_correct / n_advdet_labeled if n_advdet_labeled else ''))
 
 
 
 
 
 
 
 
386
 
387
  return results
388
 
 
9
  from termcolor import colored
10
 
11
  from torch.utils.data import DataLoader
12
+ from transformers import (
13
+ AutoTokenizer,
14
+ AutoModel,
15
+ AutoConfig,
16
+ DebertaV2ForMaskedLM,
17
+ RobertaForMaskedLM,
18
+ BertForMaskedLM,
19
+ )
20
 
21
  from ....functional.dataset.dataset_manager import detect_infer_dataset
22
 
23
  from ..models import BERTTADModelList
24
+ from ..classic.__bert__.dataset_utils.data_utils_for_inference import (
25
+ BERTTADDataset,
26
+ Tokenizer4Pretraining,
27
+ )
28
 
29
+ from ....utils.demo_utils import (
30
+ print_args,
31
+ TransformerConnectionError,
32
+ get_device,
33
+ build_embedding_matrix,
34
+ )
35
 
36
 
37
  def init_attacker(tad_classifier, defense):
38
  try:
39
  from textattack import Attacker
40
+ from textattack.attack_recipes import (
41
+ BAEGarg2019,
42
+ PWWSRen2019,
43
+ TextFoolerJin2019,
44
+ PSOZang2020,
45
+ IGAWang2019,
46
+ GeneticAlgorithmAlzantot2018,
47
+ DeepWordBugGao2018,
48
+ )
49
  from textattack.datasets import Dataset
50
  from textattack.models.wrappers import HuggingFaceModelWrapper
51
 
 
56
  def __call__(self, text_inputs, **kwargs):
57
  outputs = []
58
  for text_input in text_inputs:
59
+ raw_outputs = self.model.infer(
60
+ text_input, print_result=False, **kwargs
61
+ )
62
+ outputs.append(raw_outputs["probs"])
63
  return outputs
64
 
65
  class SentAttacker:
 
66
  def __init__(self, model, recipe_class=BAEGarg2019):
67
  model = model
68
  model_wrapper = DemoModelWrapper(model)
69
 
70
  recipe = recipe_class.build(model_wrapper)
71
 
72
+ _dataset = [("", 0)]
73
  _dataset = Dataset(_dataset)
74
 
75
  self.attacker = Attacker(recipe, _dataset)
76
 
77
  attackers = {
78
+ "bae": BAEGarg2019,
79
+ "pwws": PWWSRen2019,
80
+ "textfooler": TextFoolerJin2019,
81
+ "pso": PSOZang2020,
82
+ "iga": IGAWang2019,
83
+ "ga": GeneticAlgorithmAlzantot2018,
84
+ "wordbugger": DeepWordBugGao2018,
85
  }
86
  return SentAttacker(tad_classifier, attackers[defense])
87
  except Exception as e:
88
+ print("Original error:", e)
 
89
 
90
 
91
  def get_mlm_and_tokenizer(text_classifier, config):
 
94
  else:
95
  base_model = text_classifier.bert.base_model
96
  pretrained_config = AutoConfig.from_pretrained(config.pretrained_bert)
97
+ if "deberta-v3" in config.pretrained_bert:
98
  MLM = DebertaV2ForMaskedLM(pretrained_config)
99
  MLM.deberta = base_model
100
+ elif "roberta" in config.pretrained_bert:
101
  MLM = RobertaForMaskedLM(pretrained_config)
102
  MLM.roberta = base_model
103
  else:
 
108
 
109
  class TADTextClassifier:
110
  def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):
111
+ """
112
+ from_train_model: load inference model from trained model
113
+ """
114
  self.cal_perplexity = cal_perplexity
115
  # load from a training
116
  if not isinstance(model_arg, str):
117
+ print("Load text classifier from training")
118
  self.model = model_arg[0]
119
  self.opt = model_arg[1]
120
  self.tokenizer = model_arg[2]
121
  else:
122
  try:
123
+ if "fine-tuned" in model_arg:
124
  raise ValueError(
125
+ "Do not support to directly load a fine-tuned model, please load a .state_dict or .model instead!"
126
+ )
127
+ print("Load text classifier from", model_arg)
128
+ state_dict_path = find_file(
129
+ model_arg, key=".state_dict", exclude_key=["__MACOSX"]
130
+ )
131
+ model_path = find_file(
132
+ model_arg, key=".model", exclude_key=["__MACOSX"]
133
+ )
134
+ tokenizer_path = find_file(
135
+ model_arg, key=".tokenizer", exclude_key=["__MACOSX"]
136
+ )
137
+ config_path = find_file(
138
+ model_arg, key=".config", exclude_key=["__MACOSX"]
139
+ )
140
+
141
+ print("config: {}".format(config_path))
142
+ print("state_dict: {}".format(state_dict_path))
143
+ print("model: {}".format(model_path))
144
+ print("tokenizer: {}".format(tokenizer_path))
145
+
146
+ with open(config_path, mode="rb") as f:
147
  self.opt = pickle.load(f)
148
+ self.opt.device = get_device(kwargs.pop("auto_device", True))[0]
149
 
150
  if state_dict_path or model_path:
151
  if hasattr(BERTTADModelList, self.opt.model.__name__):
152
  if state_dict_path:
153
+ if kwargs.pop("offline", False):
154
  self.bert = AutoModel.from_pretrained(
155
+ find_cwd_dir(
156
+ self.opt.pretrained_bert.split("/")[-1]
157
+ )
158
+ )
159
  else:
160
+ self.bert = AutoModel.from_pretrained(
161
+ self.opt.pretrained_bert
162
+ )
163
  self.model = self.opt.model(self.bert, self.opt)
164
+ self.model.load_state_dict(
165
+ torch.load(state_dict_path, map_location="cpu")
166
+ )
167
  elif model_path:
168
+ self.model = torch.load(model_path, map_location="cpu")
169
 
170
  try:
171
+ self.tokenizer = Tokenizer4Pretraining(
172
+ max_seq_len=self.opt.max_seq_len, opt=self.opt, **kwargs
173
+ )
174
  except ValueError:
175
  if tokenizer_path:
176
+ with open(tokenizer_path, mode="rb") as f:
177
  self.tokenizer = pickle.load(f)
178
  else:
179
  raise TransformerConnectionError()
180
 
181
  except Exception as e:
182
+ raise RuntimeError(
183
+ "Exception: {} Fail to load the model from {}! ".format(
184
+ e, model_arg
185
+ )
186
+ )
187
 
188
  self.infer_dataloader = None
189
+ self.opt.eval_batch_size = kwargs.pop("eval_batch_size", 128)
190
 
191
  self.opt.initializer = self.opt.initializer
192
 
 
201
  def to(self, device=None):
202
  self.opt.device = device
203
  self.model.to(device)
204
+ if hasattr(self, "MLM"):
205
  self.MLM.to(self.opt.device)
206
 
207
  def cpu(self):
208
+ self.opt.device = "cpu"
209
+ self.model.to("cpu")
210
+ if hasattr(self, "MLM"):
211
+ self.MLM.to("cpu")
212
 
213
+ def cuda(self, device="cuda:0"):
214
  self.opt.device = device
215
  self.model.to(device)
216
+ if hasattr(self, "MLM"):
217
  self.MLM.to(device)
218
 
219
  def _log_write_args(self):
 
225
  else:
226
  n_nontrainable_params += n_params
227
  print(
228
+ "n_trainable_params: {0}, n_nontrainable_params: {1}".format(
229
+ n_trainable_params, n_nontrainable_params
230
+ )
231
+ )
232
  for arg in vars(self.opt):
233
  if getattr(self.opt, arg) is not None:
234
+ print(">>> {0}: {1}".format(arg, getattr(self.opt, arg)))
235
+
236
+ def batch_infer(
237
+ self,
238
+ target_file=None,
239
+ print_result=True,
240
+ save_result=False,
241
+ ignore_error=True,
242
+ defense: str = None,
243
+ ):
244
+ save_path = os.path.join(os.getcwd(), "tad_text_classification.result.json")
245
+
246
+ target_file = detect_infer_dataset(target_file, task="text_defense")
247
  if not target_file:
248
+ raise FileNotFoundError("Can not find inference datasets!")
249
 
250
  if hasattr(BERTTADModelList, self.opt.model.__name__):
251
  dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
252
 
253
  dataset.prepare_infer_dataset(target_file, ignore_error=ignore_error)
254
+ self.infer_dataloader = DataLoader(
255
+ dataset=dataset,
256
+ batch_size=self.opt.eval_batch_size,
257
+ pin_memory=True,
258
+ shuffle=False,
259
+ )
260
+ return self._infer(
261
+ save_path=save_path if save_result else None,
262
+ print_result=print_result,
263
+ defense=defense,
264
+ )
265
+
266
+ def infer(
267
+ self,
268
+ text: str = None,
269
+ print_result=True,
270
+ ignore_error=True,
271
+ defense: str = None,
272
+ ):
273
  if hasattr(BERTTADModelList, self.opt.model.__name__):
274
  dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
275
 
276
  if text:
277
  dataset.prepare_infer_sample(text, ignore_error=ignore_error)
278
  else:
279
+ raise RuntimeError("Please specify your datasets path!")
280
+ self.infer_dataloader = DataLoader(
281
+ dataset=dataset, batch_size=self.opt.eval_batch_size, shuffle=False
282
+ )
283
  return self._infer(print_result=print_result, defense=defense)[0]
284
 
285
  def _infer(self, save_path=None, print_result=True, defense=None):
 
286
  _params = filter(lambda p: p.requires_grad, self.model.parameters())
287
 
288
+ correct = {True: "Correct", False: "Wrong"}
289
  results = []
290
 
291
  with torch.no_grad():
 
296
  n_advdet_correct = 0
297
  n_advdet_labeled = 0
298
  if len(self.infer_dataloader.dataset) >= 100:
299
+ it = tqdm.tqdm(self.infer_dataloader, postfix="inferring...")
300
  else:
301
  it = self.infer_dataloader
302
  for _, sample in enumerate(it):
303
+ inputs = [
304
+ sample[col].to(self.opt.device) for col in self.opt.inputs_cols
305
+ ]
306
  outputs = self.model(inputs)
307
+ logits, advdet_logits, adv_tr_logits = (
308
+ outputs["sent_logits"],
309
+ outputs["advdet_logits"],
310
+ outputs["adv_tr_logits"],
311
+ )
312
+ probs, advdet_probs, adv_tr_probs = (
313
+ torch.softmax(logits, dim=-1),
314
+ torch.softmax(advdet_logits, dim=-1),
315
+ torch.softmax(adv_tr_logits, dim=-1),
316
+ )
317
+
318
+ for i, (prob, advdet_prob, adv_tr_prob) in enumerate(
319
+ zip(probs, advdet_probs, adv_tr_probs)
320
+ ):
321
+ text_raw = sample["text_raw"][i]
322
 
323
  pred_label = int(prob.argmax(axis=-1))
324
  pred_is_adv_label = int(advdet_prob.argmax(axis=-1))
325
  pred_adv_tr_label = int(adv_tr_prob.argmax(axis=-1))
326
+ ref_label = (
327
+ int(sample["label"][i])
328
+ if int(sample["label"][i]) in self.opt.index_to_label
329
+ else ""
330
+ )
331
+ ref_is_adv_label = (
332
+ int(sample["is_adv"][i])
333
+ if int(sample["is_adv"][i]) in self.opt.index_to_is_adv
334
+ else ""
335
+ )
336
+ ref_adv_tr_label = (
337
+ int(sample["adv_train_label"][i])
338
+ if int(sample["adv_train_label"][i])
339
+ in self.opt.index_to_adv_train_label
340
+ else ""
341
+ )
342
 
343
  if self.cal_perplexity:
344
  ids = self.MLM_tokenizer(text_raw, return_tensors="pt")
345
+ ids["labels"] = ids["input_ids"].clone()
346
  ids = ids.to(self.opt.device)
347
+ loss = self.MLM(**ids)["loss"]
348
+ perplexity = float(torch.exp(loss / ids["input_ids"].size(1)))
349
  else:
350
+ perplexity = "N.A."
351
 
352
  result = {
353
+ "text": text_raw,
354
+ "label": self.opt.index_to_label[pred_label],
355
+ "probs": prob.cpu().numpy(),
356
+ "confidence": float(max(prob)),
357
+ "ref_label": self.opt.index_to_label[ref_label]
358
+ if isinstance(ref_label, int)
359
+ else ref_label,
360
+ "ref_label_check": correct[pred_label == ref_label]
361
+ if ref_label != -100
362
+ else "",
363
+ "is_fixed": False,
364
+ "is_adv_label": self.opt.index_to_is_adv[pred_is_adv_label],
365
+ "is_adv_probs": advdet_prob.cpu().numpy(),
366
+ "is_adv_confidence": float(max(advdet_prob)),
367
+ "ref_is_adv_label": self.opt.index_to_is_adv[ref_is_adv_label]
368
+ if isinstance(ref_is_adv_label, int)
369
+ else ref_is_adv_label,
370
+ "ref_is_adv_check": correct[
371
+ pred_is_adv_label == ref_is_adv_label
372
+ ]
373
+ if ref_is_adv_label != -100
374
+ and isinstance(ref_is_adv_label, int)
375
+ else "",
376
+ "pred_adv_tr_label": self.opt.index_to_label[pred_adv_tr_label],
377
+ "ref_adv_tr_label": self.opt.index_to_label[ref_adv_tr_label],
378
+ "perplexity": perplexity,
379
  }
380
  if defense:
381
  try:
382
+ if not hasattr(self, "sent_attacker"):
383
+ self.sent_attacker = init_attacker(
384
+ self, defense.lower()
385
+ )
386
+ if result["is_adv_label"] == "1":
387
+ res = self.sent_attacker.attacker.simple_attack(
388
+ text_raw, int(result["label"])
389
+ )
390
+ new_infer_res = self.infer(
391
+ res.perturbed_result.attacked_text.text,
392
+ print_result=False,
393
+ )
394
+ result["perturbed_label"] = result["label"]
395
+ result["label"] = new_infer_res["label"]
396
+ result["probs"] = new_infer_res["probs"]
397
+ result["ref_label_check"] = (
398
+ correct[int(result["label"]) == ref_label]
399
+ if ref_label != -100
400
+ else ""
401
+ )
402
+ result[
403
+ "restored_text"
404
+ ] = res.perturbed_result.attacked_text.text
405
+ result["is_fixed"] = True
406
  else:
407
+ result["restored_text"] = ""
408
+ result["is_fixed"] = False
409
 
410
  except Exception as e:
411
+ print(
412
+ "Error:{}, try install TextAttack and tensorflow_text after 10 seconds...".format(
413
+ e
414
+ )
415
+ )
416
  time.sleep(10)
417
+ raise RuntimeError("Installation done, please run again...")
418
 
419
  if ref_label != -100:
420
  n_labeled += 1
421
 
422
+ if result["label"] == result["ref_label"]:
423
  n_correct += 1
424
 
425
  if ref_is_adv_label != -100:
 
432
  try:
433
  if print_result:
434
  for ex_id, result in enumerate(results):
435
+ text_printing = result["text"][:]
436
+ text_info = ""
437
+ if result["label"] != "-100":
438
+ if not result["ref_label"]:
439
+ text_info += " -> <CLS:{}(ref:{} confidence:{})>".format(
440
+ result["label"],
441
+ result["ref_label"],
442
+ result["confidence"],
443
+ )
444
+ elif result["label"] == result["ref_label"]:
445
  text_info += colored(
446
+ " -> <CLS:{}(ref:{} confidence:{})>".format(
447
+ result["label"],
448
+ result["ref_label"],
449
+ result["confidence"],
450
+ ),
451
+ "green",
452
+ )
453
  else:
454
  text_info += colored(
455
+ " -> <CLS:{}(ref:{} confidence:{})>".format(
456
+ result["label"],
457
+ result["ref_label"],
458
+ result["confidence"],
459
+ ),
460
+ "red",
461
+ )
462
 
463
  # AdvDet
464
+ if result["is_adv_label"] != "-100":
465
+ if not result["ref_is_adv_label"]:
466
+ text_info += " -> <AdvDet:{}(ref:{} confidence:{})>".format(
467
+ result["is_adv_label"],
468
+ result["ref_is_adv_check"],
469
+ result["is_adv_confidence"],
470
+ )
471
+ elif result["is_adv_label"] == result["ref_is_adv_label"]:
472
+ text_info += colored(
473
+ " -> <AdvDet:{}(ref:{} confidence:{})>".format(
474
+ result["is_adv_label"],
475
+ result["ref_is_adv_label"],
476
+ result["is_adv_confidence"],
477
+ ),
478
+ "green",
479
+ )
480
  else:
481
+ text_info += colored(
482
+ " -> <AdvDet:{}(ref:{} confidence:{})>".format(
483
+ result["is_adv_label"],
484
+ result["ref_is_adv_label"],
485
+ result["is_adv_confidence"],
486
+ ),
487
+ "red",
488
+ )
489
  text_printing += text_info
490
  if self.cal_perplexity:
491
+ text_printing += colored(
492
+ " --> <perplexity:{}>".format(result["perplexity"]),
493
+ "yellow",
494
+ )
495
+ print("Example {}: {}".format(ex_id, text_printing))
496
  if save_path:
497
+ with open(save_path, "w", encoding="utf8") as fout:
498
  json.dump(str(results), fout, ensure_ascii=False)
499
+ print("inference result saved in: {}".format(save_path))
500
  except Exception as e:
501
+ print("Can not save result: {}, Exception: {}".format(text_raw, e))
502
 
503
  if len(results) > 1:
504
+ print(
505
+ "CLS Acc:{}%".format(100 * n_correct / n_labeled if n_labeled else "")
506
+ )
507
+ print(
508
+ "AdvDet Acc:{}%".format(
509
+ 100 * n_advdet_correct / n_advdet_labeled
510
+ if n_advdet_labeled
511
+ else ""
512
+ )
513
+ )
514
 
515
  return results
516
 
anonymous_demo/functional/checkpoint/checkpoint_manager.py CHANGED
@@ -12,9 +12,8 @@ class CheckpointManager:
12
  class TADCheckpointManager(CheckpointManager):
13
  @staticmethod
14
  @retry
15
- def get_tad_text_classifier(checkpoint: str = None,
16
- eval_batch_size=128,
17
- **kwargs):
18
-
19
- tad_text_classifier = TADTextClassifier(checkpoint, eval_batch_size=eval_batch_size, **kwargs)
20
  return tad_text_classifier
 
12
  class TADCheckpointManager(CheckpointManager):
13
  @staticmethod
14
  @retry
15
+ def get_tad_text_classifier(checkpoint: str = None, eval_batch_size=128, **kwargs):
16
+ tad_text_classifier = TADTextClassifier(
17
+ checkpoint, eval_batch_size=eval_batch_size, **kwargs
18
+ )
 
19
  return tad_text_classifier
anonymous_demo/functional/config/config_manager.py CHANGED
@@ -10,7 +10,6 @@ def config_check(args):
10
 
11
 
12
  class ConfigManager(Namespace):
13
-
14
  def __init__(self, args=None, **kwargs):
15
  """
16
  The ConfigManager is a subclass of argparse.Namespace and based on parameter dict and count the call-frequency of each parameter
@@ -29,36 +28,35 @@ class ConfigManager(Namespace):
29
  self.args_call_count = {arg: 0 for arg in args}
30
 
31
  def __getattribute__(self, arg_name):
32
- if arg_name == 'args' or arg_name == 'args_call_count':
33
  return super().__getattribute__(arg_name)
34
  try:
35
- value = super().__getattribute__('args')[arg_name]
36
- args_call_count = super().__getattribute__('args_call_count')
37
  args_call_count[arg_name] += 1
38
- super().__setattr__('args_call_count', args_call_count)
39
  return value
40
 
41
  except Exception as e:
42
-
43
  return super().__getattribute__(arg_name)
44
 
45
  def __setattr__(self, arg_name, value):
46
- if arg_name == 'args' or arg_name == 'args_call_count':
47
  super().__setattr__(arg_name, value)
48
  return
49
  try:
50
- args = super().__getattribute__('args')
51
  args[arg_name] = value
52
- super().__setattr__('args', args)
53
- args_call_count = super().__getattribute__('args_call_count')
54
 
55
  if arg_name in args_call_count:
56
  # args_call_count[arg_name] += 1
57
- super().__setattr__('args_call_count', args_call_count)
58
 
59
  else:
60
  args_call_count[arg_name] = 0
61
- super().__setattr__('args_call_count', args_call_count)
62
 
63
  except Exception as e:
64
  super().__setattr__(arg_name, value)
 
10
 
11
 
12
  class ConfigManager(Namespace):
 
13
  def __init__(self, args=None, **kwargs):
14
  """
15
  The ConfigManager is a subclass of argparse.Namespace and based on parameter dict and count the call-frequency of each parameter
 
28
  self.args_call_count = {arg: 0 for arg in args}
29
 
30
  def __getattribute__(self, arg_name):
31
+ if arg_name == "args" or arg_name == "args_call_count":
32
  return super().__getattribute__(arg_name)
33
  try:
34
+ value = super().__getattribute__("args")[arg_name]
35
+ args_call_count = super().__getattribute__("args_call_count")
36
  args_call_count[arg_name] += 1
37
+ super().__setattr__("args_call_count", args_call_count)
38
  return value
39
 
40
  except Exception as e:
 
41
  return super().__getattribute__(arg_name)
42
 
43
  def __setattr__(self, arg_name, value):
44
+ if arg_name == "args" or arg_name == "args_call_count":
45
  super().__setattr__(arg_name, value)
46
  return
47
  try:
48
+ args = super().__getattribute__("args")
49
  args[arg_name] = value
50
+ super().__setattr__("args", args)
51
+ args_call_count = super().__getattribute__("args_call_count")
52
 
53
  if arg_name in args_call_count:
54
  # args_call_count[arg_name] += 1
55
+ super().__setattr__("args_call_count", args_call_count)
56
 
57
  else:
58
  args_call_count[arg_name] = 0
59
+ super().__setattr__("args_call_count", args_call_count)
60
 
61
  except Exception as e:
62
  super().__setattr__(arg_name, value)
anonymous_demo/functional/config/tad_config_manager.py CHANGED
@@ -3,116 +3,121 @@ import copy
3
  from anonymous_demo.functional.config.config_manager import ConfigManager
4
  from anonymous_demo.core.tad.classic.__bert__.models import TADBERT
5
 
6
- _tad_config_template = {'model': TADBERT,
7
- 'optimizer': "adamw",
8
- 'learning_rate': 0.00002,
9
- 'patience': 99999,
10
- 'pretrained_bert': "microsoft/mdeberta-v3-base",
11
- 'cache_dataset': True,
12
- 'warmup_step': -1,
13
- 'show_metric': False,
14
- 'max_seq_len': 80,
15
- 'dropout': 0,
16
- 'l2reg': 0.000001,
17
- 'num_epoch': 10,
18
- 'batch_size': 16,
19
- 'initializer': 'xavier_uniform_',
20
- 'seed': 52,
21
- 'polarities_dim': 3,
22
- 'log_step': 10,
23
- 'evaluate_begin': 0,
24
- 'cross_validate_fold': -1,
25
- 'use_amp': False,
26
- # split train and test datasets into 5 folds and repeat 3 training
27
- }
28
-
29
- _tad_config_base = {'model': TADBERT,
30
- 'optimizer': "adamw",
31
- 'learning_rate': 0.00002,
32
- 'pretrained_bert': "microsoft/deberta-v3-base",
33
- 'cache_dataset': True,
34
- 'warmup_step': -1,
35
- 'show_metric': False,
36
- 'max_seq_len': 80,
37
- 'patience': 99999,
38
- 'dropout': 0,
39
- 'l2reg': 0.000001,
40
- 'num_epoch': 10,
41
- 'batch_size': 16,
42
- 'initializer': 'xavier_uniform_',
43
- 'seed': 52,
44
- 'polarities_dim': 3,
45
- 'log_step': 10,
46
- 'evaluate_begin': 0,
47
- 'cross_validate_fold': -1
48
- # split train and test datasets into 5 folds and repeat 3 training
49
- }
50
-
51
- _tad_config_english = {'model': TADBERT,
52
- 'optimizer': "adamw",
53
- 'learning_rate': 0.00002,
54
- 'patience': 99999,
55
- 'pretrained_bert': "microsoft/deberta-v3-base",
56
- 'cache_dataset': True,
57
- 'warmup_step': -1,
58
- 'show_metric': False,
59
- 'max_seq_len': 80,
60
- 'dropout': 0,
61
- 'l2reg': 0.000001,
62
- 'num_epoch': 10,
63
- 'batch_size': 16,
64
- 'initializer': 'xavier_uniform_',
65
- 'seed': 52,
66
- 'polarities_dim': 3,
67
- 'log_step': 10,
68
- 'evaluate_begin': 0,
69
- 'cross_validate_fold': -1
70
- # split train and test datasets into 5 folds and repeat 3 training
71
- }
72
-
73
- _tad_config_multilingual = {'model': TADBERT,
74
- 'optimizer': "adamw",
75
- 'learning_rate': 0.00002,
76
- 'patience': 99999,
77
- 'pretrained_bert': "microsoft/mdeberta-v3-base",
78
- 'cache_dataset': True,
79
- 'warmup_step': -1,
80
- 'show_metric': False,
81
- 'max_seq_len': 80,
82
- 'dropout': 0,
83
- 'l2reg': 0.000001,
84
- 'num_epoch': 10,
85
- 'batch_size': 16,
86
- 'initializer': 'xavier_uniform_',
87
- 'seed': 52,
88
- 'polarities_dim': 3,
89
- 'log_step': 10,
90
- 'evaluate_begin': 0,
91
- 'cross_validate_fold': -1
92
- # split train and test datasets into 5 folds and repeat 3 training
93
- }
94
-
95
- _tad_config_chinese = {'model': TADBERT,
96
- 'optimizer': "adamw",
97
- 'learning_rate': 0.00002,
98
- 'patience': 99999,
99
- 'cache_dataset': True,
100
- 'warmup_step': -1,
101
- 'show_metric': False,
102
- 'pretrained_bert': "bert-base-chinese",
103
- 'max_seq_len': 80,
104
- 'dropout': 0,
105
- 'l2reg': 0.000001,
106
- 'num_epoch': 10,
107
- 'batch_size': 16,
108
- 'initializer': 'xavier_uniform_',
109
- 'seed': 52,
110
- 'polarities_dim': 3,
111
- 'log_step': 10,
112
- 'evaluate_begin': 0,
113
- 'cross_validate_fold': -1
114
- # split train and test datasets into 5 folds and repeat 3 training
115
- }
 
 
 
 
 
116
 
117
 
118
  class TADConfigManager(ConfigManager):
@@ -148,47 +153,50 @@ class TADConfigManager(ConfigManager):
148
  @staticmethod
149
  def set_tad_config(configType: str, newitem: dict):
150
  if isinstance(newitem, dict):
151
- if configType == 'template':
152
  _tad_config_template.update(newitem)
153
- elif configType == 'base':
154
  _tad_config_base.update(newitem)
155
- elif configType == 'english':
156
  _tad_config_english.update(newitem)
157
- elif configType == 'chinese':
158
  _tad_config_chinese.update(newitem)
159
- elif configType == 'multilingual':
160
  _tad_config_multilingual.update(newitem)
161
- elif configType == 'glove':
162
  _tad_config_glove.update(newitem)
163
  else:
164
  raise ValueError(
165
- "Wrong value of config type supplied, please use one from following type: template, base, english, chinese, multilingual, glove")
 
166
  else:
167
- raise TypeError("Wrong type of new config item supplied, please use dict e.g.{'NewConfig': NewValue}")
 
 
168
 
169
  @staticmethod
170
  def set_tad_config_template(newitem):
171
- TADConfigManager.set_tad_config('template', newitem)
172
 
173
  @staticmethod
174
  def set_tad_config_base(newitem):
175
- TADConfigManager.set_tad_config('base', newitem)
176
 
177
  @staticmethod
178
  def set_tad_config_english(newitem):
179
- TADConfigManager.set_tad_config('english', newitem)
180
 
181
  @staticmethod
182
  def set_tad_config_chinese(newitem):
183
- TADConfigManager.set_tad_config('chinese', newitem)
184
 
185
  @staticmethod
186
  def set_tad_config_multilingual(newitem):
187
- TADConfigManager.set_tad_config('multilingual', newitem)
188
 
189
  @staticmethod
190
  def set_tad_config_glove(newitem):
191
- TADConfigManager.set_tad_config('glove', newitem)
192
 
193
  @staticmethod
194
  def get_tad_config_template() -> ConfigManager:
 
3
  from anonymous_demo.functional.config.config_manager import ConfigManager
4
  from anonymous_demo.core.tad.classic.__bert__.models import TADBERT
5
 
6
+ _tad_config_template = {
7
+ "model": TADBERT,
8
+ "optimizer": "adamw",
9
+ "learning_rate": 0.00002,
10
+ "patience": 99999,
11
+ "pretrained_bert": "microsoft/mdeberta-v3-base",
12
+ "cache_dataset": True,
13
+ "warmup_step": -1,
14
+ "show_metric": False,
15
+ "max_seq_len": 80,
16
+ "dropout": 0,
17
+ "l2reg": 0.000001,
18
+ "num_epoch": 10,
19
+ "batch_size": 16,
20
+ "initializer": "xavier_uniform_",
21
+ "seed": 52,
22
+ "polarities_dim": 3,
23
+ "log_step": 10,
24
+ "evaluate_begin": 0,
25
+ "cross_validate_fold": -1,
26
+ "use_amp": False,
27
+ # split train and test datasets into 5 folds and repeat 3 training
28
+ }
29
+
30
+ _tad_config_base = {
31
+ "model": TADBERT,
32
+ "optimizer": "adamw",
33
+ "learning_rate": 0.00002,
34
+ "pretrained_bert": "microsoft/deberta-v3-base",
35
+ "cache_dataset": True,
36
+ "warmup_step": -1,
37
+ "show_metric": False,
38
+ "max_seq_len": 80,
39
+ "patience": 99999,
40
+ "dropout": 0,
41
+ "l2reg": 0.000001,
42
+ "num_epoch": 10,
43
+ "batch_size": 16,
44
+ "initializer": "xavier_uniform_",
45
+ "seed": 52,
46
+ "polarities_dim": 3,
47
+ "log_step": 10,
48
+ "evaluate_begin": 0,
49
+ "cross_validate_fold": -1
50
+ # split train and test datasets into 5 folds and repeat 3 training
51
+ }
52
+
53
+ _tad_config_english = {
54
+ "model": TADBERT,
55
+ "optimizer": "adamw",
56
+ "learning_rate": 0.00002,
57
+ "patience": 99999,
58
+ "pretrained_bert": "microsoft/deberta-v3-base",
59
+ "cache_dataset": True,
60
+ "warmup_step": -1,
61
+ "show_metric": False,
62
+ "max_seq_len": 80,
63
+ "dropout": 0,
64
+ "l2reg": 0.000001,
65
+ "num_epoch": 10,
66
+ "batch_size": 16,
67
+ "initializer": "xavier_uniform_",
68
+ "seed": 52,
69
+ "polarities_dim": 3,
70
+ "log_step": 10,
71
+ "evaluate_begin": 0,
72
+ "cross_validate_fold": -1
73
+ # split train and test datasets into 5 folds and repeat 3 training
74
+ }
75
+
76
+ _tad_config_multilingual = {
77
+ "model": TADBERT,
78
+ "optimizer": "adamw",
79
+ "learning_rate": 0.00002,
80
+ "patience": 99999,
81
+ "pretrained_bert": "microsoft/mdeberta-v3-base",
82
+ "cache_dataset": True,
83
+ "warmup_step": -1,
84
+ "show_metric": False,
85
+ "max_seq_len": 80,
86
+ "dropout": 0,
87
+ "l2reg": 0.000001,
88
+ "num_epoch": 10,
89
+ "batch_size": 16,
90
+ "initializer": "xavier_uniform_",
91
+ "seed": 52,
92
+ "polarities_dim": 3,
93
+ "log_step": 10,
94
+ "evaluate_begin": 0,
95
+ "cross_validate_fold": -1
96
+ # split train and test datasets into 5 folds and repeat 3 training
97
+ }
98
+
99
+ _tad_config_chinese = {
100
+ "model": TADBERT,
101
+ "optimizer": "adamw",
102
+ "learning_rate": 0.00002,
103
+ "patience": 99999,
104
+ "cache_dataset": True,
105
+ "warmup_step": -1,
106
+ "show_metric": False,
107
+ "pretrained_bert": "bert-base-chinese",
108
+ "max_seq_len": 80,
109
+ "dropout": 0,
110
+ "l2reg": 0.000001,
111
+ "num_epoch": 10,
112
+ "batch_size": 16,
113
+ "initializer": "xavier_uniform_",
114
+ "seed": 52,
115
+ "polarities_dim": 3,
116
+ "log_step": 10,
117
+ "evaluate_begin": 0,
118
+ "cross_validate_fold": -1
119
+ # split train and test datasets into 5 folds and repeat 3 training
120
+ }
121
 
122
 
123
  class TADConfigManager(ConfigManager):
 
153
  @staticmethod
154
  def set_tad_config(configType: str, newitem: dict):
155
  if isinstance(newitem, dict):
156
+ if configType == "template":
157
  _tad_config_template.update(newitem)
158
+ elif configType == "base":
159
  _tad_config_base.update(newitem)
160
+ elif configType == "english":
161
  _tad_config_english.update(newitem)
162
+ elif configType == "chinese":
163
  _tad_config_chinese.update(newitem)
164
+ elif configType == "multilingual":
165
  _tad_config_multilingual.update(newitem)
166
+ elif configType == "glove":
167
  _tad_config_glove.update(newitem)
168
  else:
169
  raise ValueError(
170
+ "Wrong value of config type supplied, please use one from following type: template, base, english, chinese, multilingual, glove"
171
+ )
172
  else:
173
+ raise TypeError(
174
+ "Wrong type of new config item supplied, please use dict e.g.{'NewConfig': NewValue}"
175
+ )
176
 
177
  @staticmethod
178
  def set_tad_config_template(newitem):
179
+ TADConfigManager.set_tad_config("template", newitem)
180
 
181
  @staticmethod
182
  def set_tad_config_base(newitem):
183
+ TADConfigManager.set_tad_config("base", newitem)
184
 
185
  @staticmethod
186
  def set_tad_config_english(newitem):
187
+ TADConfigManager.set_tad_config("english", newitem)
188
 
189
  @staticmethod
190
  def set_tad_config_chinese(newitem):
191
+ TADConfigManager.set_tad_config("chinese", newitem)
192
 
193
  @staticmethod
194
  def set_tad_config_multilingual(newitem):
195
+ TADConfigManager.set_tad_config("multilingual", newitem)
196
 
197
  @staticmethod
198
  def set_tad_config_glove(newitem):
199
+ TADConfigManager.set_tad_config("glove", newitem)
200
 
201
  @staticmethod
202
  def get_tad_config_template() -> ConfigManager:
anonymous_demo/functional/dataset/__init__.py CHANGED
@@ -1 +1 @@
1
- from anonymous_demo.functional.dataset.dataset_manager import (detect_infer_dataset)
 
1
+ from anonymous_demo.functional.dataset.dataset_manager import detect_infer_dataset
anonymous_demo/functional/dataset/dataset_manager.py CHANGED
@@ -1,11 +1,24 @@
1
  import os
2
  from findfile import find_files, find_dir
3
 
4
- filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip',
5
- '.state_dict', '.model', '.png', 'acc_', 'f1_', '.backup', '.bak']
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
- def detect_infer_dataset(dataset_path, task='apc'):
9
  dataset_file = []
10
  if isinstance(dataset_path, str) and os.path.isfile(dataset_path):
11
  dataset_file.append(dataset_path)
@@ -13,9 +26,20 @@ def detect_infer_dataset(dataset_path, task='apc'):
13
 
14
  for d in dataset_path:
15
  if not os.path.exists(d):
16
- search_path = find_dir(os.getcwd(), [d, task, 'dataset'], exclude_key=filter_key_words, disable_alert=False)
17
- dataset_file += find_files(search_path, ['.inference', d], exclude_key=['train.'] + filter_key_words)
 
 
 
 
 
 
 
 
 
18
  else:
19
- dataset_file += find_files(d, ['.inference', task], exclude_key=['train.'] + filter_key_words)
 
 
20
 
21
  return dataset_file
 
1
  import os
2
  from findfile import find_files, find_dir
3
 
4
+ filter_key_words = [
5
+ ".py",
6
+ ".md",
7
+ "readme",
8
+ "log",
9
+ "result",
10
+ "zip",
11
+ ".state_dict",
12
+ ".model",
13
+ ".png",
14
+ "acc_",
15
+ "f1_",
16
+ ".backup",
17
+ ".bak",
18
+ ]
19
 
20
 
21
+ def detect_infer_dataset(dataset_path, task="apc"):
22
  dataset_file = []
23
  if isinstance(dataset_path, str) and os.path.isfile(dataset_path):
24
  dataset_file.append(dataset_path)
 
26
 
27
  for d in dataset_path:
28
  if not os.path.exists(d):
29
+ search_path = find_dir(
30
+ os.getcwd(),
31
+ [d, task, "dataset"],
32
+ exclude_key=filter_key_words,
33
+ disable_alert=False,
34
+ )
35
+ dataset_file += find_files(
36
+ search_path,
37
+ [".inference", d],
38
+ exclude_key=["train."] + filter_key_words,
39
+ )
40
  else:
41
+ dataset_file += find_files(
42
+ d, [".inference", task], exclude_key=["train."] + filter_key_words
43
+ )
44
 
45
  return dataset_file
anonymous_demo/network/lcf_pooler.py CHANGED
@@ -14,10 +14,12 @@ class LCF_Pooler(nn.Module):
14
  device = hidden_states.device
15
  lcf_vec = lcf_vec.detach().cpu().numpy()
16
 
17
- pooled_output = numpy.zeros((hidden_states.shape[0], hidden_states.shape[2]), dtype=numpy.float32)
 
 
18
  hidden_states = hidden_states.detach().cpu().numpy()
19
  for i, vec in enumerate(lcf_vec):
20
- lcf_ids = [j for j in range(len(vec)) if sum(vec[j] - 1.) == 0]
21
  pooled_output[i] = hidden_states[i][lcf_ids[len(lcf_ids) // 2]]
22
 
23
  pooled_output = torch.Tensor(pooled_output).to(device)
 
14
  device = hidden_states.device
15
  lcf_vec = lcf_vec.detach().cpu().numpy()
16
 
17
+ pooled_output = numpy.zeros(
18
+ (hidden_states.shape[0], hidden_states.shape[2]), dtype=numpy.float32
19
+ )
20
  hidden_states = hidden_states.detach().cpu().numpy()
21
  for i, vec in enumerate(lcf_vec):
22
+ lcf_ids = [j for j in range(len(vec)) if sum(vec[j] - 1.0) == 0]
23
  pooled_output[i] = hidden_states[i][lcf_ids[len(lcf_ids) // 2]]
24
 
25
  pooled_output = torch.Tensor(pooled_output).to(device)
anonymous_demo/network/lsa.py CHANGED
@@ -16,8 +16,17 @@ class LSA(nn.Module):
16
  self.eta1 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
17
  self.eta2 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
18
 
19
- def forward(self, global_context_features, spc_mask_vec, lcf_matrix, left_lcf_matrix, right_lcf_matrix):
20
- masked_global_context_features = torch.mul(spc_mask_vec, global_context_features)
 
 
 
 
 
 
 
 
 
21
 
22
  # # --------------------------------------------------- #
23
  lcf_features = torch.mul(global_context_features, lcf_matrix)
@@ -29,24 +38,36 @@ class LSA(nn.Module):
29
  right_lcf_features = torch.mul(masked_global_context_features, right_lcf_matrix)
30
  right_lcf_features = self.encoder_right(right_lcf_features)
31
  # # --------------------------------------------------- #
32
- if 'lr' == self.opt.window or 'rl' == self.opt.window:
33
  if self.eta1 <= 0 and self.opt.eta != -1:
34
  torch.nn.init.uniform_(self.eta1)
35
- print('reset eta1 to: {}'.format(self.eta1.item()))
36
  if self.eta2 <= 0 and self.opt.eta != -1:
37
  torch.nn.init.uniform_(self.eta2)
38
- print('reset eta2 to: {}'.format(self.eta2.item()))
39
  if self.opt.eta >= 0:
40
- cat_features = torch.cat((lcf_features, self.eta1 * left_lcf_features, self.eta2 * right_lcf_features),
41
- -1)
 
 
 
 
 
 
42
  else:
43
- cat_features = torch.cat((lcf_features, left_lcf_features, right_lcf_features), -1)
 
 
44
  sent_out = self.linear_window_3h(cat_features)
45
- elif 'l' == self.opt.window:
46
- sent_out = self.linear_window_2h(torch.cat((lcf_features, self.eta1 * left_lcf_features), -1))
47
- elif 'r' == self.opt.window:
48
- sent_out = self.linear_window_2h(torch.cat((lcf_features, self.eta2 * right_lcf_features), -1))
 
 
 
 
49
  else:
50
- raise KeyError('Invalid parameter:', self.opt.window)
51
 
52
  return sent_out
 
16
  self.eta1 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
17
  self.eta2 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
18
 
19
+ def forward(
20
+ self,
21
+ global_context_features,
22
+ spc_mask_vec,
23
+ lcf_matrix,
24
+ left_lcf_matrix,
25
+ right_lcf_matrix,
26
+ ):
27
+ masked_global_context_features = torch.mul(
28
+ spc_mask_vec, global_context_features
29
+ )
30
 
31
  # # --------------------------------------------------- #
32
  lcf_features = torch.mul(global_context_features, lcf_matrix)
 
38
  right_lcf_features = torch.mul(masked_global_context_features, right_lcf_matrix)
39
  right_lcf_features = self.encoder_right(right_lcf_features)
40
  # # --------------------------------------------------- #
41
+ if "lr" == self.opt.window or "rl" == self.opt.window:
42
  if self.eta1 <= 0 and self.opt.eta != -1:
43
  torch.nn.init.uniform_(self.eta1)
44
+ print("reset eta1 to: {}".format(self.eta1.item()))
45
  if self.eta2 <= 0 and self.opt.eta != -1:
46
  torch.nn.init.uniform_(self.eta2)
47
+ print("reset eta2 to: {}".format(self.eta2.item()))
48
  if self.opt.eta >= 0:
49
+ cat_features = torch.cat(
50
+ (
51
+ lcf_features,
52
+ self.eta1 * left_lcf_features,
53
+ self.eta2 * right_lcf_features,
54
+ ),
55
+ -1,
56
+ )
57
  else:
58
+ cat_features = torch.cat(
59
+ (lcf_features, left_lcf_features, right_lcf_features), -1
60
+ )
61
  sent_out = self.linear_window_3h(cat_features)
62
+ elif "l" == self.opt.window:
63
+ sent_out = self.linear_window_2h(
64
+ torch.cat((lcf_features, self.eta1 * left_lcf_features), -1)
65
+ )
66
+ elif "r" == self.opt.window:
67
+ sent_out = self.linear_window_2h(
68
+ torch.cat((lcf_features, self.eta2 * right_lcf_features), -1)
69
+ )
70
  else:
71
+ raise KeyError("Invalid parameter:", self.opt.window)
72
 
73
  return sent_out
anonymous_demo/network/sa_encoder.py CHANGED
@@ -8,7 +8,9 @@ import torch.nn as nn
8
  class BertSelfAttention(nn.Module):
9
  def __init__(self, config):
10
  super().__init__()
11
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
 
 
12
  raise ValueError(
13
  f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
14
  f"heads ({config.num_attention_heads})"
@@ -23,16 +25,29 @@ class BertSelfAttention(nn.Module):
23
  self.value = nn.Linear(config.hidden_size, self.all_head_size)
24
 
25
  self.dropout = nn.Dropout(
26
- config.attention_probs_dropout_prob if hasattr(config, 'attention_probs_dropout_prob') else 0)
27
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
28
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
 
 
 
 
 
 
 
 
29
  self.max_position_embeddings = config.max_position_embeddings
30
- self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
 
 
31
 
32
  self.is_decoder = config.is_decoder
33
 
34
  def transpose_for_scores(self, x):
35
- new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
 
 
 
36
  x = x.view(*new_x_shape)
37
  return x.permute(0, 2, 1, 3)
38
 
@@ -86,21 +101,42 @@ class BertSelfAttention(nn.Module):
86
  # Take the dot product between "query" and "key" to get the raw attention scores.
87
  attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
88
 
89
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
 
 
 
90
  seq_length = hidden_states.size()[1]
91
- position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
92
- position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
 
 
 
 
93
  distance = position_ids_l - position_ids_r
94
- positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
95
- positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
 
 
 
 
96
 
97
  if self.position_embedding_type == "relative_key":
98
- relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
 
 
99
  attention_scores = attention_scores + relative_position_scores
100
  elif self.position_embedding_type == "relative_key_query":
101
- relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
102
- relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
103
- attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
 
 
 
 
 
 
 
 
104
 
105
  attention_scores = attention_scores / math.sqrt(self.attention_head_size)
106
  if attention_mask is not None:
@@ -124,7 +160,9 @@ class BertSelfAttention(nn.Module):
124
  new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
125
  context_layer = context_layer.view(*new_context_layer_shape)
126
 
127
- outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
 
 
128
 
129
  if self.is_decoder:
130
  outputs = outputs + (past_key_value,)
@@ -136,7 +174,9 @@ class Encoder(nn.Module):
136
  super(Encoder, self).__init__()
137
  self.opt = opt
138
  self.config = config
139
- self.encoder = nn.ModuleList([SelfAttention(config, opt) for _ in range(layer_num)])
 
 
140
  self.tanh = torch.nn.Tanh()
141
 
142
  def forward(self, x):
 
8
  class BertSelfAttention(nn.Module):
9
  def __init__(self, config):
10
  super().__init__()
11
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
12
+ config, "embedding_size"
13
+ ):
14
  raise ValueError(
15
  f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
16
  f"heads ({config.num_attention_heads})"
 
25
  self.value = nn.Linear(config.hidden_size, self.all_head_size)
26
 
27
  self.dropout = nn.Dropout(
28
+ config.attention_probs_dropout_prob
29
+ if hasattr(config, "attention_probs_dropout_prob")
30
+ else 0
31
+ )
32
+ self.position_embedding_type = getattr(
33
+ config, "position_embedding_type", "absolute"
34
+ )
35
+ if (
36
+ self.position_embedding_type == "relative_key"
37
+ or self.position_embedding_type == "relative_key_query"
38
+ ):
39
  self.max_position_embeddings = config.max_position_embeddings
40
+ self.distance_embedding = nn.Embedding(
41
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
42
+ )
43
 
44
  self.is_decoder = config.is_decoder
45
 
46
  def transpose_for_scores(self, x):
47
+ new_x_shape = x.size()[:-1] + (
48
+ self.num_attention_heads,
49
+ self.attention_head_size,
50
+ )
51
  x = x.view(*new_x_shape)
52
  return x.permute(0, 2, 1, 3)
53
 
 
101
  # Take the dot product between "query" and "key" to get the raw attention scores.
102
  attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
103
 
104
+ if (
105
+ self.position_embedding_type == "relative_key"
106
+ or self.position_embedding_type == "relative_key_query"
107
+ ):
108
  seq_length = hidden_states.size()[1]
109
+ position_ids_l = torch.arange(
110
+ seq_length, dtype=torch.long, device=hidden_states.device
111
+ ).view(-1, 1)
112
+ position_ids_r = torch.arange(
113
+ seq_length, dtype=torch.long, device=hidden_states.device
114
+ ).view(1, -1)
115
  distance = position_ids_l - position_ids_r
116
+ positional_embedding = self.distance_embedding(
117
+ distance + self.max_position_embeddings - 1
118
+ )
119
+ positional_embedding = positional_embedding.to(
120
+ dtype=query_layer.dtype
121
+ ) # fp16 compatibility
122
 
123
  if self.position_embedding_type == "relative_key":
124
+ relative_position_scores = torch.einsum(
125
+ "bhld,lrd->bhlr", query_layer, positional_embedding
126
+ )
127
  attention_scores = attention_scores + relative_position_scores
128
  elif self.position_embedding_type == "relative_key_query":
129
+ relative_position_scores_query = torch.einsum(
130
+ "bhld,lrd->bhlr", query_layer, positional_embedding
131
+ )
132
+ relative_position_scores_key = torch.einsum(
133
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
134
+ )
135
+ attention_scores = (
136
+ attention_scores
137
+ + relative_position_scores_query
138
+ + relative_position_scores_key
139
+ )
140
 
141
  attention_scores = attention_scores / math.sqrt(self.attention_head_size)
142
  if attention_mask is not None:
 
160
  new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
161
  context_layer = context_layer.view(*new_context_layer_shape)
162
 
163
+ outputs = (
164
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
165
+ )
166
 
167
  if self.is_decoder:
168
  outputs = outputs + (past_key_value,)
 
174
  super(Encoder, self).__init__()
175
  self.opt = opt
176
  self.config = config
177
+ self.encoder = nn.ModuleList(
178
+ [SelfAttention(config, opt) for _ in range(layer_num)]
179
+ )
180
  self.tanh = torch.nn.Tanh()
181
 
182
  def forward(self, x):
anonymous_demo/utils/demo_utils.py CHANGED
@@ -22,10 +22,10 @@ from anonymous_demo import __version__
22
 
23
 
24
  def save_args(config, save_path):
25
- f = open(os.path.join(save_path), mode='w', encoding='utf8')
26
  for arg in config.args:
27
  if config.args_call_count[arg]:
28
- f.write('{}: {}\n'.format(arg, config.args[arg]))
29
  f.close()
30
 
31
 
@@ -33,20 +33,39 @@ def print_args(config, logger=None, mode=0):
33
  args = [key for key in sorted(config.args.keys())]
34
  for arg in args:
35
  if logger:
36
- logger.info('{0}:{1}\t-->\tCalling Count:{2}'.format(arg, config.args[arg], config.args_call_count[arg]))
 
 
 
 
37
  else:
38
- print('{0}:{1}\t-->\tCalling Count:{2}'.format(arg, config.args[arg], config.args_call_count[arg]))
 
 
 
 
39
 
40
 
41
  def check_and_fix_labels(label_set: set, label_name, all_data, opt):
42
- if '-100' in label_set:
43
-
44
- label_to_index = {origin_label: int(idx) - 1 if origin_label != '-100' else -100 for origin_label, idx in zip(sorted(label_set), range(len(label_set)))}
45
- index_to_label = {int(idx) - 1 if origin_label != '-100' else -100: origin_label for origin_label, idx in zip(sorted(label_set), range(len(label_set)))}
 
 
 
 
 
46
  else:
47
- label_to_index = {origin_label: int(idx) for origin_label, idx in zip(sorted(label_set), range(len(label_set)))}
48
- index_to_label = {int(idx): origin_label for origin_label, idx in zip(sorted(label_set), range(len(label_set)))}
49
- if 'index_to_label' not in opt.args:
 
 
 
 
 
 
50
  opt.index_to_label = index_to_label
51
  opt.label_to_index = label_to_index
52
 
@@ -54,7 +73,7 @@ def check_and_fix_labels(label_set: set, label_name, all_data, opt):
54
  opt.index_to_label.update(index_to_label)
55
  opt.label_to_index.update(label_to_index)
56
  num_label = {l: 0 for l in label_set}
57
- num_label['Sum'] = len(all_data)
58
  for item in all_data:
59
  try:
60
  num_label[item[label_name]] += 1
@@ -63,75 +82,91 @@ def check_and_fix_labels(label_set: set, label_name, all_data, opt):
63
  # print(e)
64
  num_label[item.polarity] += 1
65
  item.polarity = label_to_index[item.polarity]
66
- print('Dataset Label Details: {}'.format(num_label))
67
 
68
 
69
  def check_and_fix_IOB_labels(label_map, opt):
70
- index_to_IOB_label = {int(label_map[origin_label]): origin_label for origin_label in label_map}
 
 
71
  opt.index_to_IOB_label = index_to_IOB_label
72
 
73
 
74
  def get_device(auto_device):
75
- if isinstance(auto_device, str) and auto_device == 'allcuda':
76
- device = 'cuda'
77
  elif isinstance(auto_device, str):
78
  device = auto_device
79
  elif isinstance(auto_device, bool):
80
- device = auto_cuda() if auto_device else 'cpu'
81
  else:
82
  device = auto_cuda()
83
  try:
84
  torch.device(device)
85
  except RuntimeError as e:
86
- print(colored('Device assignment error: {}, redirect to CPU'.format(e), 'red'))
87
- device = 'cpu'
 
 
88
  device_name = auto_cuda_name()
89
  return device, device_name
90
 
91
 
92
  def _load_word_vec(path, word2idx=None, embed_dim=300):
93
- fin = open(path, 'r', encoding='utf-8', newline='\n', errors='ignore')
94
  word_vec = {}
95
- for line in tqdm.tqdm(fin.readlines(), postfix='Loading embedding file...'):
96
  tokens = line.rstrip().split()
97
- word, vec = ' '.join(tokens[:-embed_dim]), tokens[-embed_dim:]
98
  if word in word2idx.keys():
99
- word_vec[word] = np.asarray(vec, dtype='float32')
100
  return word_vec
101
 
102
 
103
  def build_embedding_matrix(word2idx, embed_dim, dat_fname, opt):
104
- if not os.path.exists('run'):
105
- os.makedirs('run')
106
- embed_matrix_path = 'run/{}'.format(os.path.join(opt.dataset_name, dat_fname))
107
  if os.path.exists(embed_matrix_path):
108
- print(colored('Loading cached embedding_matrix from {} (Please remove all cached files if there is any problem!)'.format(embed_matrix_path), 'green'))
109
- embedding_matrix = pickle.load(open(embed_matrix_path, 'rb'))
 
 
 
 
 
 
 
110
  else:
111
  glove_path = prepare_glove840_embedding(embed_matrix_path)
112
  embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim))
113
 
114
  word_vec = _load_word_vec(glove_path, word2idx=word2idx, embed_dim=embed_dim)
115
 
116
- for word, i in tqdm.tqdm(word2idx.items(), postfix=colored('Building embedding_matrix {}'.format(dat_fname), 'yellow')):
 
 
 
117
  vec = word_vec.get(word)
118
  if vec is not None:
119
  embedding_matrix[i] = vec
120
- pickle.dump(embedding_matrix, open(embed_matrix_path, 'wb'))
121
  return embedding_matrix
122
 
123
 
124
- def pad_and_truncate(sequence, maxlen, dtype='int64', padding='post', truncating='post', value=0):
 
 
125
  x = (np.ones(maxlen) * value).astype(dtype)
126
- if truncating == 'pre':
127
  trunc = sequence[-maxlen:]
128
  else:
129
  trunc = sequence[:maxlen]
130
  trunc = np.asarray(trunc, dtype=dtype)
131
- if padding == 'post':
132
- x[:len(trunc)] = trunc
133
  else:
134
- x[-len(trunc):] = trunc
135
  return x
136
 
137
 
@@ -145,7 +180,6 @@ def retry(f):
145
  def decorated(*args, **kwargs):
146
  count = 5
147
  while count:
148
-
149
  try:
150
  return f(*args, **kwargs)
151
  except (
@@ -158,7 +192,7 @@ def retry(f):
158
  requests.exceptions.SSLError,
159
  requests.exceptions.BaseHTTPError,
160
  ) as e:
161
- print(colored('Training Exception: {}, will retry later'.format(e)))
162
  time.sleep(60)
163
  count -= 1
164
 
@@ -168,14 +202,14 @@ def retry(f):
168
  def save_json(dic, save_path):
169
  if isinstance(dic, str):
170
  dic = eval(dic)
171
- with open(save_path, 'w', encoding='utf-8') as f:
172
  # f.write(str(dict))
173
  str_ = json.dumps(dic, ensure_ascii=False)
174
  f.write(str_)
175
 
176
 
177
  def load_json(save_path):
178
- with open(save_path, 'r', encoding='utf-8') as f:
179
  data = f.readline().strip()
180
  print(type(data), data)
181
  dic = json.loads(data)
@@ -184,14 +218,14 @@ def load_json(save_path):
184
 
185
  def init_optimizer(optimizer):
186
  optimizers = {
187
- 'adadelta': torch.optim.Adadelta, # default lr=1.0
188
- 'adagrad': torch.optim.Adagrad, # default lr=0.01
189
- 'adam': torch.optim.Adam, # default lr=0.001
190
- 'adamax': torch.optim.Adamax, # default lr=0.002
191
- 'asgd': torch.optim.ASGD, # default lr=0.01
192
- 'rmsprop': torch.optim.RMSprop, # default lr=0.01
193
- 'sgd': torch.optim.SGD,
194
- 'adamw': torch.optim.AdamW,
195
  torch.optim.Adadelta: torch.optim.Adadelta, # default lr=1.0
196
  torch.optim.Adagrad: torch.optim.Adagrad, # default lr=0.01
197
  torch.optim.Adam: torch.optim.Adam, # default lr=0.001
@@ -206,4 +240,8 @@ def init_optimizer(optimizer):
206
  elif hasattr(torch.optim, optimizer.__name__):
207
  return optimizer
208
  else:
209
- raise KeyError('Unsupported optimizer: {}. Please use string or the optimizer objects in torch.optim as your optimizer'.format(optimizer))
 
 
 
 
 
22
 
23
 
24
  def save_args(config, save_path):
25
+ f = open(os.path.join(save_path), mode="w", encoding="utf8")
26
  for arg in config.args:
27
  if config.args_call_count[arg]:
28
+ f.write("{}: {}\n".format(arg, config.args[arg]))
29
  f.close()
30
 
31
 
 
33
  args = [key for key in sorted(config.args.keys())]
34
  for arg in args:
35
  if logger:
36
+ logger.info(
37
+ "{0}:{1}\t-->\tCalling Count:{2}".format(
38
+ arg, config.args[arg], config.args_call_count[arg]
39
+ )
40
+ )
41
  else:
42
+ print(
43
+ "{0}:{1}\t-->\tCalling Count:{2}".format(
44
+ arg, config.args[arg], config.args_call_count[arg]
45
+ )
46
+ )
47
 
48
 
49
  def check_and_fix_labels(label_set: set, label_name, all_data, opt):
50
+ if "-100" in label_set:
51
+ label_to_index = {
52
+ origin_label: int(idx) - 1 if origin_label != "-100" else -100
53
+ for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
54
+ }
55
+ index_to_label = {
56
+ int(idx) - 1 if origin_label != "-100" else -100: origin_label
57
+ for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
58
+ }
59
  else:
60
+ label_to_index = {
61
+ origin_label: int(idx)
62
+ for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
63
+ }
64
+ index_to_label = {
65
+ int(idx): origin_label
66
+ for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
67
+ }
68
+ if "index_to_label" not in opt.args:
69
  opt.index_to_label = index_to_label
70
  opt.label_to_index = label_to_index
71
 
 
73
  opt.index_to_label.update(index_to_label)
74
  opt.label_to_index.update(label_to_index)
75
  num_label = {l: 0 for l in label_set}
76
+ num_label["Sum"] = len(all_data)
77
  for item in all_data:
78
  try:
79
  num_label[item[label_name]] += 1
 
82
  # print(e)
83
  num_label[item.polarity] += 1
84
  item.polarity = label_to_index[item.polarity]
85
+ print("Dataset Label Details: {}".format(num_label))
86
 
87
 
88
  def check_and_fix_IOB_labels(label_map, opt):
89
+ index_to_IOB_label = {
90
+ int(label_map[origin_label]): origin_label for origin_label in label_map
91
+ }
92
  opt.index_to_IOB_label = index_to_IOB_label
93
 
94
 
95
  def get_device(auto_device):
96
+ if isinstance(auto_device, str) and auto_device == "allcuda":
97
+ device = "cuda"
98
  elif isinstance(auto_device, str):
99
  device = auto_device
100
  elif isinstance(auto_device, bool):
101
+ device = auto_cuda() if auto_device else "cpu"
102
  else:
103
  device = auto_cuda()
104
  try:
105
  torch.device(device)
106
  except RuntimeError as e:
107
+ print(
108
+ colored("Device assignment error: {}, redirect to CPU".format(e), "red")
109
+ )
110
+ device = "cpu"
111
  device_name = auto_cuda_name()
112
  return device, device_name
113
 
114
 
115
  def _load_word_vec(path, word2idx=None, embed_dim=300):
116
+ fin = open(path, "r", encoding="utf-8", newline="\n", errors="ignore")
117
  word_vec = {}
118
+ for line in tqdm.tqdm(fin.readlines(), postfix="Loading embedding file..."):
119
  tokens = line.rstrip().split()
120
+ word, vec = " ".join(tokens[:-embed_dim]), tokens[-embed_dim:]
121
  if word in word2idx.keys():
122
+ word_vec[word] = np.asarray(vec, dtype="float32")
123
  return word_vec
124
 
125
 
126
  def build_embedding_matrix(word2idx, embed_dim, dat_fname, opt):
127
+ if not os.path.exists("run"):
128
+ os.makedirs("run")
129
+ embed_matrix_path = "run/{}".format(os.path.join(opt.dataset_name, dat_fname))
130
  if os.path.exists(embed_matrix_path):
131
+ print(
132
+ colored(
133
+ "Loading cached embedding_matrix from {} (Please remove all cached files if there is any problem!)".format(
134
+ embed_matrix_path
135
+ ),
136
+ "green",
137
+ )
138
+ )
139
+ embedding_matrix = pickle.load(open(embed_matrix_path, "rb"))
140
  else:
141
  glove_path = prepare_glove840_embedding(embed_matrix_path)
142
  embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim))
143
 
144
  word_vec = _load_word_vec(glove_path, word2idx=word2idx, embed_dim=embed_dim)
145
 
146
+ for word, i in tqdm.tqdm(
147
+ word2idx.items(),
148
+ postfix=colored("Building embedding_matrix {}".format(dat_fname), "yellow"),
149
+ ):
150
  vec = word_vec.get(word)
151
  if vec is not None:
152
  embedding_matrix[i] = vec
153
+ pickle.dump(embedding_matrix, open(embed_matrix_path, "wb"))
154
  return embedding_matrix
155
 
156
 
157
+ def pad_and_truncate(
158
+ sequence, maxlen, dtype="int64", padding="post", truncating="post", value=0
159
+ ):
160
  x = (np.ones(maxlen) * value).astype(dtype)
161
+ if truncating == "pre":
162
  trunc = sequence[-maxlen:]
163
  else:
164
  trunc = sequence[:maxlen]
165
  trunc = np.asarray(trunc, dtype=dtype)
166
+ if padding == "post":
167
+ x[: len(trunc)] = trunc
168
  else:
169
+ x[-len(trunc) :] = trunc
170
  return x
171
 
172
 
 
180
  def decorated(*args, **kwargs):
181
  count = 5
182
  while count:
 
183
  try:
184
  return f(*args, **kwargs)
185
  except (
 
192
  requests.exceptions.SSLError,
193
  requests.exceptions.BaseHTTPError,
194
  ) as e:
195
+ print(colored("Training Exception: {}, will retry later".format(e)))
196
  time.sleep(60)
197
  count -= 1
198
 
 
202
  def save_json(dic, save_path):
203
  if isinstance(dic, str):
204
  dic = eval(dic)
205
+ with open(save_path, "w", encoding="utf-8") as f:
206
  # f.write(str(dict))
207
  str_ = json.dumps(dic, ensure_ascii=False)
208
  f.write(str_)
209
 
210
 
211
  def load_json(save_path):
212
+ with open(save_path, "r", encoding="utf-8") as f:
213
  data = f.readline().strip()
214
  print(type(data), data)
215
  dic = json.loads(data)
 
218
 
219
  def init_optimizer(optimizer):
220
  optimizers = {
221
+ "adadelta": torch.optim.Adadelta, # default lr=1.0
222
+ "adagrad": torch.optim.Adagrad, # default lr=0.01
223
+ "adam": torch.optim.Adam, # default lr=0.001
224
+ "adamax": torch.optim.Adamax, # default lr=0.002
225
+ "asgd": torch.optim.ASGD, # default lr=0.01
226
+ "rmsprop": torch.optim.RMSprop, # default lr=0.01
227
+ "sgd": torch.optim.SGD,
228
+ "adamw": torch.optim.AdamW,
229
  torch.optim.Adadelta: torch.optim.Adadelta, # default lr=1.0
230
  torch.optim.Adagrad: torch.optim.Adagrad, # default lr=0.01
231
  torch.optim.Adam: torch.optim.Adam, # default lr=0.001
 
240
  elif hasattr(torch.optim, optimizer.__name__):
241
  return optimizer
242
  else:
243
+ raise KeyError(
244
+ "Unsupported optimizer: {}. Please use string or the optimizer objects in torch.optim as your optimizer".format(
245
+ optimizer
246
+ )
247
+ )
anonymous_demo/utils/logger.py CHANGED
@@ -5,22 +5,22 @@ import time
5
 
6
  import termcolor
7
 
8
- today = time.strftime('%Y%m%d %H%M%S', time.localtime(time.time()))
9
 
10
 
11
- def get_logger(log_path, log_name='', log_type='training_log'):
12
  if not log_path:
13
  log_dir = os.path.join(log_path, "logs")
14
  else:
15
- log_dir = os.path.join('.', "logs")
16
 
17
- full_path = os.path.join(log_dir, log_name + '_' + today)
18
  if not os.path.exists(full_path):
19
  os.makedirs(full_path)
20
  log_path = os.path.join(full_path, "{}.log".format(log_type))
21
  logger = logging.getLogger(log_name)
22
  if not logger.handlers:
23
- formatter = logging.Formatter('%(asctime)s %(levelname)s: %(message)s')
24
 
25
  file_handler = logging.FileHandler(log_path, encoding="utf8")
26
  file_handler.setFormatter(formatter)
 
5
 
6
  import termcolor
7
 
8
+ today = time.strftime("%Y%m%d %H%M%S", time.localtime(time.time()))
9
 
10
 
11
+ def get_logger(log_path, log_name="", log_type="training_log"):
12
  if not log_path:
13
  log_dir = os.path.join(log_path, "logs")
14
  else:
15
+ log_dir = os.path.join(".", "logs")
16
 
17
+ full_path = os.path.join(log_dir, log_name + "_" + today)
18
  if not os.path.exists(full_path):
19
  os.makedirs(full_path)
20
  log_path = os.path.join(full_path, "{}.log".format(log_type))
21
  logger = logging.getLogger(log_name)
22
  if not logger.handlers:
23
+ formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s")
24
 
25
  file_handler = logging.FileHandler(log_path, encoding="utf8")
26
  file_handler.setFormatter(formatter)
app.py CHANGED
@@ -18,6 +18,7 @@ from textattack.attack_recipes import (
18
  IGAWang2019,
19
  GeneticAlgorithmAlzantot2018,
20
  DeepWordBugGao2018,
 
21
  )
22
  from textattack.attack_results import SuccessfulAttackResult
23
  from textattack.datasets import Dataset
@@ -88,9 +89,10 @@ attack_recipes = {
88
  "iga": IGAWang2019,
89
  "GA": GeneticAlgorithmAlzantot2018,
90
  "wordbugger": DeepWordBugGao2018,
 
91
  }
92
 
93
- for attacker in ["pwws", "bae", "textfooler"]:
94
  for dataset in [
95
  "agnews10k",
96
  "amazon",
@@ -389,7 +391,10 @@ with demo:
389
  "- To our best knowledge, Reactive Perturbation Defocusing is a novel approach in adversarial defense "
390
  ". RPD significantly (>10% defense accuracy improvement) outperforms the state-of-the-art methods."
391
  )
392
-
 
 
 
393
 
394
  gr.Markdown("## <p align='center'>Natural Example Input</p>")
395
  with gr.Group():
@@ -400,7 +405,14 @@ with demo:
400
  label="Select a testing dataset and an adversarial attacker to generate an adversarial example.",
401
  )
402
  input_attacker = gr.Radio(
403
- choices=["BAE", "PWWS", "TextFooler"],
 
 
 
 
 
 
 
404
  value="TextFooler",
405
  label="Choose an Adversarial Attacker for generating an adversarial example to attack the model.",
406
  )
@@ -414,7 +426,6 @@ with demo:
414
  placeholder="Original label...", label="Original Label"
415
  )
416
 
417
-
418
  button_gen = gr.Button(
419
  "Generate an adversarial example and repair using RPD (No GPU, Time:3-10 mins )",
420
  variant="primary",
@@ -432,11 +443,14 @@ with demo:
432
  output_adv_example = gr.Textbox(label="Adversarial Example")
433
  output_adv_label = gr.Textbox(label="Perturbed Label")
434
  with gr.Row():
435
- output_repaired_example = gr.Textbox(label="Repaired Adversarial Example by RPD")
 
 
436
  output_repaired_label = gr.Textbox(label="Repaired Label")
437
 
438
-
439
- gr.Markdown("## <p align='center'>The Output of Reactive Perturbation Defocusing</p>")
 
440
  with gr.Group():
441
  output_is_adv_df = gr.DataFrame(label="Adversarial Example Detection Result")
442
  gr.Markdown(
@@ -444,9 +458,7 @@ with demo:
444
  "The perturbed_label is the predicted label of the adversarial example. "
445
  "The confidence field represents the confidence of the predicted adversarial example detection. "
446
  )
447
- output_df = gr.DataFrame(
448
- label="Repaired Standard Classification Result"
449
- )
450
  gr.Markdown(
451
  "If is_repaired=true, it has been repaired by RPD. "
452
  "The pred_label field indicates the standard classification result. "
@@ -454,20 +466,19 @@ with demo:
454
  "The is_correct field indicates whether the predicted label is correct."
455
  )
456
 
457
-
458
  gr.Markdown("## <p align='center'>Example Comparisons</p>")
459
  ori_text_diff = gr.HighlightedText(
460
- label="The Original Natural Example",
461
- combine_adjacent=True,
462
- )
463
  adv_text_diff = gr.HighlightedText(
464
- label="Character Editions of Adversarial Example Compared to the Natural Example",
465
- combine_adjacent=True,
466
- )
467
  restored_text_diff = gr.HighlightedText(
468
- label="Character Editions of Repaired Adversarial Example Compared to the Natural Example",
469
- combine_adjacent=True,
470
- )
471
 
472
  # Bind functions to buttons
473
  button_gen.click(
 
18
  IGAWang2019,
19
  GeneticAlgorithmAlzantot2018,
20
  DeepWordBugGao2018,
21
+ CLARE2020,
22
  )
23
  from textattack.attack_results import SuccessfulAttackResult
24
  from textattack.datasets import Dataset
 
89
  "iga": IGAWang2019,
90
  "GA": GeneticAlgorithmAlzantot2018,
91
  "wordbugger": DeepWordBugGao2018,
92
+ 'clare': CLARE2020,
93
  }
94
 
95
+ for attacker in ["pwws", "bae", "textfooler", "pso", "wordbugger", 'clare']:
96
  for dataset in [
97
  "agnews10k",
98
  "amazon",
 
391
  "- To our best knowledge, Reactive Perturbation Defocusing is a novel approach in adversarial defense "
392
  ". RPD significantly (>10% defense accuracy improvement) outperforms the state-of-the-art methods."
393
  )
394
+ gr.Markdown(
395
+ "- The DeepWordBug, IGA, GA, PSO, and CLARE attackers are very slow on CPU Devices."
396
+ " And they are unknown attackers to RPD's adversarial detector. "
397
+ )
398
 
399
  gr.Markdown("## <p align='center'>Natural Example Input</p>")
400
  with gr.Group():
 
405
  label="Select a testing dataset and an adversarial attacker to generate an adversarial example.",
406
  )
407
  input_attacker = gr.Radio(
408
+ choices=[
409
+ "BAE",
410
+ "PWWS",
411
+ "TextFooler",
412
+ "WordBugger",
413
+ "PSO",
414
+ "CLARE",
415
+ ],
416
  value="TextFooler",
417
  label="Choose an Adversarial Attacker for generating an adversarial example to attack the model.",
418
  )
 
426
  placeholder="Original label...", label="Original Label"
427
  )
428
 
 
429
  button_gen = gr.Button(
430
  "Generate an adversarial example and repair using RPD (No GPU, Time:3-10 mins )",
431
  variant="primary",
 
443
  output_adv_example = gr.Textbox(label="Adversarial Example")
444
  output_adv_label = gr.Textbox(label="Perturbed Label")
445
  with gr.Row():
446
+ output_repaired_example = gr.Textbox(
447
+ label="Repaired Adversarial Example by RPD"
448
+ )
449
  output_repaired_label = gr.Textbox(label="Repaired Label")
450
 
451
+ gr.Markdown(
452
+ "## <p align='center'>The Output of Reactive Perturbation Defocusing</p>"
453
+ )
454
  with gr.Group():
455
  output_is_adv_df = gr.DataFrame(label="Adversarial Example Detection Result")
456
  gr.Markdown(
 
458
  "The perturbed_label is the predicted label of the adversarial example. "
459
  "The confidence field represents the confidence of the predicted adversarial example detection. "
460
  )
461
+ output_df = gr.DataFrame(label="Repaired Standard Classification Result")
 
 
462
  gr.Markdown(
463
  "If is_repaired=true, it has been repaired by RPD. "
464
  "The pred_label field indicates the standard classification result. "
 
466
  "The is_correct field indicates whether the predicted label is correct."
467
  )
468
 
 
469
  gr.Markdown("## <p align='center'>Example Comparisons</p>")
470
  ori_text_diff = gr.HighlightedText(
471
+ label="The Original Natural Example",
472
+ combine_adjacent=True,
473
+ )
474
  adv_text_diff = gr.HighlightedText(
475
+ label="Character Editions of Adversarial Example Compared to the Natural Example",
476
+ combine_adjacent=True,
477
+ )
478
  restored_text_diff = gr.HighlightedText(
479
+ label="Character Editions of Repaired Adversarial Example Compared to the Natural Example",
480
+ combine_adjacent=True,
481
+ )
482
 
483
  # Bind functions to buttons
484
  button_gen.click(
requirements.txt CHANGED
@@ -16,4 +16,4 @@ transformers>4.20.0
16
  torch>1.0.0
17
  sentencepiece
18
  tensorflow_text
19
- textattack
 
16
  torch>1.0.0
17
  sentencepiece
18
  tensorflow_text
19
+ textattack[tensorflow]
textattack/attack_recipes/morpheus_tan_2020.py CHANGED
@@ -27,7 +27,6 @@ class MorpheusTan2020(AttackRecipe):
27
 
28
  @staticmethod
29
  def build(model_wrapper):
30
-
31
  #
32
  # Goal is to minimize BLEU score between the model output given for the
33
  # perturbed input sequence and the reference translation
 
27
 
28
  @staticmethod
29
  def build(model_wrapper):
 
30
  #
31
  # Goal is to minimize BLEU score between the model output given for the
32
  # perturbed input sequence and the reference translation
textattack/attack_recipes/seq2sick_cheng_2018_blackbox.py CHANGED
@@ -31,7 +31,6 @@ class Seq2SickCheng2018BlackBox(AttackRecipe):
31
 
32
  @staticmethod
33
  def build(model_wrapper, goal_function="non_overlapping"):
34
-
35
  #
36
  # Goal is non-overlapping output.
37
  #
 
31
 
32
  @staticmethod
33
  def build(model_wrapper, goal_function="non_overlapping"):
 
34
  #
35
  # Goal is non-overlapping output.
36
  #
textattack/attacker.py CHANGED
@@ -105,8 +105,8 @@ class Attacker:
105
  def simple_attack(self, text, label):
106
  """Internal method that carries out attack.
107
 
108
- No parallel processing is involved.
109
- """
110
  if torch.cuda.is_available():
111
  self.attack.cuda_()
112
 
@@ -120,9 +120,11 @@ class Attacker:
120
  except Exception as e:
121
  raise e
122
  # return
123
- if (isinstance(result, SkippedAttackResult) and self.attack_args.attack_n) or (
124
- not isinstance(result, SuccessfulAttackResult)
125
- and self.attack_args.num_successful_examples
 
 
126
  ):
127
  return
128
  else:
 
105
  def simple_attack(self, text, label):
106
  """Internal method that carries out attack.
107
 
108
+ No parallel processing is involved.
109
+ """
110
  if torch.cuda.is_available():
111
  self.attack.cuda_()
112
 
 
120
  except Exception as e:
121
  raise e
122
  # return
123
+ if (
124
+ isinstance(result, SkippedAttackResult) and self.attack_args.attack_n
125
+ ) or (
126
+ not isinstance(result, SuccessfulAttackResult)
127
+ and self.attack_args.num_successful_examples
128
  ):
129
  return
130
  else:
textattack/commands/augment_command.py CHANGED
@@ -32,7 +32,6 @@ class AugmentCommand(TextAttackCommand):
32
 
33
  args = textattack.AugmenterArgs(**vars(args))
34
  if args.interactive:
35
-
36
  print("\nRunning in interactive mode...\n")
37
  augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])(
38
  pct_words_to_swap=args.pct_words_to_swap,
 
32
 
33
  args = textattack.AugmenterArgs(**vars(args))
34
  if args.interactive:
 
35
  print("\nRunning in interactive mode...\n")
36
  augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])(
37
  pct_words_to_swap=args.pct_words_to_swap,
textattack/commands/eval_model_command.py CHANGED
@@ -56,7 +56,7 @@ class EvalModelCommand(TextAttackCommand):
56
  while i < min(args.num_examples, len(dataset)):
57
  dataset_batch = dataset[i : min(args.num_examples, i + args.batch_size)]
58
  batch_inputs = []
59
- for (text_input, ground_truth_output) in dataset_batch:
60
  attacked_text = textattack.shared.AttackedText(text_input)
61
  batch_inputs.append(attacked_text.tokenizer_input)
62
  ground_truth_outputs.append(ground_truth_output)
 
56
  while i < min(args.num_examples, len(dataset)):
57
  dataset_batch = dataset[i : min(args.num_examples, i + args.batch_size)]
58
  batch_inputs = []
59
+ for text_input, ground_truth_output in dataset_batch:
60
  attacked_text = textattack.shared.AttackedText(text_input)
61
  batch_inputs.append(attacked_text.tokenizer_input)
62
  ground_truth_outputs.append(ground_truth_output)
textattack/constraints/overlap/max_words_perturbed.py CHANGED
@@ -38,7 +38,6 @@ class MaxWordsPerturbed(Constraint):
38
  self.max_percent = max_percent
39
 
40
  def _check_constraint(self, transformed_text, reference_text):
41
-
42
  num_words_diff = len(transformed_text.all_words_diff(reference_text))
43
  if self.max_percent:
44
  min_num_words = min(len(transformed_text.words), len(reference_text.words))
 
38
  self.max_percent = max_percent
39
 
40
  def _check_constraint(self, transformed_text, reference_text):
 
41
  num_words_diff = len(transformed_text.all_words_diff(reference_text))
42
  if self.max_percent:
43
  min_num_words = min(len(transformed_text.words), len(reference_text.words))
textattack/goal_function_results/classification_goal_function_result.py CHANGED
@@ -26,7 +26,6 @@ class ClassificationGoalFunctionResult(GoalFunctionResult):
26
  num_queries,
27
  ground_truth_output,
28
  ):
29
-
30
  super().__init__(
31
  attacked_text,
32
  raw_output,
 
26
  num_queries,
27
  ground_truth_output,
28
  ):
 
29
  super().__init__(
30
  attacked_text,
31
  raw_output,
textattack/goal_function_results/text_to_text_goal_function_result.py CHANGED
@@ -23,7 +23,6 @@ class TextToTextGoalFunctionResult(GoalFunctionResult):
23
  num_queries,
24
  ground_truth_output,
25
  ):
26
-
27
  super().__init__(
28
  attacked_text,
29
  raw_output,
 
23
  num_queries,
24
  ground_truth_output,
25
  ):
 
26
  super().__init__(
27
  attacked_text,
28
  raw_output,
textattack/loggers/weights_and_biases_logger.py CHANGED
@@ -13,7 +13,6 @@ class WeightsAndBiasesLogger(Logger):
13
  """Logs attack results to Weights & Biases."""
14
 
15
  def __init__(self, **kwargs):
16
-
17
  global wandb
18
  wandb = LazyLoader("wandb", globals(), "wandb")
19
 
 
13
  """Logs attack results to Weights & Biases."""
14
 
15
  def __init__(self, **kwargs):
 
16
  global wandb
17
  wandb = LazyLoader("wandb", globals(), "wandb")
18
 
textattack/metrics/quality_metrics/perplexity.py CHANGED
@@ -94,7 +94,6 @@ class Perplexity(Metric):
94
  return self.all_metrics
95
 
96
  def calc_ppl(self, texts):
97
-
98
  with torch.no_grad():
99
  text = " ".join(texts)
100
  eval_loss = []
 
94
  return self.all_metrics
95
 
96
  def calc_ppl(self, texts):
 
97
  with torch.no_grad():
98
  text = " ".join(texts)
99
  eval_loss = []
textattack/models/wrappers/demo_model_wrapper.py CHANGED
@@ -2,14 +2,14 @@ from textattack.models.wrappers import HuggingFaceModelWrapper
2
 
3
 
4
  class TADModelWrapper(HuggingFaceModelWrapper):
5
- """ Transformers sentiment analysis pipeline returns a list of responses
6
- like
7
 
8
- [{'label': 'POSITIVE', 'score': 0.7817379832267761}]
9
 
10
- We need to convert that to a format TextAttack understands, like
11
 
12
- [[0.218262017, 0.7817379832267761]
13
  """
14
 
15
  def __init__(self, model):
@@ -19,6 +19,6 @@ class TADModelWrapper(HuggingFaceModelWrapper):
19
  outputs = []
20
  for text_input in text_inputs:
21
  raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
22
- outputs.append(raw_outputs['probs'])
23
 
24
  return outputs
 
2
 
3
 
4
  class TADModelWrapper(HuggingFaceModelWrapper):
5
+ """Transformers sentiment analysis pipeline returns a list of responses
6
+ like
7
 
8
+ [{'label': 'POSITIVE', 'score': 0.7817379832267761}]
9
 
10
+ We need to convert that to a format TextAttack understands, like
11
 
12
+ [[0.218262017, 0.7817379832267761]
13
  """
14
 
15
  def __init__(self, model):
 
19
  outputs = []
20
  for text_input in text_inputs:
21
  raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
22
+ outputs.append(raw_outputs["probs"])
23
 
24
  return outputs
textattack/reactive_defense/reactive_defender.py CHANGED
@@ -4,7 +4,6 @@ from textattack.shared.utils import ReprMixin
4
 
5
 
6
  class ReactiveDefender(ReprMixin, ABC):
7
-
8
  def __init__(self, **kwargs):
9
  pass
10
 
 
4
 
5
 
6
  class ReactiveDefender(ReprMixin, ABC):
 
7
  def __init__(self, **kwargs):
8
  pass
9
 
textattack/reactive_defense/tad_reactive_defender.py CHANGED
@@ -5,21 +5,24 @@ from textattack.reactive_defense.reactive_defender import ReactiveDefender
5
 
6
 
7
  class TADReactiveDefender(ReactiveDefender):
8
- """ Transformers sentiment analysis pipeline returns a list of responses
9
- like
10
 
11
- [{'label': 'POSITIVE', 'score': 0.7817379832267761}]
12
 
13
- We need to convert that to a format TextAttack understands, like
14
 
15
- [[0.218262017, 0.7817379832267761]
16
  """
17
 
18
- def __init__(self, ckpt='tad-sst2', **kwargs):
19
  super().__init__(**kwargs)
20
- self.tad_classifier = TADCheckpointManager.get_tad_text_classifier(checkpoint=DEMO_MODELS[ckpt],
21
- auto_device=True)
 
22
 
23
  def reactive_defense(self, text, **kwargs):
24
- res = self.tad_classifier.infer(text, defense='pwws', print_result=False, **kwargs)
 
 
25
  return res
 
5
 
6
 
7
  class TADReactiveDefender(ReactiveDefender):
8
+ """Transformers sentiment analysis pipeline returns a list of responses
9
+ like
10
 
11
+ [{'label': 'POSITIVE', 'score': 0.7817379832267761}]
12
 
13
+ We need to convert that to a format TextAttack understands, like
14
 
15
+ [[0.218262017, 0.7817379832267761]
16
  """
17
 
18
+ def __init__(self, ckpt="tad-sst2", **kwargs):
19
  super().__init__(**kwargs)
20
+ self.tad_classifier = TADCheckpointManager.get_tad_text_classifier(
21
+ checkpoint=DEMO_MODELS[ckpt], auto_device=True
22
+ )
23
 
24
  def reactive_defense(self, text, **kwargs):
25
+ res = self.tad_classifier.infer(
26
+ text, defense="pwws", print_result=False, **kwargs
27
+ )
28
  return res
textattack/search_methods/greedy_word_swap_wir.py CHANGED
@@ -65,7 +65,6 @@ class GreedyWordSwapWIR(SearchMethod):
65
  # compute the largest change in score we can find by swapping each word
66
  delta_ps = []
67
  for idx in indices_to_order:
68
-
69
  # Exit Loop when search_over is True - but we need to make sure delta_ps
70
  # is the same size as softmax_saliency_scores
71
  if search_over:
 
65
  # compute the largest change in score we can find by swapping each word
66
  delta_ps = []
67
  for idx in indices_to_order:
 
68
  # Exit Loop when search_over is True - but we need to make sure delta_ps
69
  # is the same size as softmax_saliency_scores
70
  if search_over:
textattack/shared/validators.py CHANGED
@@ -24,7 +24,10 @@ MODELS_BY_GOAL_FUNCTIONS = {
24
  r"^textattack.models.helpers.word_cnn_for_classification.*",
25
  r"^transformers.modeling_\w*\.\w*ForSequenceClassification$",
26
  ],
27
- (NonOverlappingOutput, MinimizeBleu,): [
 
 
 
28
  r"^textattack.models.helpers.t5_for_text_to_text.*",
29
  ],
30
  }
 
24
  r"^textattack.models.helpers.word_cnn_for_classification.*",
25
  r"^transformers.modeling_\w*\.\w*ForSequenceClassification$",
26
  ],
27
+ (
28
+ NonOverlappingOutput,
29
+ MinimizeBleu,
30
+ ): [
31
  r"^textattack.models.helpers.t5_for_text_to_text.*",
32
  ],
33
  }
textattack/trainer.py CHANGED
@@ -398,6 +398,7 @@ class Trainer:
398
  Returns:
399
  :obj:`torch.utils.data.DataLoader`
400
  """
 
401
  # TODO: Add pairing option where we can pair original examples with adversarial examples.
402
  # Helper functions for collating data
403
  def collate_fn(data):
@@ -406,7 +407,6 @@ class Trainer:
406
  is_adv_sample = []
407
  for item in data:
408
  if "_example_type" in item[0].keys():
409
-
410
  # Get example type value from OrderedDict and remove it
411
 
412
  adv = item[0].pop("_example_type")
@@ -460,6 +460,7 @@ class Trainer:
460
  Returns:
461
  :obj:`torch.utils.data.DataLoader`
462
  """
 
463
  # Helper functions for collating data
464
  def collate_fn(data):
465
  input_texts = []
 
398
  Returns:
399
  :obj:`torch.utils.data.DataLoader`
400
  """
401
+
402
  # TODO: Add pairing option where we can pair original examples with adversarial examples.
403
  # Helper functions for collating data
404
  def collate_fn(data):
 
407
  is_adv_sample = []
408
  for item in data:
409
  if "_example_type" in item[0].keys():
 
410
  # Get example type value from OrderedDict and remove it
411
 
412
  adv = item[0].pop("_example_type")
 
460
  Returns:
461
  :obj:`torch.utils.data.DataLoader`
462
  """
463
+
464
  # Helper functions for collating data
465
  def collate_fn(data):
466
  input_texts = []
textattack/training_args.py CHANGED
@@ -547,7 +547,6 @@ class _CommandLineTrainingArgs:
547
  train_dataset.output_column == "label"
548
  and eval_dataset.output_column == "label"
549
  ):
550
-
551
  train_dataset_labels = train_dataset._dataset["label"]
552
 
553
  eval_dataset_labels = eval_dataset._dataset["label"]
 
547
  train_dataset.output_column == "label"
548
  and eval_dataset.output_column == "label"
549
  ):
 
550
  train_dataset_labels = train_dataset._dataset["label"]
551
 
552
  eval_dataset_labels = eval_dataset._dataset["label"]
textattack/transformations/word_swaps/word_swap_change_name.py CHANGED
@@ -64,7 +64,6 @@ class WordSwapChangeName(WordSwap):
64
  return transformed_texts
65
 
66
  def _get_replacement_words(self, word, word_part_of_speech):
67
-
68
  replacement_words = []
69
  tag = word_part_of_speech
70
  if (
 
64
  return transformed_texts
65
 
66
  def _get_replacement_words(self, word, word_part_of_speech):
 
67
  replacement_words = []
68
  tag = word_part_of_speech
69
  if (
textattack/transformations/word_swaps/word_swap_change_number.py CHANGED
@@ -70,7 +70,7 @@ class WordSwapChangeNumber(WordSwap):
70
 
71
  # replace original numbers with new numbers
72
  transformed_texts = []
73
- for (idx, word) in num_words:
74
  replacement_words = self._get_new_number(word)
75
  for r in replacement_words:
76
  if r == word:
 
70
 
71
  # replace original numbers with new numbers
72
  transformed_texts = []
73
+ for idx, word in num_words:
74
  replacement_words = self._get_new_number(word)
75
  for r in replacement_words:
76
  if r == word: