anonymous8
commited on
Commit
•
d65ddc0
1
Parent(s):
ecdc8b8
update
Browse files- anonymous_demo/__init__.py +2 -2
- anonymous_demo/core/tad/classic/__bert__/dataset_utils/data_utils_for_inference.py +47 -42
- anonymous_demo/core/tad/classic/__bert__/models/tad_bert.py +12 -9
- anonymous_demo/core/tad/prediction/tad_classifier.py +305 -177
- anonymous_demo/functional/checkpoint/checkpoint_manager.py +4 -5
- anonymous_demo/functional/config/config_manager.py +10 -12
- anonymous_demo/functional/config/tad_config_manager.py +132 -124
- anonymous_demo/functional/dataset/__init__.py +1 -1
- anonymous_demo/functional/dataset/dataset_manager.py +30 -6
- anonymous_demo/network/lcf_pooler.py +4 -2
- anonymous_demo/network/lsa.py +34 -13
- anonymous_demo/network/sa_encoder.py +57 -17
- anonymous_demo/utils/demo_utils.py +86 -48
- anonymous_demo/utils/logger.py +5 -5
- app.py +31 -20
- requirements.txt +1 -1
- textattack/attack_recipes/morpheus_tan_2020.py +0 -1
- textattack/attack_recipes/seq2sick_cheng_2018_blackbox.py +0 -1
- textattack/attacker.py +7 -5
- textattack/commands/augment_command.py +0 -1
- textattack/commands/eval_model_command.py +1 -1
- textattack/constraints/overlap/max_words_perturbed.py +0 -1
- textattack/goal_function_results/classification_goal_function_result.py +0 -1
- textattack/goal_function_results/text_to_text_goal_function_result.py +0 -1
- textattack/loggers/weights_and_biases_logger.py +0 -1
- textattack/metrics/quality_metrics/perplexity.py +0 -1
- textattack/models/wrappers/demo_model_wrapper.py +6 -6
- textattack/reactive_defense/reactive_defender.py +0 -1
- textattack/reactive_defense/tad_reactive_defender.py +12 -9
- textattack/search_methods/greedy_word_swap_wir.py +0 -1
- textattack/shared/validators.py +4 -1
- textattack/trainer.py +2 -1
- textattack/training_args.py +0 -1
- textattack/transformations/word_swaps/word_swap_change_name.py +0 -1
- textattack/transformations/word_swaps/word_swap_change_number.py +1 -1
anonymous_demo/__init__.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
__version__ =
|
2 |
|
3 |
-
__name__ =
|
4 |
|
5 |
from anonymous_demo.functional import TADCheckpointManager
|
|
|
1 |
+
__version__ = "1.0.0"
|
2 |
|
3 |
+
__name__ = "anonymous_demo"
|
4 |
|
5 |
from anonymous_demo.functional import TADCheckpointManager
|
anonymous_demo/core/tad/classic/__bert__/dataset_utils/data_utils_for_inference.py
CHANGED
@@ -6,26 +6,30 @@ from transformers import AutoTokenizer
|
|
6 |
|
7 |
class Tokenizer4Pretraining:
|
8 |
def __init__(self, max_seq_len, opt, **kwargs):
|
9 |
-
if kwargs.pop(
|
10 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
11 |
-
|
|
|
|
|
12 |
else:
|
13 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
14 |
-
|
|
|
15 |
self.max_seq_len = max_seq_len
|
16 |
|
17 |
-
def text_to_sequence(self, text, reverse=False, padding=
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
class BERTTADDataset(Dataset):
|
24 |
-
|
25 |
def __init__(self, tokenizer, opt):
|
26 |
-
self.bert_baseline_input_colses = {
|
27 |
-
'bert': ['text_bert_indices']
|
28 |
-
}
|
29 |
|
30 |
self.tokenizer = tokenizer
|
31 |
self.opt = opt
|
@@ -40,33 +44,39 @@ class BERTTADDataset(Dataset):
|
|
40 |
def process_data(self, samples, ignore_error=True):
|
41 |
all_data = []
|
42 |
if len(samples) > 100:
|
43 |
-
it = tqdm.tqdm(
|
|
|
|
|
44 |
else:
|
45 |
it = samples
|
46 |
for text in it:
|
47 |
try:
|
48 |
# handle for empty lines in inference datasets
|
49 |
-
if text is None or
|
50 |
-
raise RuntimeError(
|
51 |
|
52 |
-
if
|
53 |
-
text, _, labels = text.strip().partition(
|
54 |
text = text.strip()
|
55 |
-
if labels.count(
|
56 |
-
label, is_adv, adv_train_label = labels.strip().split(
|
57 |
-
label, is_adv, adv_train_label =
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
60 |
label, is_adv = label.strip(), is_adv.strip()
|
61 |
-
adv_train_label =
|
62 |
-
elif labels.count(
|
63 |
label = labels.strip()
|
64 |
-
adv_train_label =
|
65 |
-
is_adv =
|
66 |
else:
|
67 |
-
label =
|
68 |
-
adv_train_label =
|
69 |
-
is_adv =
|
70 |
|
71 |
label = int(label)
|
72 |
adv_train_label = int(adv_train_label)
|
@@ -78,19 +88,14 @@ class BERTTADDataset(Dataset):
|
|
78 |
adv_train_label = -100
|
79 |
is_adv = -100
|
80 |
|
81 |
-
text_indices = self.tokenizer.text_to_sequence(
|
82 |
|
83 |
data = {
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
'adv_train_label': adv_train_label,
|
91 |
-
|
92 |
-
'is_adv': is_adv,
|
93 |
-
|
94 |
# 'label': self.opt.label_to_index.get(label, -100) if isinstance(label, str) else label,
|
95 |
#
|
96 |
# 'adv_train_label': self.opt.adv_train_label_to_index.get(adv_train_label, -100) if isinstance(adv_train_label, str) else adv_train_label,
|
@@ -102,7 +107,7 @@ class BERTTADDataset(Dataset):
|
|
102 |
|
103 |
except Exception as e:
|
104 |
if ignore_error:
|
105 |
-
print(
|
106 |
else:
|
107 |
raise e
|
108 |
|
|
|
6 |
|
7 |
class Tokenizer4Pretraining:
|
8 |
def __init__(self, max_seq_len, opt, **kwargs):
|
9 |
+
if kwargs.pop("offline", False):
|
10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
11 |
+
find_cwd_dir(opt.pretrained_bert.split("/")[-1]),
|
12 |
+
do_lower_case="uncased" in opt.pretrained_bert,
|
13 |
+
)
|
14 |
else:
|
15 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
16 |
+
opt.pretrained_bert, do_lower_case="uncased" in opt.pretrained_bert
|
17 |
+
)
|
18 |
self.max_seq_len = max_seq_len
|
19 |
|
20 |
+
def text_to_sequence(self, text, reverse=False, padding="post", truncating="post"):
|
21 |
+
return self.tokenizer.encode(
|
22 |
+
text,
|
23 |
+
truncation=True,
|
24 |
+
padding="max_length",
|
25 |
+
max_length=self.max_seq_len,
|
26 |
+
return_tensors="pt",
|
27 |
+
)
|
28 |
|
29 |
|
30 |
class BERTTADDataset(Dataset):
|
|
|
31 |
def __init__(self, tokenizer, opt):
|
32 |
+
self.bert_baseline_input_colses = {"bert": ["text_bert_indices"]}
|
|
|
|
|
33 |
|
34 |
self.tokenizer = tokenizer
|
35 |
self.opt = opt
|
|
|
44 |
def process_data(self, samples, ignore_error=True):
|
45 |
all_data = []
|
46 |
if len(samples) > 100:
|
47 |
+
it = tqdm.tqdm(
|
48 |
+
samples, postfix="preparing text classification inference dataloader..."
|
49 |
+
)
|
50 |
else:
|
51 |
it = samples
|
52 |
for text in it:
|
53 |
try:
|
54 |
# handle for empty lines in inference datasets
|
55 |
+
if text is None or "" == text.strip():
|
56 |
+
raise RuntimeError("Invalid Input!")
|
57 |
|
58 |
+
if "!ref!" in text:
|
59 |
+
text, _, labels = text.strip().partition("!ref!")
|
60 |
text = text.strip()
|
61 |
+
if labels.count(",") == 2:
|
62 |
+
label, is_adv, adv_train_label = labels.strip().split(",")
|
63 |
+
label, is_adv, adv_train_label = (
|
64 |
+
label.strip(),
|
65 |
+
is_adv.strip(),
|
66 |
+
adv_train_label.strip(),
|
67 |
+
)
|
68 |
+
elif labels.count(",") == 1:
|
69 |
+
label, is_adv = labels.strip().split(",")
|
70 |
label, is_adv = label.strip(), is_adv.strip()
|
71 |
+
adv_train_label = "-100"
|
72 |
+
elif labels.count(",") == 0:
|
73 |
label = labels.strip()
|
74 |
+
adv_train_label = "-100"
|
75 |
+
is_adv = "-100"
|
76 |
else:
|
77 |
+
label = "-100"
|
78 |
+
adv_train_label = "-100"
|
79 |
+
is_adv = "-100"
|
80 |
|
81 |
label = int(label)
|
82 |
adv_train_label = int(adv_train_label)
|
|
|
88 |
adv_train_label = -100
|
89 |
is_adv = -100
|
90 |
|
91 |
+
text_indices = self.tokenizer.text_to_sequence("{}".format(text))
|
92 |
|
93 |
data = {
|
94 |
+
"text_bert_indices": text_indices[0],
|
95 |
+
"text_raw": text,
|
96 |
+
"label": label,
|
97 |
+
"adv_train_label": adv_train_label,
|
98 |
+
"is_adv": is_adv,
|
|
|
|
|
|
|
|
|
|
|
99 |
# 'label': self.opt.label_to_index.get(label, -100) if isinstance(label, str) else label,
|
100 |
#
|
101 |
# 'adv_train_label': self.opt.adv_train_label_to_index.get(adv_train_label, -100) if isinstance(adv_train_label, str) else adv_train_label,
|
|
|
107 |
|
108 |
except Exception as e:
|
109 |
if ignore_error:
|
110 |
+
print("Ignore error while processing:", text)
|
111 |
else:
|
112 |
raise e
|
113 |
|
anonymous_demo/core/tad/classic/__bert__/models/tad_bert.py
CHANGED
@@ -6,7 +6,7 @@ from anonymous_demo.network.sa_encoder import Encoder
|
|
6 |
|
7 |
|
8 |
class TADBERT(nn.Module):
|
9 |
-
inputs = [
|
10 |
|
11 |
def __init__(self, bert, opt):
|
12 |
super(TADBERT, self).__init__()
|
@@ -23,21 +23,24 @@ class TADBERT(nn.Module):
|
|
23 |
|
24 |
def forward(self, inputs):
|
25 |
text_raw_indices = inputs[0]
|
26 |
-
last_hidden_state = self.bert(text_raw_indices)[
|
27 |
|
28 |
sent_logits = self.dense1(self.pooler(last_hidden_state))
|
29 |
advdet_logits = self.dense2(self.pooler(last_hidden_state))
|
30 |
adv_tr_logits = self.dense3(self.pooler(last_hidden_state))
|
31 |
|
32 |
att_score = torch.nn.functional.normalize(
|
33 |
-
last_hidden_state.abs().sum(dim=1, keepdim=False)
|
34 |
-
|
|
|
|
|
|
|
35 |
|
36 |
outputs = {
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
}
|
43 |
return outputs
|
|
|
6 |
|
7 |
|
8 |
class TADBERT(nn.Module):
|
9 |
+
inputs = ["text_bert_indices"]
|
10 |
|
11 |
def __init__(self, bert, opt):
|
12 |
super(TADBERT, self).__init__()
|
|
|
23 |
|
24 |
def forward(self, inputs):
|
25 |
text_raw_indices = inputs[0]
|
26 |
+
last_hidden_state = self.bert(text_raw_indices)["last_hidden_state"]
|
27 |
|
28 |
sent_logits = self.dense1(self.pooler(last_hidden_state))
|
29 |
advdet_logits = self.dense2(self.pooler(last_hidden_state))
|
30 |
adv_tr_logits = self.dense3(self.pooler(last_hidden_state))
|
31 |
|
32 |
att_score = torch.nn.functional.normalize(
|
33 |
+
last_hidden_state.abs().sum(dim=1, keepdim=False)
|
34 |
+
- last_hidden_state.abs().min(dim=1, keepdim=True)[0],
|
35 |
+
p=1,
|
36 |
+
dim=1,
|
37 |
+
)
|
38 |
|
39 |
outputs = {
|
40 |
+
"sent_logits": sent_logits,
|
41 |
+
"advdet_logits": advdet_logits,
|
42 |
+
"adv_tr_logits": adv_tr_logits,
|
43 |
+
"last_hidden_state": last_hidden_state,
|
44 |
+
"att_score": att_score,
|
45 |
}
|
46 |
return outputs
|
anonymous_demo/core/tad/prediction/tad_classifier.py
CHANGED
@@ -9,21 +9,43 @@ from findfile import find_file, find_cwd_dir
|
|
9 |
from termcolor import colored
|
10 |
|
11 |
from torch.utils.data import DataLoader
|
12 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
from ....functional.dataset.dataset_manager import detect_infer_dataset
|
15 |
|
16 |
from ..models import BERTTADModelList
|
17 |
-
from ..classic.__bert__.dataset_utils.data_utils_for_inference import
|
|
|
|
|
|
|
18 |
|
19 |
-
from ....utils.demo_utils import
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
def init_attacker(tad_classifier, defense):
|
23 |
try:
|
24 |
from textattack import Attacker
|
25 |
-
from textattack.attack_recipes import
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
from textattack.datasets import Dataset
|
28 |
from textattack.models.wrappers import HuggingFaceModelWrapper
|
29 |
|
@@ -34,36 +56,36 @@ def init_attacker(tad_classifier, defense):
|
|
34 |
def __call__(self, text_inputs, **kwargs):
|
35 |
outputs = []
|
36 |
for text_input in text_inputs:
|
37 |
-
raw_outputs = self.model.infer(
|
38 |
-
|
|
|
|
|
39 |
return outputs
|
40 |
|
41 |
class SentAttacker:
|
42 |
-
|
43 |
def __init__(self, model, recipe_class=BAEGarg2019):
|
44 |
model = model
|
45 |
model_wrapper = DemoModelWrapper(model)
|
46 |
|
47 |
recipe = recipe_class.build(model_wrapper)
|
48 |
|
49 |
-
_dataset = [(
|
50 |
_dataset = Dataset(_dataset)
|
51 |
|
52 |
self.attacker = Attacker(recipe, _dataset)
|
53 |
|
54 |
attackers = {
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
}
|
63 |
return SentAttacker(tad_classifier, attackers[defense])
|
64 |
except Exception as e:
|
65 |
-
|
66 |
-
print('Original error:', e)
|
67 |
|
68 |
|
69 |
def get_mlm_and_tokenizer(text_classifier, config):
|
@@ -72,10 +94,10 @@ def get_mlm_and_tokenizer(text_classifier, config):
|
|
72 |
else:
|
73 |
base_model = text_classifier.bert.base_model
|
74 |
pretrained_config = AutoConfig.from_pretrained(config.pretrained_bert)
|
75 |
-
if
|
76 |
MLM = DebertaV2ForMaskedLM(pretrained_config)
|
77 |
MLM.deberta = base_model
|
78 |
-
elif
|
79 |
MLM = RobertaForMaskedLM(pretrained_config)
|
80 |
MLM.roberta = base_model
|
81 |
else:
|
@@ -86,64 +108,85 @@ def get_mlm_and_tokenizer(text_classifier, config):
|
|
86 |
|
87 |
class TADTextClassifier:
|
88 |
def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
self.cal_perplexity = cal_perplexity
|
93 |
# load from a training
|
94 |
if not isinstance(model_arg, str):
|
95 |
-
print(
|
96 |
self.model = model_arg[0]
|
97 |
self.opt = model_arg[1]
|
98 |
self.tokenizer = model_arg[2]
|
99 |
else:
|
100 |
try:
|
101 |
-
if
|
102 |
raise ValueError(
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
self.opt = pickle.load(f)
|
117 |
-
self.opt.device = get_device(kwargs.pop(
|
118 |
|
119 |
if state_dict_path or model_path:
|
120 |
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
121 |
if state_dict_path:
|
122 |
-
if kwargs.pop(
|
123 |
self.bert = AutoModel.from_pretrained(
|
124 |
-
find_cwd_dir(
|
|
|
|
|
|
|
125 |
else:
|
126 |
-
self.bert = AutoModel.from_pretrained(
|
|
|
|
|
127 |
self.model = self.opt.model(self.bert, self.opt)
|
128 |
-
self.model.load_state_dict(
|
|
|
|
|
129 |
elif model_path:
|
130 |
-
self.model = torch.load(model_path, map_location=
|
131 |
|
132 |
try:
|
133 |
-
self.tokenizer = Tokenizer4Pretraining(
|
134 |
-
|
|
|
135 |
except ValueError:
|
136 |
if tokenizer_path:
|
137 |
-
with open(tokenizer_path, mode=
|
138 |
self.tokenizer = pickle.load(f)
|
139 |
else:
|
140 |
raise TransformerConnectionError()
|
141 |
|
142 |
except Exception as e:
|
143 |
-
raise RuntimeError(
|
|
|
|
|
|
|
|
|
144 |
|
145 |
self.infer_dataloader = None
|
146 |
-
self.opt.eval_batch_size = kwargs.pop(
|
147 |
|
148 |
self.opt.initializer = self.opt.initializer
|
149 |
|
@@ -158,19 +201,19 @@ class TADTextClassifier:
|
|
158 |
def to(self, device=None):
|
159 |
self.opt.device = device
|
160 |
self.model.to(device)
|
161 |
-
if hasattr(self,
|
162 |
self.MLM.to(self.opt.device)
|
163 |
|
164 |
def cpu(self):
|
165 |
-
self.opt.device =
|
166 |
-
self.model.to(
|
167 |
-
if hasattr(self,
|
168 |
-
self.MLM.to(
|
169 |
|
170 |
-
def cuda(self, device=
|
171 |
self.opt.device = device
|
172 |
self.model.to(device)
|
173 |
-
if hasattr(self,
|
174 |
self.MLM.to(device)
|
175 |
|
176 |
def _log_write_args(self):
|
@@ -182,55 +225,67 @@ class TADTextClassifier:
|
|
182 |
else:
|
183 |
n_nontrainable_params += n_params
|
184 |
print(
|
185 |
-
|
|
|
|
|
|
|
186 |
for arg in vars(self.opt):
|
187 |
if getattr(self.opt, arg) is not None:
|
188 |
-
print(
|
189 |
-
|
190 |
-
def batch_infer(
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
save_path = os.path.join(os.getcwd(),
|
199 |
-
|
200 |
-
target_file = detect_infer_dataset(target_file, task=
|
201 |
if not target_file:
|
202 |
-
raise FileNotFoundError(
|
203 |
|
204 |
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
205 |
dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
|
206 |
|
207 |
dataset.prepare_infer_dataset(target_file, ignore_error=ignore_error)
|
208 |
-
self.infer_dataloader = DataLoader(
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
220 |
dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
|
221 |
|
222 |
if text:
|
223 |
dataset.prepare_infer_sample(text, ignore_error=ignore_error)
|
224 |
else:
|
225 |
-
raise RuntimeError(
|
226 |
-
self.infer_dataloader = DataLoader(
|
|
|
|
|
227 |
return self._infer(print_result=print_result, defense=defense)[0]
|
228 |
|
229 |
def _infer(self, save_path=None, print_result=True, defense=None):
|
230 |
-
|
231 |
_params = filter(lambda p: p.requires_grad, self.model.parameters())
|
232 |
|
233 |
-
correct = {True:
|
234 |
results = []
|
235 |
|
236 |
with torch.no_grad():
|
@@ -241,86 +296,130 @@ class TADTextClassifier:
|
|
241 |
n_advdet_correct = 0
|
242 |
n_advdet_labeled = 0
|
243 |
if len(self.infer_dataloader.dataset) >= 100:
|
244 |
-
it = tqdm.tqdm(self.infer_dataloader, postfix=
|
245 |
else:
|
246 |
it = self.infer_dataloader
|
247 |
for _, sample in enumerate(it):
|
248 |
-
inputs = [
|
|
|
|
|
249 |
outputs = self.model(inputs)
|
250 |
-
logits, advdet_logits, adv_tr_logits =
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
pred_label = int(prob.argmax(axis=-1))
|
260 |
pred_is_adv_label = int(advdet_prob.argmax(axis=-1))
|
261 |
pred_adv_tr_label = int(adv_tr_prob.argmax(axis=-1))
|
262 |
-
ref_label =
|
263 |
-
|
264 |
-
sample[
|
265 |
-
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
|
268 |
if self.cal_perplexity:
|
269 |
ids = self.MLM_tokenizer(text_raw, return_tensors="pt")
|
270 |
-
ids[
|
271 |
ids = ids.to(self.opt.device)
|
272 |
-
loss = self.MLM(**ids)[
|
273 |
-
perplexity = float(torch.exp(loss / ids[
|
274 |
else:
|
275 |
-
perplexity =
|
276 |
|
277 |
result = {
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
}
|
298 |
if defense:
|
299 |
try:
|
300 |
-
if not hasattr(self,
|
301 |
-
self.sent_attacker = init_attacker(
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
else:
|
312 |
-
result[
|
313 |
-
result[
|
314 |
|
315 |
except Exception as e:
|
316 |
-
print(
|
|
|
|
|
|
|
|
|
317 |
time.sleep(10)
|
318 |
-
raise RuntimeError(
|
319 |
|
320 |
if ref_label != -100:
|
321 |
n_labeled += 1
|
322 |
|
323 |
-
if result[
|
324 |
n_correct += 1
|
325 |
|
326 |
if ref_is_adv_label != -100:
|
@@ -333,56 +432,85 @@ class TADTextClassifier:
|
|
333 |
try:
|
334 |
if print_result:
|
335 |
for ex_id, result in enumerate(results):
|
336 |
-
text_printing = result[
|
337 |
-
text_info =
|
338 |
-
if result[
|
339 |
-
if not result[
|
340 |
-
text_info +=
|
341 |
-
|
342 |
-
|
343 |
-
|
|
|
|
|
344 |
text_info += colored(
|
345 |
-
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
347 |
else:
|
348 |
text_info += colored(
|
349 |
-
|
350 |
-
|
|
|
|
|
|
|
|
|
|
|
351 |
|
352 |
# AdvDet
|
353 |
-
if result[
|
354 |
-
if not result[
|
355 |
-
text_info +=
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
|
|
|
|
|
|
|
|
365 |
else:
|
366 |
-
text_info += colored(
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
|
|
|
|
372 |
text_printing += text_info
|
373 |
if self.cal_perplexity:
|
374 |
-
text_printing += colored(
|
375 |
-
|
|
|
|
|
|
|
376 |
if save_path:
|
377 |
-
with open(save_path,
|
378 |
json.dump(str(results), fout, ensure_ascii=False)
|
379 |
-
print(
|
380 |
except Exception as e:
|
381 |
-
print(
|
382 |
|
383 |
if len(results) > 1:
|
384 |
-
print(
|
385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
return results
|
388 |
|
|
|
9 |
from termcolor import colored
|
10 |
|
11 |
from torch.utils.data import DataLoader
|
12 |
+
from transformers import (
|
13 |
+
AutoTokenizer,
|
14 |
+
AutoModel,
|
15 |
+
AutoConfig,
|
16 |
+
DebertaV2ForMaskedLM,
|
17 |
+
RobertaForMaskedLM,
|
18 |
+
BertForMaskedLM,
|
19 |
+
)
|
20 |
|
21 |
from ....functional.dataset.dataset_manager import detect_infer_dataset
|
22 |
|
23 |
from ..models import BERTTADModelList
|
24 |
+
from ..classic.__bert__.dataset_utils.data_utils_for_inference import (
|
25 |
+
BERTTADDataset,
|
26 |
+
Tokenizer4Pretraining,
|
27 |
+
)
|
28 |
|
29 |
+
from ....utils.demo_utils import (
|
30 |
+
print_args,
|
31 |
+
TransformerConnectionError,
|
32 |
+
get_device,
|
33 |
+
build_embedding_matrix,
|
34 |
+
)
|
35 |
|
36 |
|
37 |
def init_attacker(tad_classifier, defense):
|
38 |
try:
|
39 |
from textattack import Attacker
|
40 |
+
from textattack.attack_recipes import (
|
41 |
+
BAEGarg2019,
|
42 |
+
PWWSRen2019,
|
43 |
+
TextFoolerJin2019,
|
44 |
+
PSOZang2020,
|
45 |
+
IGAWang2019,
|
46 |
+
GeneticAlgorithmAlzantot2018,
|
47 |
+
DeepWordBugGao2018,
|
48 |
+
)
|
49 |
from textattack.datasets import Dataset
|
50 |
from textattack.models.wrappers import HuggingFaceModelWrapper
|
51 |
|
|
|
56 |
def __call__(self, text_inputs, **kwargs):
|
57 |
outputs = []
|
58 |
for text_input in text_inputs:
|
59 |
+
raw_outputs = self.model.infer(
|
60 |
+
text_input, print_result=False, **kwargs
|
61 |
+
)
|
62 |
+
outputs.append(raw_outputs["probs"])
|
63 |
return outputs
|
64 |
|
65 |
class SentAttacker:
|
|
|
66 |
def __init__(self, model, recipe_class=BAEGarg2019):
|
67 |
model = model
|
68 |
model_wrapper = DemoModelWrapper(model)
|
69 |
|
70 |
recipe = recipe_class.build(model_wrapper)
|
71 |
|
72 |
+
_dataset = [("", 0)]
|
73 |
_dataset = Dataset(_dataset)
|
74 |
|
75 |
self.attacker = Attacker(recipe, _dataset)
|
76 |
|
77 |
attackers = {
|
78 |
+
"bae": BAEGarg2019,
|
79 |
+
"pwws": PWWSRen2019,
|
80 |
+
"textfooler": TextFoolerJin2019,
|
81 |
+
"pso": PSOZang2020,
|
82 |
+
"iga": IGAWang2019,
|
83 |
+
"ga": GeneticAlgorithmAlzantot2018,
|
84 |
+
"wordbugger": DeepWordBugGao2018,
|
85 |
}
|
86 |
return SentAttacker(tad_classifier, attackers[defense])
|
87 |
except Exception as e:
|
88 |
+
print("Original error:", e)
|
|
|
89 |
|
90 |
|
91 |
def get_mlm_and_tokenizer(text_classifier, config):
|
|
|
94 |
else:
|
95 |
base_model = text_classifier.bert.base_model
|
96 |
pretrained_config = AutoConfig.from_pretrained(config.pretrained_bert)
|
97 |
+
if "deberta-v3" in config.pretrained_bert:
|
98 |
MLM = DebertaV2ForMaskedLM(pretrained_config)
|
99 |
MLM.deberta = base_model
|
100 |
+
elif "roberta" in config.pretrained_bert:
|
101 |
MLM = RobertaForMaskedLM(pretrained_config)
|
102 |
MLM.roberta = base_model
|
103 |
else:
|
|
|
108 |
|
109 |
class TADTextClassifier:
|
110 |
def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):
|
111 |
+
"""
|
112 |
+
from_train_model: load inference model from trained model
|
113 |
+
"""
|
114 |
self.cal_perplexity = cal_perplexity
|
115 |
# load from a training
|
116 |
if not isinstance(model_arg, str):
|
117 |
+
print("Load text classifier from training")
|
118 |
self.model = model_arg[0]
|
119 |
self.opt = model_arg[1]
|
120 |
self.tokenizer = model_arg[2]
|
121 |
else:
|
122 |
try:
|
123 |
+
if "fine-tuned" in model_arg:
|
124 |
raise ValueError(
|
125 |
+
"Do not support to directly load a fine-tuned model, please load a .state_dict or .model instead!"
|
126 |
+
)
|
127 |
+
print("Load text classifier from", model_arg)
|
128 |
+
state_dict_path = find_file(
|
129 |
+
model_arg, key=".state_dict", exclude_key=["__MACOSX"]
|
130 |
+
)
|
131 |
+
model_path = find_file(
|
132 |
+
model_arg, key=".model", exclude_key=["__MACOSX"]
|
133 |
+
)
|
134 |
+
tokenizer_path = find_file(
|
135 |
+
model_arg, key=".tokenizer", exclude_key=["__MACOSX"]
|
136 |
+
)
|
137 |
+
config_path = find_file(
|
138 |
+
model_arg, key=".config", exclude_key=["__MACOSX"]
|
139 |
+
)
|
140 |
+
|
141 |
+
print("config: {}".format(config_path))
|
142 |
+
print("state_dict: {}".format(state_dict_path))
|
143 |
+
print("model: {}".format(model_path))
|
144 |
+
print("tokenizer: {}".format(tokenizer_path))
|
145 |
+
|
146 |
+
with open(config_path, mode="rb") as f:
|
147 |
self.opt = pickle.load(f)
|
148 |
+
self.opt.device = get_device(kwargs.pop("auto_device", True))[0]
|
149 |
|
150 |
if state_dict_path or model_path:
|
151 |
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
152 |
if state_dict_path:
|
153 |
+
if kwargs.pop("offline", False):
|
154 |
self.bert = AutoModel.from_pretrained(
|
155 |
+
find_cwd_dir(
|
156 |
+
self.opt.pretrained_bert.split("/")[-1]
|
157 |
+
)
|
158 |
+
)
|
159 |
else:
|
160 |
+
self.bert = AutoModel.from_pretrained(
|
161 |
+
self.opt.pretrained_bert
|
162 |
+
)
|
163 |
self.model = self.opt.model(self.bert, self.opt)
|
164 |
+
self.model.load_state_dict(
|
165 |
+
torch.load(state_dict_path, map_location="cpu")
|
166 |
+
)
|
167 |
elif model_path:
|
168 |
+
self.model = torch.load(model_path, map_location="cpu")
|
169 |
|
170 |
try:
|
171 |
+
self.tokenizer = Tokenizer4Pretraining(
|
172 |
+
max_seq_len=self.opt.max_seq_len, opt=self.opt, **kwargs
|
173 |
+
)
|
174 |
except ValueError:
|
175 |
if tokenizer_path:
|
176 |
+
with open(tokenizer_path, mode="rb") as f:
|
177 |
self.tokenizer = pickle.load(f)
|
178 |
else:
|
179 |
raise TransformerConnectionError()
|
180 |
|
181 |
except Exception as e:
|
182 |
+
raise RuntimeError(
|
183 |
+
"Exception: {} Fail to load the model from {}! ".format(
|
184 |
+
e, model_arg
|
185 |
+
)
|
186 |
+
)
|
187 |
|
188 |
self.infer_dataloader = None
|
189 |
+
self.opt.eval_batch_size = kwargs.pop("eval_batch_size", 128)
|
190 |
|
191 |
self.opt.initializer = self.opt.initializer
|
192 |
|
|
|
201 |
def to(self, device=None):
|
202 |
self.opt.device = device
|
203 |
self.model.to(device)
|
204 |
+
if hasattr(self, "MLM"):
|
205 |
self.MLM.to(self.opt.device)
|
206 |
|
207 |
def cpu(self):
|
208 |
+
self.opt.device = "cpu"
|
209 |
+
self.model.to("cpu")
|
210 |
+
if hasattr(self, "MLM"):
|
211 |
+
self.MLM.to("cpu")
|
212 |
|
213 |
+
def cuda(self, device="cuda:0"):
|
214 |
self.opt.device = device
|
215 |
self.model.to(device)
|
216 |
+
if hasattr(self, "MLM"):
|
217 |
self.MLM.to(device)
|
218 |
|
219 |
def _log_write_args(self):
|
|
|
225 |
else:
|
226 |
n_nontrainable_params += n_params
|
227 |
print(
|
228 |
+
"n_trainable_params: {0}, n_nontrainable_params: {1}".format(
|
229 |
+
n_trainable_params, n_nontrainable_params
|
230 |
+
)
|
231 |
+
)
|
232 |
for arg in vars(self.opt):
|
233 |
if getattr(self.opt, arg) is not None:
|
234 |
+
print(">>> {0}: {1}".format(arg, getattr(self.opt, arg)))
|
235 |
+
|
236 |
+
def batch_infer(
|
237 |
+
self,
|
238 |
+
target_file=None,
|
239 |
+
print_result=True,
|
240 |
+
save_result=False,
|
241 |
+
ignore_error=True,
|
242 |
+
defense: str = None,
|
243 |
+
):
|
244 |
+
save_path = os.path.join(os.getcwd(), "tad_text_classification.result.json")
|
245 |
+
|
246 |
+
target_file = detect_infer_dataset(target_file, task="text_defense")
|
247 |
if not target_file:
|
248 |
+
raise FileNotFoundError("Can not find inference datasets!")
|
249 |
|
250 |
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
251 |
dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
|
252 |
|
253 |
dataset.prepare_infer_dataset(target_file, ignore_error=ignore_error)
|
254 |
+
self.infer_dataloader = DataLoader(
|
255 |
+
dataset=dataset,
|
256 |
+
batch_size=self.opt.eval_batch_size,
|
257 |
+
pin_memory=True,
|
258 |
+
shuffle=False,
|
259 |
+
)
|
260 |
+
return self._infer(
|
261 |
+
save_path=save_path if save_result else None,
|
262 |
+
print_result=print_result,
|
263 |
+
defense=defense,
|
264 |
+
)
|
265 |
+
|
266 |
+
def infer(
|
267 |
+
self,
|
268 |
+
text: str = None,
|
269 |
+
print_result=True,
|
270 |
+
ignore_error=True,
|
271 |
+
defense: str = None,
|
272 |
+
):
|
273 |
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
274 |
dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
|
275 |
|
276 |
if text:
|
277 |
dataset.prepare_infer_sample(text, ignore_error=ignore_error)
|
278 |
else:
|
279 |
+
raise RuntimeError("Please specify your datasets path!")
|
280 |
+
self.infer_dataloader = DataLoader(
|
281 |
+
dataset=dataset, batch_size=self.opt.eval_batch_size, shuffle=False
|
282 |
+
)
|
283 |
return self._infer(print_result=print_result, defense=defense)[0]
|
284 |
|
285 |
def _infer(self, save_path=None, print_result=True, defense=None):
|
|
|
286 |
_params = filter(lambda p: p.requires_grad, self.model.parameters())
|
287 |
|
288 |
+
correct = {True: "Correct", False: "Wrong"}
|
289 |
results = []
|
290 |
|
291 |
with torch.no_grad():
|
|
|
296 |
n_advdet_correct = 0
|
297 |
n_advdet_labeled = 0
|
298 |
if len(self.infer_dataloader.dataset) >= 100:
|
299 |
+
it = tqdm.tqdm(self.infer_dataloader, postfix="inferring...")
|
300 |
else:
|
301 |
it = self.infer_dataloader
|
302 |
for _, sample in enumerate(it):
|
303 |
+
inputs = [
|
304 |
+
sample[col].to(self.opt.device) for col in self.opt.inputs_cols
|
305 |
+
]
|
306 |
outputs = self.model(inputs)
|
307 |
+
logits, advdet_logits, adv_tr_logits = (
|
308 |
+
outputs["sent_logits"],
|
309 |
+
outputs["advdet_logits"],
|
310 |
+
outputs["adv_tr_logits"],
|
311 |
+
)
|
312 |
+
probs, advdet_probs, adv_tr_probs = (
|
313 |
+
torch.softmax(logits, dim=-1),
|
314 |
+
torch.softmax(advdet_logits, dim=-1),
|
315 |
+
torch.softmax(adv_tr_logits, dim=-1),
|
316 |
+
)
|
317 |
+
|
318 |
+
for i, (prob, advdet_prob, adv_tr_prob) in enumerate(
|
319 |
+
zip(probs, advdet_probs, adv_tr_probs)
|
320 |
+
):
|
321 |
+
text_raw = sample["text_raw"][i]
|
322 |
|
323 |
pred_label = int(prob.argmax(axis=-1))
|
324 |
pred_is_adv_label = int(advdet_prob.argmax(axis=-1))
|
325 |
pred_adv_tr_label = int(adv_tr_prob.argmax(axis=-1))
|
326 |
+
ref_label = (
|
327 |
+
int(sample["label"][i])
|
328 |
+
if int(sample["label"][i]) in self.opt.index_to_label
|
329 |
+
else ""
|
330 |
+
)
|
331 |
+
ref_is_adv_label = (
|
332 |
+
int(sample["is_adv"][i])
|
333 |
+
if int(sample["is_adv"][i]) in self.opt.index_to_is_adv
|
334 |
+
else ""
|
335 |
+
)
|
336 |
+
ref_adv_tr_label = (
|
337 |
+
int(sample["adv_train_label"][i])
|
338 |
+
if int(sample["adv_train_label"][i])
|
339 |
+
in self.opt.index_to_adv_train_label
|
340 |
+
else ""
|
341 |
+
)
|
342 |
|
343 |
if self.cal_perplexity:
|
344 |
ids = self.MLM_tokenizer(text_raw, return_tensors="pt")
|
345 |
+
ids["labels"] = ids["input_ids"].clone()
|
346 |
ids = ids.to(self.opt.device)
|
347 |
+
loss = self.MLM(**ids)["loss"]
|
348 |
+
perplexity = float(torch.exp(loss / ids["input_ids"].size(1)))
|
349 |
else:
|
350 |
+
perplexity = "N.A."
|
351 |
|
352 |
result = {
|
353 |
+
"text": text_raw,
|
354 |
+
"label": self.opt.index_to_label[pred_label],
|
355 |
+
"probs": prob.cpu().numpy(),
|
356 |
+
"confidence": float(max(prob)),
|
357 |
+
"ref_label": self.opt.index_to_label[ref_label]
|
358 |
+
if isinstance(ref_label, int)
|
359 |
+
else ref_label,
|
360 |
+
"ref_label_check": correct[pred_label == ref_label]
|
361 |
+
if ref_label != -100
|
362 |
+
else "",
|
363 |
+
"is_fixed": False,
|
364 |
+
"is_adv_label": self.opt.index_to_is_adv[pred_is_adv_label],
|
365 |
+
"is_adv_probs": advdet_prob.cpu().numpy(),
|
366 |
+
"is_adv_confidence": float(max(advdet_prob)),
|
367 |
+
"ref_is_adv_label": self.opt.index_to_is_adv[ref_is_adv_label]
|
368 |
+
if isinstance(ref_is_adv_label, int)
|
369 |
+
else ref_is_adv_label,
|
370 |
+
"ref_is_adv_check": correct[
|
371 |
+
pred_is_adv_label == ref_is_adv_label
|
372 |
+
]
|
373 |
+
if ref_is_adv_label != -100
|
374 |
+
and isinstance(ref_is_adv_label, int)
|
375 |
+
else "",
|
376 |
+
"pred_adv_tr_label": self.opt.index_to_label[pred_adv_tr_label],
|
377 |
+
"ref_adv_tr_label": self.opt.index_to_label[ref_adv_tr_label],
|
378 |
+
"perplexity": perplexity,
|
379 |
}
|
380 |
if defense:
|
381 |
try:
|
382 |
+
if not hasattr(self, "sent_attacker"):
|
383 |
+
self.sent_attacker = init_attacker(
|
384 |
+
self, defense.lower()
|
385 |
+
)
|
386 |
+
if result["is_adv_label"] == "1":
|
387 |
+
res = self.sent_attacker.attacker.simple_attack(
|
388 |
+
text_raw, int(result["label"])
|
389 |
+
)
|
390 |
+
new_infer_res = self.infer(
|
391 |
+
res.perturbed_result.attacked_text.text,
|
392 |
+
print_result=False,
|
393 |
+
)
|
394 |
+
result["perturbed_label"] = result["label"]
|
395 |
+
result["label"] = new_infer_res["label"]
|
396 |
+
result["probs"] = new_infer_res["probs"]
|
397 |
+
result["ref_label_check"] = (
|
398 |
+
correct[int(result["label"]) == ref_label]
|
399 |
+
if ref_label != -100
|
400 |
+
else ""
|
401 |
+
)
|
402 |
+
result[
|
403 |
+
"restored_text"
|
404 |
+
] = res.perturbed_result.attacked_text.text
|
405 |
+
result["is_fixed"] = True
|
406 |
else:
|
407 |
+
result["restored_text"] = ""
|
408 |
+
result["is_fixed"] = False
|
409 |
|
410 |
except Exception as e:
|
411 |
+
print(
|
412 |
+
"Error:{}, try install TextAttack and tensorflow_text after 10 seconds...".format(
|
413 |
+
e
|
414 |
+
)
|
415 |
+
)
|
416 |
time.sleep(10)
|
417 |
+
raise RuntimeError("Installation done, please run again...")
|
418 |
|
419 |
if ref_label != -100:
|
420 |
n_labeled += 1
|
421 |
|
422 |
+
if result["label"] == result["ref_label"]:
|
423 |
n_correct += 1
|
424 |
|
425 |
if ref_is_adv_label != -100:
|
|
|
432 |
try:
|
433 |
if print_result:
|
434 |
for ex_id, result in enumerate(results):
|
435 |
+
text_printing = result["text"][:]
|
436 |
+
text_info = ""
|
437 |
+
if result["label"] != "-100":
|
438 |
+
if not result["ref_label"]:
|
439 |
+
text_info += " -> <CLS:{}(ref:{} confidence:{})>".format(
|
440 |
+
result["label"],
|
441 |
+
result["ref_label"],
|
442 |
+
result["confidence"],
|
443 |
+
)
|
444 |
+
elif result["label"] == result["ref_label"]:
|
445 |
text_info += colored(
|
446 |
+
" -> <CLS:{}(ref:{} confidence:{})>".format(
|
447 |
+
result["label"],
|
448 |
+
result["ref_label"],
|
449 |
+
result["confidence"],
|
450 |
+
),
|
451 |
+
"green",
|
452 |
+
)
|
453 |
else:
|
454 |
text_info += colored(
|
455 |
+
" -> <CLS:{}(ref:{} confidence:{})>".format(
|
456 |
+
result["label"],
|
457 |
+
result["ref_label"],
|
458 |
+
result["confidence"],
|
459 |
+
),
|
460 |
+
"red",
|
461 |
+
)
|
462 |
|
463 |
# AdvDet
|
464 |
+
if result["is_adv_label"] != "-100":
|
465 |
+
if not result["ref_is_adv_label"]:
|
466 |
+
text_info += " -> <AdvDet:{}(ref:{} confidence:{})>".format(
|
467 |
+
result["is_adv_label"],
|
468 |
+
result["ref_is_adv_check"],
|
469 |
+
result["is_adv_confidence"],
|
470 |
+
)
|
471 |
+
elif result["is_adv_label"] == result["ref_is_adv_label"]:
|
472 |
+
text_info += colored(
|
473 |
+
" -> <AdvDet:{}(ref:{} confidence:{})>".format(
|
474 |
+
result["is_adv_label"],
|
475 |
+
result["ref_is_adv_label"],
|
476 |
+
result["is_adv_confidence"],
|
477 |
+
),
|
478 |
+
"green",
|
479 |
+
)
|
480 |
else:
|
481 |
+
text_info += colored(
|
482 |
+
" -> <AdvDet:{}(ref:{} confidence:{})>".format(
|
483 |
+
result["is_adv_label"],
|
484 |
+
result["ref_is_adv_label"],
|
485 |
+
result["is_adv_confidence"],
|
486 |
+
),
|
487 |
+
"red",
|
488 |
+
)
|
489 |
text_printing += text_info
|
490 |
if self.cal_perplexity:
|
491 |
+
text_printing += colored(
|
492 |
+
" --> <perplexity:{}>".format(result["perplexity"]),
|
493 |
+
"yellow",
|
494 |
+
)
|
495 |
+
print("Example {}: {}".format(ex_id, text_printing))
|
496 |
if save_path:
|
497 |
+
with open(save_path, "w", encoding="utf8") as fout:
|
498 |
json.dump(str(results), fout, ensure_ascii=False)
|
499 |
+
print("inference result saved in: {}".format(save_path))
|
500 |
except Exception as e:
|
501 |
+
print("Can not save result: {}, Exception: {}".format(text_raw, e))
|
502 |
|
503 |
if len(results) > 1:
|
504 |
+
print(
|
505 |
+
"CLS Acc:{}%".format(100 * n_correct / n_labeled if n_labeled else "")
|
506 |
+
)
|
507 |
+
print(
|
508 |
+
"AdvDet Acc:{}%".format(
|
509 |
+
100 * n_advdet_correct / n_advdet_labeled
|
510 |
+
if n_advdet_labeled
|
511 |
+
else ""
|
512 |
+
)
|
513 |
+
)
|
514 |
|
515 |
return results
|
516 |
|
anonymous_demo/functional/checkpoint/checkpoint_manager.py
CHANGED
@@ -12,9 +12,8 @@ class CheckpointManager:
|
|
12 |
class TADCheckpointManager(CheckpointManager):
|
13 |
@staticmethod
|
14 |
@retry
|
15 |
-
def get_tad_text_classifier(checkpoint: str = None,
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
tad_text_classifier = TADTextClassifier(checkpoint, eval_batch_size=eval_batch_size, **kwargs)
|
20 |
return tad_text_classifier
|
|
|
12 |
class TADCheckpointManager(CheckpointManager):
|
13 |
@staticmethod
|
14 |
@retry
|
15 |
+
def get_tad_text_classifier(checkpoint: str = None, eval_batch_size=128, **kwargs):
|
16 |
+
tad_text_classifier = TADTextClassifier(
|
17 |
+
checkpoint, eval_batch_size=eval_batch_size, **kwargs
|
18 |
+
)
|
|
|
19 |
return tad_text_classifier
|
anonymous_demo/functional/config/config_manager.py
CHANGED
@@ -10,7 +10,6 @@ def config_check(args):
|
|
10 |
|
11 |
|
12 |
class ConfigManager(Namespace):
|
13 |
-
|
14 |
def __init__(self, args=None, **kwargs):
|
15 |
"""
|
16 |
The ConfigManager is a subclass of argparse.Namespace and based on parameter dict and count the call-frequency of each parameter
|
@@ -29,36 +28,35 @@ class ConfigManager(Namespace):
|
|
29 |
self.args_call_count = {arg: 0 for arg in args}
|
30 |
|
31 |
def __getattribute__(self, arg_name):
|
32 |
-
if arg_name ==
|
33 |
return super().__getattribute__(arg_name)
|
34 |
try:
|
35 |
-
value = super().__getattribute__(
|
36 |
-
args_call_count = super().__getattribute__(
|
37 |
args_call_count[arg_name] += 1
|
38 |
-
super().__setattr__(
|
39 |
return value
|
40 |
|
41 |
except Exception as e:
|
42 |
-
|
43 |
return super().__getattribute__(arg_name)
|
44 |
|
45 |
def __setattr__(self, arg_name, value):
|
46 |
-
if arg_name ==
|
47 |
super().__setattr__(arg_name, value)
|
48 |
return
|
49 |
try:
|
50 |
-
args = super().__getattribute__(
|
51 |
args[arg_name] = value
|
52 |
-
super().__setattr__(
|
53 |
-
args_call_count = super().__getattribute__(
|
54 |
|
55 |
if arg_name in args_call_count:
|
56 |
# args_call_count[arg_name] += 1
|
57 |
-
super().__setattr__(
|
58 |
|
59 |
else:
|
60 |
args_call_count[arg_name] = 0
|
61 |
-
super().__setattr__(
|
62 |
|
63 |
except Exception as e:
|
64 |
super().__setattr__(arg_name, value)
|
|
|
10 |
|
11 |
|
12 |
class ConfigManager(Namespace):
|
|
|
13 |
def __init__(self, args=None, **kwargs):
|
14 |
"""
|
15 |
The ConfigManager is a subclass of argparse.Namespace and based on parameter dict and count the call-frequency of each parameter
|
|
|
28 |
self.args_call_count = {arg: 0 for arg in args}
|
29 |
|
30 |
def __getattribute__(self, arg_name):
|
31 |
+
if arg_name == "args" or arg_name == "args_call_count":
|
32 |
return super().__getattribute__(arg_name)
|
33 |
try:
|
34 |
+
value = super().__getattribute__("args")[arg_name]
|
35 |
+
args_call_count = super().__getattribute__("args_call_count")
|
36 |
args_call_count[arg_name] += 1
|
37 |
+
super().__setattr__("args_call_count", args_call_count)
|
38 |
return value
|
39 |
|
40 |
except Exception as e:
|
|
|
41 |
return super().__getattribute__(arg_name)
|
42 |
|
43 |
def __setattr__(self, arg_name, value):
|
44 |
+
if arg_name == "args" or arg_name == "args_call_count":
|
45 |
super().__setattr__(arg_name, value)
|
46 |
return
|
47 |
try:
|
48 |
+
args = super().__getattribute__("args")
|
49 |
args[arg_name] = value
|
50 |
+
super().__setattr__("args", args)
|
51 |
+
args_call_count = super().__getattribute__("args_call_count")
|
52 |
|
53 |
if arg_name in args_call_count:
|
54 |
# args_call_count[arg_name] += 1
|
55 |
+
super().__setattr__("args_call_count", args_call_count)
|
56 |
|
57 |
else:
|
58 |
args_call_count[arg_name] = 0
|
59 |
+
super().__setattr__("args_call_count", args_call_count)
|
60 |
|
61 |
except Exception as e:
|
62 |
super().__setattr__(arg_name, value)
|
anonymous_demo/functional/config/tad_config_manager.py
CHANGED
@@ -3,116 +3,121 @@ import copy
|
|
3 |
from anonymous_demo.functional.config.config_manager import ConfigManager
|
4 |
from anonymous_demo.core.tad.classic.__bert__.models import TADBERT
|
5 |
|
6 |
-
_tad_config_template = {
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
|
118 |
class TADConfigManager(ConfigManager):
|
@@ -148,47 +153,50 @@ class TADConfigManager(ConfigManager):
|
|
148 |
@staticmethod
|
149 |
def set_tad_config(configType: str, newitem: dict):
|
150 |
if isinstance(newitem, dict):
|
151 |
-
if configType ==
|
152 |
_tad_config_template.update(newitem)
|
153 |
-
elif configType ==
|
154 |
_tad_config_base.update(newitem)
|
155 |
-
elif configType ==
|
156 |
_tad_config_english.update(newitem)
|
157 |
-
elif configType ==
|
158 |
_tad_config_chinese.update(newitem)
|
159 |
-
elif configType ==
|
160 |
_tad_config_multilingual.update(newitem)
|
161 |
-
elif configType ==
|
162 |
_tad_config_glove.update(newitem)
|
163 |
else:
|
164 |
raise ValueError(
|
165 |
-
"Wrong value of config type supplied, please use one from following type: template, base, english, chinese, multilingual, glove"
|
|
|
166 |
else:
|
167 |
-
raise TypeError(
|
|
|
|
|
168 |
|
169 |
@staticmethod
|
170 |
def set_tad_config_template(newitem):
|
171 |
-
TADConfigManager.set_tad_config(
|
172 |
|
173 |
@staticmethod
|
174 |
def set_tad_config_base(newitem):
|
175 |
-
TADConfigManager.set_tad_config(
|
176 |
|
177 |
@staticmethod
|
178 |
def set_tad_config_english(newitem):
|
179 |
-
TADConfigManager.set_tad_config(
|
180 |
|
181 |
@staticmethod
|
182 |
def set_tad_config_chinese(newitem):
|
183 |
-
TADConfigManager.set_tad_config(
|
184 |
|
185 |
@staticmethod
|
186 |
def set_tad_config_multilingual(newitem):
|
187 |
-
TADConfigManager.set_tad_config(
|
188 |
|
189 |
@staticmethod
|
190 |
def set_tad_config_glove(newitem):
|
191 |
-
TADConfigManager.set_tad_config(
|
192 |
|
193 |
@staticmethod
|
194 |
def get_tad_config_template() -> ConfigManager:
|
|
|
3 |
from anonymous_demo.functional.config.config_manager import ConfigManager
|
4 |
from anonymous_demo.core.tad.classic.__bert__.models import TADBERT
|
5 |
|
6 |
+
_tad_config_template = {
|
7 |
+
"model": TADBERT,
|
8 |
+
"optimizer": "adamw",
|
9 |
+
"learning_rate": 0.00002,
|
10 |
+
"patience": 99999,
|
11 |
+
"pretrained_bert": "microsoft/mdeberta-v3-base",
|
12 |
+
"cache_dataset": True,
|
13 |
+
"warmup_step": -1,
|
14 |
+
"show_metric": False,
|
15 |
+
"max_seq_len": 80,
|
16 |
+
"dropout": 0,
|
17 |
+
"l2reg": 0.000001,
|
18 |
+
"num_epoch": 10,
|
19 |
+
"batch_size": 16,
|
20 |
+
"initializer": "xavier_uniform_",
|
21 |
+
"seed": 52,
|
22 |
+
"polarities_dim": 3,
|
23 |
+
"log_step": 10,
|
24 |
+
"evaluate_begin": 0,
|
25 |
+
"cross_validate_fold": -1,
|
26 |
+
"use_amp": False,
|
27 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
28 |
+
}
|
29 |
+
|
30 |
+
_tad_config_base = {
|
31 |
+
"model": TADBERT,
|
32 |
+
"optimizer": "adamw",
|
33 |
+
"learning_rate": 0.00002,
|
34 |
+
"pretrained_bert": "microsoft/deberta-v3-base",
|
35 |
+
"cache_dataset": True,
|
36 |
+
"warmup_step": -1,
|
37 |
+
"show_metric": False,
|
38 |
+
"max_seq_len": 80,
|
39 |
+
"patience": 99999,
|
40 |
+
"dropout": 0,
|
41 |
+
"l2reg": 0.000001,
|
42 |
+
"num_epoch": 10,
|
43 |
+
"batch_size": 16,
|
44 |
+
"initializer": "xavier_uniform_",
|
45 |
+
"seed": 52,
|
46 |
+
"polarities_dim": 3,
|
47 |
+
"log_step": 10,
|
48 |
+
"evaluate_begin": 0,
|
49 |
+
"cross_validate_fold": -1
|
50 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
51 |
+
}
|
52 |
+
|
53 |
+
_tad_config_english = {
|
54 |
+
"model": TADBERT,
|
55 |
+
"optimizer": "adamw",
|
56 |
+
"learning_rate": 0.00002,
|
57 |
+
"patience": 99999,
|
58 |
+
"pretrained_bert": "microsoft/deberta-v3-base",
|
59 |
+
"cache_dataset": True,
|
60 |
+
"warmup_step": -1,
|
61 |
+
"show_metric": False,
|
62 |
+
"max_seq_len": 80,
|
63 |
+
"dropout": 0,
|
64 |
+
"l2reg": 0.000001,
|
65 |
+
"num_epoch": 10,
|
66 |
+
"batch_size": 16,
|
67 |
+
"initializer": "xavier_uniform_",
|
68 |
+
"seed": 52,
|
69 |
+
"polarities_dim": 3,
|
70 |
+
"log_step": 10,
|
71 |
+
"evaluate_begin": 0,
|
72 |
+
"cross_validate_fold": -1
|
73 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
74 |
+
}
|
75 |
+
|
76 |
+
_tad_config_multilingual = {
|
77 |
+
"model": TADBERT,
|
78 |
+
"optimizer": "adamw",
|
79 |
+
"learning_rate": 0.00002,
|
80 |
+
"patience": 99999,
|
81 |
+
"pretrained_bert": "microsoft/mdeberta-v3-base",
|
82 |
+
"cache_dataset": True,
|
83 |
+
"warmup_step": -1,
|
84 |
+
"show_metric": False,
|
85 |
+
"max_seq_len": 80,
|
86 |
+
"dropout": 0,
|
87 |
+
"l2reg": 0.000001,
|
88 |
+
"num_epoch": 10,
|
89 |
+
"batch_size": 16,
|
90 |
+
"initializer": "xavier_uniform_",
|
91 |
+
"seed": 52,
|
92 |
+
"polarities_dim": 3,
|
93 |
+
"log_step": 10,
|
94 |
+
"evaluate_begin": 0,
|
95 |
+
"cross_validate_fold": -1
|
96 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
97 |
+
}
|
98 |
+
|
99 |
+
_tad_config_chinese = {
|
100 |
+
"model": TADBERT,
|
101 |
+
"optimizer": "adamw",
|
102 |
+
"learning_rate": 0.00002,
|
103 |
+
"patience": 99999,
|
104 |
+
"cache_dataset": True,
|
105 |
+
"warmup_step": -1,
|
106 |
+
"show_metric": False,
|
107 |
+
"pretrained_bert": "bert-base-chinese",
|
108 |
+
"max_seq_len": 80,
|
109 |
+
"dropout": 0,
|
110 |
+
"l2reg": 0.000001,
|
111 |
+
"num_epoch": 10,
|
112 |
+
"batch_size": 16,
|
113 |
+
"initializer": "xavier_uniform_",
|
114 |
+
"seed": 52,
|
115 |
+
"polarities_dim": 3,
|
116 |
+
"log_step": 10,
|
117 |
+
"evaluate_begin": 0,
|
118 |
+
"cross_validate_fold": -1
|
119 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
120 |
+
}
|
121 |
|
122 |
|
123 |
class TADConfigManager(ConfigManager):
|
|
|
153 |
@staticmethod
|
154 |
def set_tad_config(configType: str, newitem: dict):
|
155 |
if isinstance(newitem, dict):
|
156 |
+
if configType == "template":
|
157 |
_tad_config_template.update(newitem)
|
158 |
+
elif configType == "base":
|
159 |
_tad_config_base.update(newitem)
|
160 |
+
elif configType == "english":
|
161 |
_tad_config_english.update(newitem)
|
162 |
+
elif configType == "chinese":
|
163 |
_tad_config_chinese.update(newitem)
|
164 |
+
elif configType == "multilingual":
|
165 |
_tad_config_multilingual.update(newitem)
|
166 |
+
elif configType == "glove":
|
167 |
_tad_config_glove.update(newitem)
|
168 |
else:
|
169 |
raise ValueError(
|
170 |
+
"Wrong value of config type supplied, please use one from following type: template, base, english, chinese, multilingual, glove"
|
171 |
+
)
|
172 |
else:
|
173 |
+
raise TypeError(
|
174 |
+
"Wrong type of new config item supplied, please use dict e.g.{'NewConfig': NewValue}"
|
175 |
+
)
|
176 |
|
177 |
@staticmethod
|
178 |
def set_tad_config_template(newitem):
|
179 |
+
TADConfigManager.set_tad_config("template", newitem)
|
180 |
|
181 |
@staticmethod
|
182 |
def set_tad_config_base(newitem):
|
183 |
+
TADConfigManager.set_tad_config("base", newitem)
|
184 |
|
185 |
@staticmethod
|
186 |
def set_tad_config_english(newitem):
|
187 |
+
TADConfigManager.set_tad_config("english", newitem)
|
188 |
|
189 |
@staticmethod
|
190 |
def set_tad_config_chinese(newitem):
|
191 |
+
TADConfigManager.set_tad_config("chinese", newitem)
|
192 |
|
193 |
@staticmethod
|
194 |
def set_tad_config_multilingual(newitem):
|
195 |
+
TADConfigManager.set_tad_config("multilingual", newitem)
|
196 |
|
197 |
@staticmethod
|
198 |
def set_tad_config_glove(newitem):
|
199 |
+
TADConfigManager.set_tad_config("glove", newitem)
|
200 |
|
201 |
@staticmethod
|
202 |
def get_tad_config_template() -> ConfigManager:
|
anonymous_demo/functional/dataset/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
from anonymous_demo.functional.dataset.dataset_manager import
|
|
|
1 |
+
from anonymous_demo.functional.dataset.dataset_manager import detect_infer_dataset
|
anonymous_demo/functional/dataset/dataset_manager.py
CHANGED
@@ -1,11 +1,24 @@
|
|
1 |
import os
|
2 |
from findfile import find_files, find_dir
|
3 |
|
4 |
-
filter_key_words = [
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
-
def detect_infer_dataset(dataset_path, task=
|
9 |
dataset_file = []
|
10 |
if isinstance(dataset_path, str) and os.path.isfile(dataset_path):
|
11 |
dataset_file.append(dataset_path)
|
@@ -13,9 +26,20 @@ def detect_infer_dataset(dataset_path, task='apc'):
|
|
13 |
|
14 |
for d in dataset_path:
|
15 |
if not os.path.exists(d):
|
16 |
-
search_path = find_dir(
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
else:
|
19 |
-
dataset_file += find_files(
|
|
|
|
|
20 |
|
21 |
return dataset_file
|
|
|
1 |
import os
|
2 |
from findfile import find_files, find_dir
|
3 |
|
4 |
+
filter_key_words = [
|
5 |
+
".py",
|
6 |
+
".md",
|
7 |
+
"readme",
|
8 |
+
"log",
|
9 |
+
"result",
|
10 |
+
"zip",
|
11 |
+
".state_dict",
|
12 |
+
".model",
|
13 |
+
".png",
|
14 |
+
"acc_",
|
15 |
+
"f1_",
|
16 |
+
".backup",
|
17 |
+
".bak",
|
18 |
+
]
|
19 |
|
20 |
|
21 |
+
def detect_infer_dataset(dataset_path, task="apc"):
|
22 |
dataset_file = []
|
23 |
if isinstance(dataset_path, str) and os.path.isfile(dataset_path):
|
24 |
dataset_file.append(dataset_path)
|
|
|
26 |
|
27 |
for d in dataset_path:
|
28 |
if not os.path.exists(d):
|
29 |
+
search_path = find_dir(
|
30 |
+
os.getcwd(),
|
31 |
+
[d, task, "dataset"],
|
32 |
+
exclude_key=filter_key_words,
|
33 |
+
disable_alert=False,
|
34 |
+
)
|
35 |
+
dataset_file += find_files(
|
36 |
+
search_path,
|
37 |
+
[".inference", d],
|
38 |
+
exclude_key=["train."] + filter_key_words,
|
39 |
+
)
|
40 |
else:
|
41 |
+
dataset_file += find_files(
|
42 |
+
d, [".inference", task], exclude_key=["train."] + filter_key_words
|
43 |
+
)
|
44 |
|
45 |
return dataset_file
|
anonymous_demo/network/lcf_pooler.py
CHANGED
@@ -14,10 +14,12 @@ class LCF_Pooler(nn.Module):
|
|
14 |
device = hidden_states.device
|
15 |
lcf_vec = lcf_vec.detach().cpu().numpy()
|
16 |
|
17 |
-
pooled_output = numpy.zeros(
|
|
|
|
|
18 |
hidden_states = hidden_states.detach().cpu().numpy()
|
19 |
for i, vec in enumerate(lcf_vec):
|
20 |
-
lcf_ids = [j for j in range(len(vec)) if sum(vec[j] - 1.) == 0]
|
21 |
pooled_output[i] = hidden_states[i][lcf_ids[len(lcf_ids) // 2]]
|
22 |
|
23 |
pooled_output = torch.Tensor(pooled_output).to(device)
|
|
|
14 |
device = hidden_states.device
|
15 |
lcf_vec = lcf_vec.detach().cpu().numpy()
|
16 |
|
17 |
+
pooled_output = numpy.zeros(
|
18 |
+
(hidden_states.shape[0], hidden_states.shape[2]), dtype=numpy.float32
|
19 |
+
)
|
20 |
hidden_states = hidden_states.detach().cpu().numpy()
|
21 |
for i, vec in enumerate(lcf_vec):
|
22 |
+
lcf_ids = [j for j in range(len(vec)) if sum(vec[j] - 1.0) == 0]
|
23 |
pooled_output[i] = hidden_states[i][lcf_ids[len(lcf_ids) // 2]]
|
24 |
|
25 |
pooled_output = torch.Tensor(pooled_output).to(device)
|
anonymous_demo/network/lsa.py
CHANGED
@@ -16,8 +16,17 @@ class LSA(nn.Module):
|
|
16 |
self.eta1 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
|
17 |
self.eta2 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
|
18 |
|
19 |
-
def forward(
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# # --------------------------------------------------- #
|
23 |
lcf_features = torch.mul(global_context_features, lcf_matrix)
|
@@ -29,24 +38,36 @@ class LSA(nn.Module):
|
|
29 |
right_lcf_features = torch.mul(masked_global_context_features, right_lcf_matrix)
|
30 |
right_lcf_features = self.encoder_right(right_lcf_features)
|
31 |
# # --------------------------------------------------- #
|
32 |
-
if
|
33 |
if self.eta1 <= 0 and self.opt.eta != -1:
|
34 |
torch.nn.init.uniform_(self.eta1)
|
35 |
-
print(
|
36 |
if self.eta2 <= 0 and self.opt.eta != -1:
|
37 |
torch.nn.init.uniform_(self.eta2)
|
38 |
-
print(
|
39 |
if self.opt.eta >= 0:
|
40 |
-
cat_features = torch.cat(
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
else:
|
43 |
-
cat_features = torch.cat(
|
|
|
|
|
44 |
sent_out = self.linear_window_3h(cat_features)
|
45 |
-
elif
|
46 |
-
sent_out = self.linear_window_2h(
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
49 |
else:
|
50 |
-
raise KeyError(
|
51 |
|
52 |
return sent_out
|
|
|
16 |
self.eta1 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
|
17 |
self.eta2 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
|
18 |
|
19 |
+
def forward(
|
20 |
+
self,
|
21 |
+
global_context_features,
|
22 |
+
spc_mask_vec,
|
23 |
+
lcf_matrix,
|
24 |
+
left_lcf_matrix,
|
25 |
+
right_lcf_matrix,
|
26 |
+
):
|
27 |
+
masked_global_context_features = torch.mul(
|
28 |
+
spc_mask_vec, global_context_features
|
29 |
+
)
|
30 |
|
31 |
# # --------------------------------------------------- #
|
32 |
lcf_features = torch.mul(global_context_features, lcf_matrix)
|
|
|
38 |
right_lcf_features = torch.mul(masked_global_context_features, right_lcf_matrix)
|
39 |
right_lcf_features = self.encoder_right(right_lcf_features)
|
40 |
# # --------------------------------------------------- #
|
41 |
+
if "lr" == self.opt.window or "rl" == self.opt.window:
|
42 |
if self.eta1 <= 0 and self.opt.eta != -1:
|
43 |
torch.nn.init.uniform_(self.eta1)
|
44 |
+
print("reset eta1 to: {}".format(self.eta1.item()))
|
45 |
if self.eta2 <= 0 and self.opt.eta != -1:
|
46 |
torch.nn.init.uniform_(self.eta2)
|
47 |
+
print("reset eta2 to: {}".format(self.eta2.item()))
|
48 |
if self.opt.eta >= 0:
|
49 |
+
cat_features = torch.cat(
|
50 |
+
(
|
51 |
+
lcf_features,
|
52 |
+
self.eta1 * left_lcf_features,
|
53 |
+
self.eta2 * right_lcf_features,
|
54 |
+
),
|
55 |
+
-1,
|
56 |
+
)
|
57 |
else:
|
58 |
+
cat_features = torch.cat(
|
59 |
+
(lcf_features, left_lcf_features, right_lcf_features), -1
|
60 |
+
)
|
61 |
sent_out = self.linear_window_3h(cat_features)
|
62 |
+
elif "l" == self.opt.window:
|
63 |
+
sent_out = self.linear_window_2h(
|
64 |
+
torch.cat((lcf_features, self.eta1 * left_lcf_features), -1)
|
65 |
+
)
|
66 |
+
elif "r" == self.opt.window:
|
67 |
+
sent_out = self.linear_window_2h(
|
68 |
+
torch.cat((lcf_features, self.eta2 * right_lcf_features), -1)
|
69 |
+
)
|
70 |
else:
|
71 |
+
raise KeyError("Invalid parameter:", self.opt.window)
|
72 |
|
73 |
return sent_out
|
anonymous_demo/network/sa_encoder.py
CHANGED
@@ -8,7 +8,9 @@ import torch.nn as nn
|
|
8 |
class BertSelfAttention(nn.Module):
|
9 |
def __init__(self, config):
|
10 |
super().__init__()
|
11 |
-
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
|
|
|
|
12 |
raise ValueError(
|
13 |
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
14 |
f"heads ({config.num_attention_heads})"
|
@@ -23,16 +25,29 @@ class BertSelfAttention(nn.Module):
|
|
23 |
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
24 |
|
25 |
self.dropout = nn.Dropout(
|
26 |
-
config.attention_probs_dropout_prob
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
self.max_position_embeddings = config.max_position_embeddings
|
30 |
-
self.distance_embedding = nn.Embedding(
|
|
|
|
|
31 |
|
32 |
self.is_decoder = config.is_decoder
|
33 |
|
34 |
def transpose_for_scores(self, x):
|
35 |
-
new_x_shape = x.size()[:-1] + (
|
|
|
|
|
|
|
36 |
x = x.view(*new_x_shape)
|
37 |
return x.permute(0, 2, 1, 3)
|
38 |
|
@@ -86,21 +101,42 @@ class BertSelfAttention(nn.Module):
|
|
86 |
# Take the dot product between "query" and "key" to get the raw attention scores.
|
87 |
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
88 |
|
89 |
-
if
|
|
|
|
|
|
|
90 |
seq_length = hidden_states.size()[1]
|
91 |
-
position_ids_l = torch.arange(
|
92 |
-
|
|
|
|
|
|
|
|
|
93 |
distance = position_ids_l - position_ids_r
|
94 |
-
positional_embedding = self.distance_embedding(
|
95 |
-
|
|
|
|
|
|
|
|
|
96 |
|
97 |
if self.position_embedding_type == "relative_key":
|
98 |
-
relative_position_scores = torch.einsum(
|
|
|
|
|
99 |
attention_scores = attention_scores + relative_position_scores
|
100 |
elif self.position_embedding_type == "relative_key_query":
|
101 |
-
relative_position_scores_query = torch.einsum(
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
106 |
if attention_mask is not None:
|
@@ -124,7 +160,9 @@ class BertSelfAttention(nn.Module):
|
|
124 |
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
125 |
context_layer = context_layer.view(*new_context_layer_shape)
|
126 |
|
127 |
-
outputs = (
|
|
|
|
|
128 |
|
129 |
if self.is_decoder:
|
130 |
outputs = outputs + (past_key_value,)
|
@@ -136,7 +174,9 @@ class Encoder(nn.Module):
|
|
136 |
super(Encoder, self).__init__()
|
137 |
self.opt = opt
|
138 |
self.config = config
|
139 |
-
self.encoder = nn.ModuleList(
|
|
|
|
|
140 |
self.tanh = torch.nn.Tanh()
|
141 |
|
142 |
def forward(self, x):
|
|
|
8 |
class BertSelfAttention(nn.Module):
|
9 |
def __init__(self, config):
|
10 |
super().__init__()
|
11 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
12 |
+
config, "embedding_size"
|
13 |
+
):
|
14 |
raise ValueError(
|
15 |
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
16 |
f"heads ({config.num_attention_heads})"
|
|
|
25 |
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
26 |
|
27 |
self.dropout = nn.Dropout(
|
28 |
+
config.attention_probs_dropout_prob
|
29 |
+
if hasattr(config, "attention_probs_dropout_prob")
|
30 |
+
else 0
|
31 |
+
)
|
32 |
+
self.position_embedding_type = getattr(
|
33 |
+
config, "position_embedding_type", "absolute"
|
34 |
+
)
|
35 |
+
if (
|
36 |
+
self.position_embedding_type == "relative_key"
|
37 |
+
or self.position_embedding_type == "relative_key_query"
|
38 |
+
):
|
39 |
self.max_position_embeddings = config.max_position_embeddings
|
40 |
+
self.distance_embedding = nn.Embedding(
|
41 |
+
2 * config.max_position_embeddings - 1, self.attention_head_size
|
42 |
+
)
|
43 |
|
44 |
self.is_decoder = config.is_decoder
|
45 |
|
46 |
def transpose_for_scores(self, x):
|
47 |
+
new_x_shape = x.size()[:-1] + (
|
48 |
+
self.num_attention_heads,
|
49 |
+
self.attention_head_size,
|
50 |
+
)
|
51 |
x = x.view(*new_x_shape)
|
52 |
return x.permute(0, 2, 1, 3)
|
53 |
|
|
|
101 |
# Take the dot product between "query" and "key" to get the raw attention scores.
|
102 |
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
103 |
|
104 |
+
if (
|
105 |
+
self.position_embedding_type == "relative_key"
|
106 |
+
or self.position_embedding_type == "relative_key_query"
|
107 |
+
):
|
108 |
seq_length = hidden_states.size()[1]
|
109 |
+
position_ids_l = torch.arange(
|
110 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
111 |
+
).view(-1, 1)
|
112 |
+
position_ids_r = torch.arange(
|
113 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
114 |
+
).view(1, -1)
|
115 |
distance = position_ids_l - position_ids_r
|
116 |
+
positional_embedding = self.distance_embedding(
|
117 |
+
distance + self.max_position_embeddings - 1
|
118 |
+
)
|
119 |
+
positional_embedding = positional_embedding.to(
|
120 |
+
dtype=query_layer.dtype
|
121 |
+
) # fp16 compatibility
|
122 |
|
123 |
if self.position_embedding_type == "relative_key":
|
124 |
+
relative_position_scores = torch.einsum(
|
125 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
126 |
+
)
|
127 |
attention_scores = attention_scores + relative_position_scores
|
128 |
elif self.position_embedding_type == "relative_key_query":
|
129 |
+
relative_position_scores_query = torch.einsum(
|
130 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
131 |
+
)
|
132 |
+
relative_position_scores_key = torch.einsum(
|
133 |
+
"bhrd,lrd->bhlr", key_layer, positional_embedding
|
134 |
+
)
|
135 |
+
attention_scores = (
|
136 |
+
attention_scores
|
137 |
+
+ relative_position_scores_query
|
138 |
+
+ relative_position_scores_key
|
139 |
+
)
|
140 |
|
141 |
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
142 |
if attention_mask is not None:
|
|
|
160 |
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
161 |
context_layer = context_layer.view(*new_context_layer_shape)
|
162 |
|
163 |
+
outputs = (
|
164 |
+
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
165 |
+
)
|
166 |
|
167 |
if self.is_decoder:
|
168 |
outputs = outputs + (past_key_value,)
|
|
|
174 |
super(Encoder, self).__init__()
|
175 |
self.opt = opt
|
176 |
self.config = config
|
177 |
+
self.encoder = nn.ModuleList(
|
178 |
+
[SelfAttention(config, opt) for _ in range(layer_num)]
|
179 |
+
)
|
180 |
self.tanh = torch.nn.Tanh()
|
181 |
|
182 |
def forward(self, x):
|
anonymous_demo/utils/demo_utils.py
CHANGED
@@ -22,10 +22,10 @@ from anonymous_demo import __version__
|
|
22 |
|
23 |
|
24 |
def save_args(config, save_path):
|
25 |
-
f = open(os.path.join(save_path), mode=
|
26 |
for arg in config.args:
|
27 |
if config.args_call_count[arg]:
|
28 |
-
f.write(
|
29 |
f.close()
|
30 |
|
31 |
|
@@ -33,20 +33,39 @@ def print_args(config, logger=None, mode=0):
|
|
33 |
args = [key for key in sorted(config.args.keys())]
|
34 |
for arg in args:
|
35 |
if logger:
|
36 |
-
logger.info(
|
|
|
|
|
|
|
|
|
37 |
else:
|
38 |
-
print(
|
|
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
def check_and_fix_labels(label_set: set, label_name, all_data, opt):
|
42 |
-
if
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
46 |
else:
|
47 |
-
label_to_index = {
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
opt.index_to_label = index_to_label
|
51 |
opt.label_to_index = label_to_index
|
52 |
|
@@ -54,7 +73,7 @@ def check_and_fix_labels(label_set: set, label_name, all_data, opt):
|
|
54 |
opt.index_to_label.update(index_to_label)
|
55 |
opt.label_to_index.update(label_to_index)
|
56 |
num_label = {l: 0 for l in label_set}
|
57 |
-
num_label[
|
58 |
for item in all_data:
|
59 |
try:
|
60 |
num_label[item[label_name]] += 1
|
@@ -63,75 +82,91 @@ def check_and_fix_labels(label_set: set, label_name, all_data, opt):
|
|
63 |
# print(e)
|
64 |
num_label[item.polarity] += 1
|
65 |
item.polarity = label_to_index[item.polarity]
|
66 |
-
print(
|
67 |
|
68 |
|
69 |
def check_and_fix_IOB_labels(label_map, opt):
|
70 |
-
index_to_IOB_label = {
|
|
|
|
|
71 |
opt.index_to_IOB_label = index_to_IOB_label
|
72 |
|
73 |
|
74 |
def get_device(auto_device):
|
75 |
-
if isinstance(auto_device, str) and auto_device ==
|
76 |
-
device =
|
77 |
elif isinstance(auto_device, str):
|
78 |
device = auto_device
|
79 |
elif isinstance(auto_device, bool):
|
80 |
-
device = auto_cuda() if auto_device else
|
81 |
else:
|
82 |
device = auto_cuda()
|
83 |
try:
|
84 |
torch.device(device)
|
85 |
except RuntimeError as e:
|
86 |
-
print(
|
87 |
-
|
|
|
|
|
88 |
device_name = auto_cuda_name()
|
89 |
return device, device_name
|
90 |
|
91 |
|
92 |
def _load_word_vec(path, word2idx=None, embed_dim=300):
|
93 |
-
fin = open(path,
|
94 |
word_vec = {}
|
95 |
-
for line in tqdm.tqdm(fin.readlines(), postfix=
|
96 |
tokens = line.rstrip().split()
|
97 |
-
word, vec =
|
98 |
if word in word2idx.keys():
|
99 |
-
word_vec[word] = np.asarray(vec, dtype=
|
100 |
return word_vec
|
101 |
|
102 |
|
103 |
def build_embedding_matrix(word2idx, embed_dim, dat_fname, opt):
|
104 |
-
if not os.path.exists(
|
105 |
-
os.makedirs(
|
106 |
-
embed_matrix_path =
|
107 |
if os.path.exists(embed_matrix_path):
|
108 |
-
print(
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
else:
|
111 |
glove_path = prepare_glove840_embedding(embed_matrix_path)
|
112 |
embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim))
|
113 |
|
114 |
word_vec = _load_word_vec(glove_path, word2idx=word2idx, embed_dim=embed_dim)
|
115 |
|
116 |
-
for word, i in tqdm.tqdm(
|
|
|
|
|
|
|
117 |
vec = word_vec.get(word)
|
118 |
if vec is not None:
|
119 |
embedding_matrix[i] = vec
|
120 |
-
pickle.dump(embedding_matrix, open(embed_matrix_path,
|
121 |
return embedding_matrix
|
122 |
|
123 |
|
124 |
-
def pad_and_truncate(
|
|
|
|
|
125 |
x = (np.ones(maxlen) * value).astype(dtype)
|
126 |
-
if truncating ==
|
127 |
trunc = sequence[-maxlen:]
|
128 |
else:
|
129 |
trunc = sequence[:maxlen]
|
130 |
trunc = np.asarray(trunc, dtype=dtype)
|
131 |
-
if padding ==
|
132 |
-
x[:len(trunc)] = trunc
|
133 |
else:
|
134 |
-
x[-len(trunc):] = trunc
|
135 |
return x
|
136 |
|
137 |
|
@@ -145,7 +180,6 @@ def retry(f):
|
|
145 |
def decorated(*args, **kwargs):
|
146 |
count = 5
|
147 |
while count:
|
148 |
-
|
149 |
try:
|
150 |
return f(*args, **kwargs)
|
151 |
except (
|
@@ -158,7 +192,7 @@ def retry(f):
|
|
158 |
requests.exceptions.SSLError,
|
159 |
requests.exceptions.BaseHTTPError,
|
160 |
) as e:
|
161 |
-
print(colored(
|
162 |
time.sleep(60)
|
163 |
count -= 1
|
164 |
|
@@ -168,14 +202,14 @@ def retry(f):
|
|
168 |
def save_json(dic, save_path):
|
169 |
if isinstance(dic, str):
|
170 |
dic = eval(dic)
|
171 |
-
with open(save_path,
|
172 |
# f.write(str(dict))
|
173 |
str_ = json.dumps(dic, ensure_ascii=False)
|
174 |
f.write(str_)
|
175 |
|
176 |
|
177 |
def load_json(save_path):
|
178 |
-
with open(save_path,
|
179 |
data = f.readline().strip()
|
180 |
print(type(data), data)
|
181 |
dic = json.loads(data)
|
@@ -184,14 +218,14 @@ def load_json(save_path):
|
|
184 |
|
185 |
def init_optimizer(optimizer):
|
186 |
optimizers = {
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
torch.optim.Adadelta: torch.optim.Adadelta, # default lr=1.0
|
196 |
torch.optim.Adagrad: torch.optim.Adagrad, # default lr=0.01
|
197 |
torch.optim.Adam: torch.optim.Adam, # default lr=0.001
|
@@ -206,4 +240,8 @@ def init_optimizer(optimizer):
|
|
206 |
elif hasattr(torch.optim, optimizer.__name__):
|
207 |
return optimizer
|
208 |
else:
|
209 |
-
raise KeyError(
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
def save_args(config, save_path):
|
25 |
+
f = open(os.path.join(save_path), mode="w", encoding="utf8")
|
26 |
for arg in config.args:
|
27 |
if config.args_call_count[arg]:
|
28 |
+
f.write("{}: {}\n".format(arg, config.args[arg]))
|
29 |
f.close()
|
30 |
|
31 |
|
|
|
33 |
args = [key for key in sorted(config.args.keys())]
|
34 |
for arg in args:
|
35 |
if logger:
|
36 |
+
logger.info(
|
37 |
+
"{0}:{1}\t-->\tCalling Count:{2}".format(
|
38 |
+
arg, config.args[arg], config.args_call_count[arg]
|
39 |
+
)
|
40 |
+
)
|
41 |
else:
|
42 |
+
print(
|
43 |
+
"{0}:{1}\t-->\tCalling Count:{2}".format(
|
44 |
+
arg, config.args[arg], config.args_call_count[arg]
|
45 |
+
)
|
46 |
+
)
|
47 |
|
48 |
|
49 |
def check_and_fix_labels(label_set: set, label_name, all_data, opt):
|
50 |
+
if "-100" in label_set:
|
51 |
+
label_to_index = {
|
52 |
+
origin_label: int(idx) - 1 if origin_label != "-100" else -100
|
53 |
+
for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
|
54 |
+
}
|
55 |
+
index_to_label = {
|
56 |
+
int(idx) - 1 if origin_label != "-100" else -100: origin_label
|
57 |
+
for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
|
58 |
+
}
|
59 |
else:
|
60 |
+
label_to_index = {
|
61 |
+
origin_label: int(idx)
|
62 |
+
for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
|
63 |
+
}
|
64 |
+
index_to_label = {
|
65 |
+
int(idx): origin_label
|
66 |
+
for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
|
67 |
+
}
|
68 |
+
if "index_to_label" not in opt.args:
|
69 |
opt.index_to_label = index_to_label
|
70 |
opt.label_to_index = label_to_index
|
71 |
|
|
|
73 |
opt.index_to_label.update(index_to_label)
|
74 |
opt.label_to_index.update(label_to_index)
|
75 |
num_label = {l: 0 for l in label_set}
|
76 |
+
num_label["Sum"] = len(all_data)
|
77 |
for item in all_data:
|
78 |
try:
|
79 |
num_label[item[label_name]] += 1
|
|
|
82 |
# print(e)
|
83 |
num_label[item.polarity] += 1
|
84 |
item.polarity = label_to_index[item.polarity]
|
85 |
+
print("Dataset Label Details: {}".format(num_label))
|
86 |
|
87 |
|
88 |
def check_and_fix_IOB_labels(label_map, opt):
|
89 |
+
index_to_IOB_label = {
|
90 |
+
int(label_map[origin_label]): origin_label for origin_label in label_map
|
91 |
+
}
|
92 |
opt.index_to_IOB_label = index_to_IOB_label
|
93 |
|
94 |
|
95 |
def get_device(auto_device):
|
96 |
+
if isinstance(auto_device, str) and auto_device == "allcuda":
|
97 |
+
device = "cuda"
|
98 |
elif isinstance(auto_device, str):
|
99 |
device = auto_device
|
100 |
elif isinstance(auto_device, bool):
|
101 |
+
device = auto_cuda() if auto_device else "cpu"
|
102 |
else:
|
103 |
device = auto_cuda()
|
104 |
try:
|
105 |
torch.device(device)
|
106 |
except RuntimeError as e:
|
107 |
+
print(
|
108 |
+
colored("Device assignment error: {}, redirect to CPU".format(e), "red")
|
109 |
+
)
|
110 |
+
device = "cpu"
|
111 |
device_name = auto_cuda_name()
|
112 |
return device, device_name
|
113 |
|
114 |
|
115 |
def _load_word_vec(path, word2idx=None, embed_dim=300):
|
116 |
+
fin = open(path, "r", encoding="utf-8", newline="\n", errors="ignore")
|
117 |
word_vec = {}
|
118 |
+
for line in tqdm.tqdm(fin.readlines(), postfix="Loading embedding file..."):
|
119 |
tokens = line.rstrip().split()
|
120 |
+
word, vec = " ".join(tokens[:-embed_dim]), tokens[-embed_dim:]
|
121 |
if word in word2idx.keys():
|
122 |
+
word_vec[word] = np.asarray(vec, dtype="float32")
|
123 |
return word_vec
|
124 |
|
125 |
|
126 |
def build_embedding_matrix(word2idx, embed_dim, dat_fname, opt):
|
127 |
+
if not os.path.exists("run"):
|
128 |
+
os.makedirs("run")
|
129 |
+
embed_matrix_path = "run/{}".format(os.path.join(opt.dataset_name, dat_fname))
|
130 |
if os.path.exists(embed_matrix_path):
|
131 |
+
print(
|
132 |
+
colored(
|
133 |
+
"Loading cached embedding_matrix from {} (Please remove all cached files if there is any problem!)".format(
|
134 |
+
embed_matrix_path
|
135 |
+
),
|
136 |
+
"green",
|
137 |
+
)
|
138 |
+
)
|
139 |
+
embedding_matrix = pickle.load(open(embed_matrix_path, "rb"))
|
140 |
else:
|
141 |
glove_path = prepare_glove840_embedding(embed_matrix_path)
|
142 |
embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim))
|
143 |
|
144 |
word_vec = _load_word_vec(glove_path, word2idx=word2idx, embed_dim=embed_dim)
|
145 |
|
146 |
+
for word, i in tqdm.tqdm(
|
147 |
+
word2idx.items(),
|
148 |
+
postfix=colored("Building embedding_matrix {}".format(dat_fname), "yellow"),
|
149 |
+
):
|
150 |
vec = word_vec.get(word)
|
151 |
if vec is not None:
|
152 |
embedding_matrix[i] = vec
|
153 |
+
pickle.dump(embedding_matrix, open(embed_matrix_path, "wb"))
|
154 |
return embedding_matrix
|
155 |
|
156 |
|
157 |
+
def pad_and_truncate(
|
158 |
+
sequence, maxlen, dtype="int64", padding="post", truncating="post", value=0
|
159 |
+
):
|
160 |
x = (np.ones(maxlen) * value).astype(dtype)
|
161 |
+
if truncating == "pre":
|
162 |
trunc = sequence[-maxlen:]
|
163 |
else:
|
164 |
trunc = sequence[:maxlen]
|
165 |
trunc = np.asarray(trunc, dtype=dtype)
|
166 |
+
if padding == "post":
|
167 |
+
x[: len(trunc)] = trunc
|
168 |
else:
|
169 |
+
x[-len(trunc) :] = trunc
|
170 |
return x
|
171 |
|
172 |
|
|
|
180 |
def decorated(*args, **kwargs):
|
181 |
count = 5
|
182 |
while count:
|
|
|
183 |
try:
|
184 |
return f(*args, **kwargs)
|
185 |
except (
|
|
|
192 |
requests.exceptions.SSLError,
|
193 |
requests.exceptions.BaseHTTPError,
|
194 |
) as e:
|
195 |
+
print(colored("Training Exception: {}, will retry later".format(e)))
|
196 |
time.sleep(60)
|
197 |
count -= 1
|
198 |
|
|
|
202 |
def save_json(dic, save_path):
|
203 |
if isinstance(dic, str):
|
204 |
dic = eval(dic)
|
205 |
+
with open(save_path, "w", encoding="utf-8") as f:
|
206 |
# f.write(str(dict))
|
207 |
str_ = json.dumps(dic, ensure_ascii=False)
|
208 |
f.write(str_)
|
209 |
|
210 |
|
211 |
def load_json(save_path):
|
212 |
+
with open(save_path, "r", encoding="utf-8") as f:
|
213 |
data = f.readline().strip()
|
214 |
print(type(data), data)
|
215 |
dic = json.loads(data)
|
|
|
218 |
|
219 |
def init_optimizer(optimizer):
|
220 |
optimizers = {
|
221 |
+
"adadelta": torch.optim.Adadelta, # default lr=1.0
|
222 |
+
"adagrad": torch.optim.Adagrad, # default lr=0.01
|
223 |
+
"adam": torch.optim.Adam, # default lr=0.001
|
224 |
+
"adamax": torch.optim.Adamax, # default lr=0.002
|
225 |
+
"asgd": torch.optim.ASGD, # default lr=0.01
|
226 |
+
"rmsprop": torch.optim.RMSprop, # default lr=0.01
|
227 |
+
"sgd": torch.optim.SGD,
|
228 |
+
"adamw": torch.optim.AdamW,
|
229 |
torch.optim.Adadelta: torch.optim.Adadelta, # default lr=1.0
|
230 |
torch.optim.Adagrad: torch.optim.Adagrad, # default lr=0.01
|
231 |
torch.optim.Adam: torch.optim.Adam, # default lr=0.001
|
|
|
240 |
elif hasattr(torch.optim, optimizer.__name__):
|
241 |
return optimizer
|
242 |
else:
|
243 |
+
raise KeyError(
|
244 |
+
"Unsupported optimizer: {}. Please use string or the optimizer objects in torch.optim as your optimizer".format(
|
245 |
+
optimizer
|
246 |
+
)
|
247 |
+
)
|
anonymous_demo/utils/logger.py
CHANGED
@@ -5,22 +5,22 @@ import time
|
|
5 |
|
6 |
import termcolor
|
7 |
|
8 |
-
today = time.strftime(
|
9 |
|
10 |
|
11 |
-
def get_logger(log_path, log_name=
|
12 |
if not log_path:
|
13 |
log_dir = os.path.join(log_path, "logs")
|
14 |
else:
|
15 |
-
log_dir = os.path.join(
|
16 |
|
17 |
-
full_path = os.path.join(log_dir, log_name +
|
18 |
if not os.path.exists(full_path):
|
19 |
os.makedirs(full_path)
|
20 |
log_path = os.path.join(full_path, "{}.log".format(log_type))
|
21 |
logger = logging.getLogger(log_name)
|
22 |
if not logger.handlers:
|
23 |
-
formatter = logging.Formatter(
|
24 |
|
25 |
file_handler = logging.FileHandler(log_path, encoding="utf8")
|
26 |
file_handler.setFormatter(formatter)
|
|
|
5 |
|
6 |
import termcolor
|
7 |
|
8 |
+
today = time.strftime("%Y%m%d %H%M%S", time.localtime(time.time()))
|
9 |
|
10 |
|
11 |
+
def get_logger(log_path, log_name="", log_type="training_log"):
|
12 |
if not log_path:
|
13 |
log_dir = os.path.join(log_path, "logs")
|
14 |
else:
|
15 |
+
log_dir = os.path.join(".", "logs")
|
16 |
|
17 |
+
full_path = os.path.join(log_dir, log_name + "_" + today)
|
18 |
if not os.path.exists(full_path):
|
19 |
os.makedirs(full_path)
|
20 |
log_path = os.path.join(full_path, "{}.log".format(log_type))
|
21 |
logger = logging.getLogger(log_name)
|
22 |
if not logger.handlers:
|
23 |
+
formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s")
|
24 |
|
25 |
file_handler = logging.FileHandler(log_path, encoding="utf8")
|
26 |
file_handler.setFormatter(formatter)
|
app.py
CHANGED
@@ -18,6 +18,7 @@ from textattack.attack_recipes import (
|
|
18 |
IGAWang2019,
|
19 |
GeneticAlgorithmAlzantot2018,
|
20 |
DeepWordBugGao2018,
|
|
|
21 |
)
|
22 |
from textattack.attack_results import SuccessfulAttackResult
|
23 |
from textattack.datasets import Dataset
|
@@ -88,9 +89,10 @@ attack_recipes = {
|
|
88 |
"iga": IGAWang2019,
|
89 |
"GA": GeneticAlgorithmAlzantot2018,
|
90 |
"wordbugger": DeepWordBugGao2018,
|
|
|
91 |
}
|
92 |
|
93 |
-
for attacker in ["pwws", "bae", "textfooler"]:
|
94 |
for dataset in [
|
95 |
"agnews10k",
|
96 |
"amazon",
|
@@ -389,7 +391,10 @@ with demo:
|
|
389 |
"- To our best knowledge, Reactive Perturbation Defocusing is a novel approach in adversarial defense "
|
390 |
". RPD significantly (>10% defense accuracy improvement) outperforms the state-of-the-art methods."
|
391 |
)
|
392 |
-
|
|
|
|
|
|
|
393 |
|
394 |
gr.Markdown("## <p align='center'>Natural Example Input</p>")
|
395 |
with gr.Group():
|
@@ -400,7 +405,14 @@ with demo:
|
|
400 |
label="Select a testing dataset and an adversarial attacker to generate an adversarial example.",
|
401 |
)
|
402 |
input_attacker = gr.Radio(
|
403 |
-
choices=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
value="TextFooler",
|
405 |
label="Choose an Adversarial Attacker for generating an adversarial example to attack the model.",
|
406 |
)
|
@@ -414,7 +426,6 @@ with demo:
|
|
414 |
placeholder="Original label...", label="Original Label"
|
415 |
)
|
416 |
|
417 |
-
|
418 |
button_gen = gr.Button(
|
419 |
"Generate an adversarial example and repair using RPD (No GPU, Time:3-10 mins )",
|
420 |
variant="primary",
|
@@ -432,11 +443,14 @@ with demo:
|
|
432 |
output_adv_example = gr.Textbox(label="Adversarial Example")
|
433 |
output_adv_label = gr.Textbox(label="Perturbed Label")
|
434 |
with gr.Row():
|
435 |
-
output_repaired_example = gr.Textbox(
|
|
|
|
|
436 |
output_repaired_label = gr.Textbox(label="Repaired Label")
|
437 |
|
438 |
-
|
439 |
-
|
|
|
440 |
with gr.Group():
|
441 |
output_is_adv_df = gr.DataFrame(label="Adversarial Example Detection Result")
|
442 |
gr.Markdown(
|
@@ -444,9 +458,7 @@ with demo:
|
|
444 |
"The perturbed_label is the predicted label of the adversarial example. "
|
445 |
"The confidence field represents the confidence of the predicted adversarial example detection. "
|
446 |
)
|
447 |
-
output_df = gr.DataFrame(
|
448 |
-
label="Repaired Standard Classification Result"
|
449 |
-
)
|
450 |
gr.Markdown(
|
451 |
"If is_repaired=true, it has been repaired by RPD. "
|
452 |
"The pred_label field indicates the standard classification result. "
|
@@ -454,20 +466,19 @@ with demo:
|
|
454 |
"The is_correct field indicates whether the predicted label is correct."
|
455 |
)
|
456 |
|
457 |
-
|
458 |
gr.Markdown("## <p align='center'>Example Comparisons</p>")
|
459 |
ori_text_diff = gr.HighlightedText(
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
adv_text_diff = gr.HighlightedText(
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
restored_text_diff = gr.HighlightedText(
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
|
472 |
# Bind functions to buttons
|
473 |
button_gen.click(
|
|
|
18 |
IGAWang2019,
|
19 |
GeneticAlgorithmAlzantot2018,
|
20 |
DeepWordBugGao2018,
|
21 |
+
CLARE2020,
|
22 |
)
|
23 |
from textattack.attack_results import SuccessfulAttackResult
|
24 |
from textattack.datasets import Dataset
|
|
|
89 |
"iga": IGAWang2019,
|
90 |
"GA": GeneticAlgorithmAlzantot2018,
|
91 |
"wordbugger": DeepWordBugGao2018,
|
92 |
+
'clare': CLARE2020,
|
93 |
}
|
94 |
|
95 |
+
for attacker in ["pwws", "bae", "textfooler", "pso", "wordbugger", 'clare']:
|
96 |
for dataset in [
|
97 |
"agnews10k",
|
98 |
"amazon",
|
|
|
391 |
"- To our best knowledge, Reactive Perturbation Defocusing is a novel approach in adversarial defense "
|
392 |
". RPD significantly (>10% defense accuracy improvement) outperforms the state-of-the-art methods."
|
393 |
)
|
394 |
+
gr.Markdown(
|
395 |
+
"- The DeepWordBug, IGA, GA, PSO, and CLARE attackers are very slow on CPU Devices."
|
396 |
+
" And they are unknown attackers to RPD's adversarial detector. "
|
397 |
+
)
|
398 |
|
399 |
gr.Markdown("## <p align='center'>Natural Example Input</p>")
|
400 |
with gr.Group():
|
|
|
405 |
label="Select a testing dataset and an adversarial attacker to generate an adversarial example.",
|
406 |
)
|
407 |
input_attacker = gr.Radio(
|
408 |
+
choices=[
|
409 |
+
"BAE",
|
410 |
+
"PWWS",
|
411 |
+
"TextFooler",
|
412 |
+
"WordBugger",
|
413 |
+
"PSO",
|
414 |
+
"CLARE",
|
415 |
+
],
|
416 |
value="TextFooler",
|
417 |
label="Choose an Adversarial Attacker for generating an adversarial example to attack the model.",
|
418 |
)
|
|
|
426 |
placeholder="Original label...", label="Original Label"
|
427 |
)
|
428 |
|
|
|
429 |
button_gen = gr.Button(
|
430 |
"Generate an adversarial example and repair using RPD (No GPU, Time:3-10 mins )",
|
431 |
variant="primary",
|
|
|
443 |
output_adv_example = gr.Textbox(label="Adversarial Example")
|
444 |
output_adv_label = gr.Textbox(label="Perturbed Label")
|
445 |
with gr.Row():
|
446 |
+
output_repaired_example = gr.Textbox(
|
447 |
+
label="Repaired Adversarial Example by RPD"
|
448 |
+
)
|
449 |
output_repaired_label = gr.Textbox(label="Repaired Label")
|
450 |
|
451 |
+
gr.Markdown(
|
452 |
+
"## <p align='center'>The Output of Reactive Perturbation Defocusing</p>"
|
453 |
+
)
|
454 |
with gr.Group():
|
455 |
output_is_adv_df = gr.DataFrame(label="Adversarial Example Detection Result")
|
456 |
gr.Markdown(
|
|
|
458 |
"The perturbed_label is the predicted label of the adversarial example. "
|
459 |
"The confidence field represents the confidence of the predicted adversarial example detection. "
|
460 |
)
|
461 |
+
output_df = gr.DataFrame(label="Repaired Standard Classification Result")
|
|
|
|
|
462 |
gr.Markdown(
|
463 |
"If is_repaired=true, it has been repaired by RPD. "
|
464 |
"The pred_label field indicates the standard classification result. "
|
|
|
466 |
"The is_correct field indicates whether the predicted label is correct."
|
467 |
)
|
468 |
|
|
|
469 |
gr.Markdown("## <p align='center'>Example Comparisons</p>")
|
470 |
ori_text_diff = gr.HighlightedText(
|
471 |
+
label="The Original Natural Example",
|
472 |
+
combine_adjacent=True,
|
473 |
+
)
|
474 |
adv_text_diff = gr.HighlightedText(
|
475 |
+
label="Character Editions of Adversarial Example Compared to the Natural Example",
|
476 |
+
combine_adjacent=True,
|
477 |
+
)
|
478 |
restored_text_diff = gr.HighlightedText(
|
479 |
+
label="Character Editions of Repaired Adversarial Example Compared to the Natural Example",
|
480 |
+
combine_adjacent=True,
|
481 |
+
)
|
482 |
|
483 |
# Bind functions to buttons
|
484 |
button_gen.click(
|
requirements.txt
CHANGED
@@ -16,4 +16,4 @@ transformers>4.20.0
|
|
16 |
torch>1.0.0
|
17 |
sentencepiece
|
18 |
tensorflow_text
|
19 |
-
textattack
|
|
|
16 |
torch>1.0.0
|
17 |
sentencepiece
|
18 |
tensorflow_text
|
19 |
+
textattack[tensorflow]
|
textattack/attack_recipes/morpheus_tan_2020.py
CHANGED
@@ -27,7 +27,6 @@ class MorpheusTan2020(AttackRecipe):
|
|
27 |
|
28 |
@staticmethod
|
29 |
def build(model_wrapper):
|
30 |
-
|
31 |
#
|
32 |
# Goal is to minimize BLEU score between the model output given for the
|
33 |
# perturbed input sequence and the reference translation
|
|
|
27 |
|
28 |
@staticmethod
|
29 |
def build(model_wrapper):
|
|
|
30 |
#
|
31 |
# Goal is to minimize BLEU score between the model output given for the
|
32 |
# perturbed input sequence and the reference translation
|
textattack/attack_recipes/seq2sick_cheng_2018_blackbox.py
CHANGED
@@ -31,7 +31,6 @@ class Seq2SickCheng2018BlackBox(AttackRecipe):
|
|
31 |
|
32 |
@staticmethod
|
33 |
def build(model_wrapper, goal_function="non_overlapping"):
|
34 |
-
|
35 |
#
|
36 |
# Goal is non-overlapping output.
|
37 |
#
|
|
|
31 |
|
32 |
@staticmethod
|
33 |
def build(model_wrapper, goal_function="non_overlapping"):
|
|
|
34 |
#
|
35 |
# Goal is non-overlapping output.
|
36 |
#
|
textattack/attacker.py
CHANGED
@@ -105,8 +105,8 @@ class Attacker:
|
|
105 |
def simple_attack(self, text, label):
|
106 |
"""Internal method that carries out attack.
|
107 |
|
108 |
-
|
109 |
-
|
110 |
if torch.cuda.is_available():
|
111 |
self.attack.cuda_()
|
112 |
|
@@ -120,9 +120,11 @@ class Attacker:
|
|
120 |
except Exception as e:
|
121 |
raise e
|
122 |
# return
|
123 |
-
if (
|
124 |
-
|
125 |
-
|
|
|
|
|
126 |
):
|
127 |
return
|
128 |
else:
|
|
|
105 |
def simple_attack(self, text, label):
|
106 |
"""Internal method that carries out attack.
|
107 |
|
108 |
+
No parallel processing is involved.
|
109 |
+
"""
|
110 |
if torch.cuda.is_available():
|
111 |
self.attack.cuda_()
|
112 |
|
|
|
120 |
except Exception as e:
|
121 |
raise e
|
122 |
# return
|
123 |
+
if (
|
124 |
+
isinstance(result, SkippedAttackResult) and self.attack_args.attack_n
|
125 |
+
) or (
|
126 |
+
not isinstance(result, SuccessfulAttackResult)
|
127 |
+
and self.attack_args.num_successful_examples
|
128 |
):
|
129 |
return
|
130 |
else:
|
textattack/commands/augment_command.py
CHANGED
@@ -32,7 +32,6 @@ class AugmentCommand(TextAttackCommand):
|
|
32 |
|
33 |
args = textattack.AugmenterArgs(**vars(args))
|
34 |
if args.interactive:
|
35 |
-
|
36 |
print("\nRunning in interactive mode...\n")
|
37 |
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])(
|
38 |
pct_words_to_swap=args.pct_words_to_swap,
|
|
|
32 |
|
33 |
args = textattack.AugmenterArgs(**vars(args))
|
34 |
if args.interactive:
|
|
|
35 |
print("\nRunning in interactive mode...\n")
|
36 |
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])(
|
37 |
pct_words_to_swap=args.pct_words_to_swap,
|
textattack/commands/eval_model_command.py
CHANGED
@@ -56,7 +56,7 @@ class EvalModelCommand(TextAttackCommand):
|
|
56 |
while i < min(args.num_examples, len(dataset)):
|
57 |
dataset_batch = dataset[i : min(args.num_examples, i + args.batch_size)]
|
58 |
batch_inputs = []
|
59 |
-
for
|
60 |
attacked_text = textattack.shared.AttackedText(text_input)
|
61 |
batch_inputs.append(attacked_text.tokenizer_input)
|
62 |
ground_truth_outputs.append(ground_truth_output)
|
|
|
56 |
while i < min(args.num_examples, len(dataset)):
|
57 |
dataset_batch = dataset[i : min(args.num_examples, i + args.batch_size)]
|
58 |
batch_inputs = []
|
59 |
+
for text_input, ground_truth_output in dataset_batch:
|
60 |
attacked_text = textattack.shared.AttackedText(text_input)
|
61 |
batch_inputs.append(attacked_text.tokenizer_input)
|
62 |
ground_truth_outputs.append(ground_truth_output)
|
textattack/constraints/overlap/max_words_perturbed.py
CHANGED
@@ -38,7 +38,6 @@ class MaxWordsPerturbed(Constraint):
|
|
38 |
self.max_percent = max_percent
|
39 |
|
40 |
def _check_constraint(self, transformed_text, reference_text):
|
41 |
-
|
42 |
num_words_diff = len(transformed_text.all_words_diff(reference_text))
|
43 |
if self.max_percent:
|
44 |
min_num_words = min(len(transformed_text.words), len(reference_text.words))
|
|
|
38 |
self.max_percent = max_percent
|
39 |
|
40 |
def _check_constraint(self, transformed_text, reference_text):
|
|
|
41 |
num_words_diff = len(transformed_text.all_words_diff(reference_text))
|
42 |
if self.max_percent:
|
43 |
min_num_words = min(len(transformed_text.words), len(reference_text.words))
|
textattack/goal_function_results/classification_goal_function_result.py
CHANGED
@@ -26,7 +26,6 @@ class ClassificationGoalFunctionResult(GoalFunctionResult):
|
|
26 |
num_queries,
|
27 |
ground_truth_output,
|
28 |
):
|
29 |
-
|
30 |
super().__init__(
|
31 |
attacked_text,
|
32 |
raw_output,
|
|
|
26 |
num_queries,
|
27 |
ground_truth_output,
|
28 |
):
|
|
|
29 |
super().__init__(
|
30 |
attacked_text,
|
31 |
raw_output,
|
textattack/goal_function_results/text_to_text_goal_function_result.py
CHANGED
@@ -23,7 +23,6 @@ class TextToTextGoalFunctionResult(GoalFunctionResult):
|
|
23 |
num_queries,
|
24 |
ground_truth_output,
|
25 |
):
|
26 |
-
|
27 |
super().__init__(
|
28 |
attacked_text,
|
29 |
raw_output,
|
|
|
23 |
num_queries,
|
24 |
ground_truth_output,
|
25 |
):
|
|
|
26 |
super().__init__(
|
27 |
attacked_text,
|
28 |
raw_output,
|
textattack/loggers/weights_and_biases_logger.py
CHANGED
@@ -13,7 +13,6 @@ class WeightsAndBiasesLogger(Logger):
|
|
13 |
"""Logs attack results to Weights & Biases."""
|
14 |
|
15 |
def __init__(self, **kwargs):
|
16 |
-
|
17 |
global wandb
|
18 |
wandb = LazyLoader("wandb", globals(), "wandb")
|
19 |
|
|
|
13 |
"""Logs attack results to Weights & Biases."""
|
14 |
|
15 |
def __init__(self, **kwargs):
|
|
|
16 |
global wandb
|
17 |
wandb = LazyLoader("wandb", globals(), "wandb")
|
18 |
|
textattack/metrics/quality_metrics/perplexity.py
CHANGED
@@ -94,7 +94,6 @@ class Perplexity(Metric):
|
|
94 |
return self.all_metrics
|
95 |
|
96 |
def calc_ppl(self, texts):
|
97 |
-
|
98 |
with torch.no_grad():
|
99 |
text = " ".join(texts)
|
100 |
eval_loss = []
|
|
|
94 |
return self.all_metrics
|
95 |
|
96 |
def calc_ppl(self, texts):
|
|
|
97 |
with torch.no_grad():
|
98 |
text = " ".join(texts)
|
99 |
eval_loss = []
|
textattack/models/wrappers/demo_model_wrapper.py
CHANGED
@@ -2,14 +2,14 @@ from textattack.models.wrappers import HuggingFaceModelWrapper
|
|
2 |
|
3 |
|
4 |
class TADModelWrapper(HuggingFaceModelWrapper):
|
5 |
-
"""
|
6 |
-
|
7 |
|
8 |
-
|
9 |
|
10 |
-
|
11 |
|
12 |
-
|
13 |
"""
|
14 |
|
15 |
def __init__(self, model):
|
@@ -19,6 +19,6 @@ class TADModelWrapper(HuggingFaceModelWrapper):
|
|
19 |
outputs = []
|
20 |
for text_input in text_inputs:
|
21 |
raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
|
22 |
-
outputs.append(raw_outputs[
|
23 |
|
24 |
return outputs
|
|
|
2 |
|
3 |
|
4 |
class TADModelWrapper(HuggingFaceModelWrapper):
|
5 |
+
"""Transformers sentiment analysis pipeline returns a list of responses
|
6 |
+
like
|
7 |
|
8 |
+
[{'label': 'POSITIVE', 'score': 0.7817379832267761}]
|
9 |
|
10 |
+
We need to convert that to a format TextAttack understands, like
|
11 |
|
12 |
+
[[0.218262017, 0.7817379832267761]
|
13 |
"""
|
14 |
|
15 |
def __init__(self, model):
|
|
|
19 |
outputs = []
|
20 |
for text_input in text_inputs:
|
21 |
raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
|
22 |
+
outputs.append(raw_outputs["probs"])
|
23 |
|
24 |
return outputs
|
textattack/reactive_defense/reactive_defender.py
CHANGED
@@ -4,7 +4,6 @@ from textattack.shared.utils import ReprMixin
|
|
4 |
|
5 |
|
6 |
class ReactiveDefender(ReprMixin, ABC):
|
7 |
-
|
8 |
def __init__(self, **kwargs):
|
9 |
pass
|
10 |
|
|
|
4 |
|
5 |
|
6 |
class ReactiveDefender(ReprMixin, ABC):
|
|
|
7 |
def __init__(self, **kwargs):
|
8 |
pass
|
9 |
|
textattack/reactive_defense/tad_reactive_defender.py
CHANGED
@@ -5,21 +5,24 @@ from textattack.reactive_defense.reactive_defender import ReactiveDefender
|
|
5 |
|
6 |
|
7 |
class TADReactiveDefender(ReactiveDefender):
|
8 |
-
"""
|
9 |
-
|
10 |
|
11 |
-
|
12 |
|
13 |
-
|
14 |
|
15 |
-
|
16 |
"""
|
17 |
|
18 |
-
def __init__(self, ckpt=
|
19 |
super().__init__(**kwargs)
|
20 |
-
self.tad_classifier = TADCheckpointManager.get_tad_text_classifier(
|
21 |
-
|
|
|
22 |
|
23 |
def reactive_defense(self, text, **kwargs):
|
24 |
-
res = self.tad_classifier.infer(
|
|
|
|
|
25 |
return res
|
|
|
5 |
|
6 |
|
7 |
class TADReactiveDefender(ReactiveDefender):
|
8 |
+
"""Transformers sentiment analysis pipeline returns a list of responses
|
9 |
+
like
|
10 |
|
11 |
+
[{'label': 'POSITIVE', 'score': 0.7817379832267761}]
|
12 |
|
13 |
+
We need to convert that to a format TextAttack understands, like
|
14 |
|
15 |
+
[[0.218262017, 0.7817379832267761]
|
16 |
"""
|
17 |
|
18 |
+
def __init__(self, ckpt="tad-sst2", **kwargs):
|
19 |
super().__init__(**kwargs)
|
20 |
+
self.tad_classifier = TADCheckpointManager.get_tad_text_classifier(
|
21 |
+
checkpoint=DEMO_MODELS[ckpt], auto_device=True
|
22 |
+
)
|
23 |
|
24 |
def reactive_defense(self, text, **kwargs):
|
25 |
+
res = self.tad_classifier.infer(
|
26 |
+
text, defense="pwws", print_result=False, **kwargs
|
27 |
+
)
|
28 |
return res
|
textattack/search_methods/greedy_word_swap_wir.py
CHANGED
@@ -65,7 +65,6 @@ class GreedyWordSwapWIR(SearchMethod):
|
|
65 |
# compute the largest change in score we can find by swapping each word
|
66 |
delta_ps = []
|
67 |
for idx in indices_to_order:
|
68 |
-
|
69 |
# Exit Loop when search_over is True - but we need to make sure delta_ps
|
70 |
# is the same size as softmax_saliency_scores
|
71 |
if search_over:
|
|
|
65 |
# compute the largest change in score we can find by swapping each word
|
66 |
delta_ps = []
|
67 |
for idx in indices_to_order:
|
|
|
68 |
# Exit Loop when search_over is True - but we need to make sure delta_ps
|
69 |
# is the same size as softmax_saliency_scores
|
70 |
if search_over:
|
textattack/shared/validators.py
CHANGED
@@ -24,7 +24,10 @@ MODELS_BY_GOAL_FUNCTIONS = {
|
|
24 |
r"^textattack.models.helpers.word_cnn_for_classification.*",
|
25 |
r"^transformers.modeling_\w*\.\w*ForSequenceClassification$",
|
26 |
],
|
27 |
-
(
|
|
|
|
|
|
|
28 |
r"^textattack.models.helpers.t5_for_text_to_text.*",
|
29 |
],
|
30 |
}
|
|
|
24 |
r"^textattack.models.helpers.word_cnn_for_classification.*",
|
25 |
r"^transformers.modeling_\w*\.\w*ForSequenceClassification$",
|
26 |
],
|
27 |
+
(
|
28 |
+
NonOverlappingOutput,
|
29 |
+
MinimizeBleu,
|
30 |
+
): [
|
31 |
r"^textattack.models.helpers.t5_for_text_to_text.*",
|
32 |
],
|
33 |
}
|
textattack/trainer.py
CHANGED
@@ -398,6 +398,7 @@ class Trainer:
|
|
398 |
Returns:
|
399 |
:obj:`torch.utils.data.DataLoader`
|
400 |
"""
|
|
|
401 |
# TODO: Add pairing option where we can pair original examples with adversarial examples.
|
402 |
# Helper functions for collating data
|
403 |
def collate_fn(data):
|
@@ -406,7 +407,6 @@ class Trainer:
|
|
406 |
is_adv_sample = []
|
407 |
for item in data:
|
408 |
if "_example_type" in item[0].keys():
|
409 |
-
|
410 |
# Get example type value from OrderedDict and remove it
|
411 |
|
412 |
adv = item[0].pop("_example_type")
|
@@ -460,6 +460,7 @@ class Trainer:
|
|
460 |
Returns:
|
461 |
:obj:`torch.utils.data.DataLoader`
|
462 |
"""
|
|
|
463 |
# Helper functions for collating data
|
464 |
def collate_fn(data):
|
465 |
input_texts = []
|
|
|
398 |
Returns:
|
399 |
:obj:`torch.utils.data.DataLoader`
|
400 |
"""
|
401 |
+
|
402 |
# TODO: Add pairing option where we can pair original examples with adversarial examples.
|
403 |
# Helper functions for collating data
|
404 |
def collate_fn(data):
|
|
|
407 |
is_adv_sample = []
|
408 |
for item in data:
|
409 |
if "_example_type" in item[0].keys():
|
|
|
410 |
# Get example type value from OrderedDict and remove it
|
411 |
|
412 |
adv = item[0].pop("_example_type")
|
|
|
460 |
Returns:
|
461 |
:obj:`torch.utils.data.DataLoader`
|
462 |
"""
|
463 |
+
|
464 |
# Helper functions for collating data
|
465 |
def collate_fn(data):
|
466 |
input_texts = []
|
textattack/training_args.py
CHANGED
@@ -547,7 +547,6 @@ class _CommandLineTrainingArgs:
|
|
547 |
train_dataset.output_column == "label"
|
548 |
and eval_dataset.output_column == "label"
|
549 |
):
|
550 |
-
|
551 |
train_dataset_labels = train_dataset._dataset["label"]
|
552 |
|
553 |
eval_dataset_labels = eval_dataset._dataset["label"]
|
|
|
547 |
train_dataset.output_column == "label"
|
548 |
and eval_dataset.output_column == "label"
|
549 |
):
|
|
|
550 |
train_dataset_labels = train_dataset._dataset["label"]
|
551 |
|
552 |
eval_dataset_labels = eval_dataset._dataset["label"]
|
textattack/transformations/word_swaps/word_swap_change_name.py
CHANGED
@@ -64,7 +64,6 @@ class WordSwapChangeName(WordSwap):
|
|
64 |
return transformed_texts
|
65 |
|
66 |
def _get_replacement_words(self, word, word_part_of_speech):
|
67 |
-
|
68 |
replacement_words = []
|
69 |
tag = word_part_of_speech
|
70 |
if (
|
|
|
64 |
return transformed_texts
|
65 |
|
66 |
def _get_replacement_words(self, word, word_part_of_speech):
|
|
|
67 |
replacement_words = []
|
68 |
tag = word_part_of_speech
|
69 |
if (
|
textattack/transformations/word_swaps/word_swap_change_number.py
CHANGED
@@ -70,7 +70,7 @@ class WordSwapChangeNumber(WordSwap):
|
|
70 |
|
71 |
# replace original numbers with new numbers
|
72 |
transformed_texts = []
|
73 |
-
for
|
74 |
replacement_words = self._get_new_number(word)
|
75 |
for r in replacement_words:
|
76 |
if r == word:
|
|
|
70 |
|
71 |
# replace original numbers with new numbers
|
72 |
transformed_texts = []
|
73 |
+
for idx, word in num_words:
|
74 |
replacement_words = self._get_new_number(word)
|
75 |
for r in replacement_words:
|
76 |
if r == word:
|