andreslu commited on
Commit
0f14897
1 Parent(s): 5abcb03

Upload 25 files

Browse files
evaluation.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import re
4
+ from datetime import datetime
5
+ import os
6
+
7
+ import numpy as np
8
+ import torch
9
+ from nltk import bleu, meteor
10
+ from rouge_score.rouge_scorer import RougeScorer
11
+ from tqdm import tqdm
12
+ from src.distinct_n.distinct_n.metrics import distinct_n_corpus_level as distinct_n
13
+
14
+ from inductor import BartInductor, CometInductor
15
+
16
+ FILES = {
17
+ 'amie-yago2': 'data/RE-datasets/AMIE-yago2.txt',
18
+ 'rules-yago2': 'data/RE-datasets/RuLES-yago2.txt',
19
+ "openrule155": "data/OpenRule155.txt",
20
+ 'fewrel': 'data/RE/fewrel-5.txt',
21
+ 'semeval': 'data/RE/semeval-5.txt',
22
+ 'TREx': 'data/RE/trex-5.txt',
23
+ 'nyt10': 'data/RE/nyt10-5.txt',
24
+ 'google-re': 'data/RE/google-re-5.txt',
25
+ 'wiki80': 'data/RE/wiki80-5.txt',
26
+ }
27
+
28
+
29
+ if not os.path.exists('logs/'):
30
+ os.mkdir('logs/')
31
+
32
+ logging.basicConfig(
33
+ filename='logs/evaluation-{}.log'.format(str(datetime.now())),
34
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
35
+ datefmt='%m/%d/%Y %H:%M:%S',
36
+ level=logging.INFO)
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ def print_config(config):
41
+ config = vars(config)
42
+ logger.info("**************** MODEL CONFIGURATION ****************")
43
+ for key in sorted(config.keys()):
44
+ val = config[key]
45
+ keystr = "{}".format(key) + (" " * (25 - len(key)))
46
+ logger.info("{} --> {}".format(keystr, val))
47
+ logger.info("**************** MODEL CONFIGURATION ****************")
48
+
49
+ scorer = RougeScorer(['rougeL'], use_stemmer=True)
50
+
51
+ def rouge(references, hypothesis):
52
+ scores = []
53
+ for reference in references:
54
+ scores.append(
55
+ scorer.score(
56
+ reference,
57
+ hypothesis)['rougeL'][2]
58
+ )
59
+
60
+ return max(scores)
61
+
62
+
63
+ class RelationExtractionEvaluator(object):
64
+ def __init__(self, args):
65
+ self.args = args
66
+ if self.args.inductor == 'rule':
67
+ self.inductor = BartInductor(
68
+ group_beam=self.args.group_beam,
69
+ continue_pretrain_instance_generator=self.args.mlm_training,
70
+ continue_pretrain_hypo_generator=self.args.bart_training,
71
+ if_then=self.args.if_then,
72
+ )
73
+ elif self.args.inductor == 'comet':
74
+ self.inductor = CometInductor()
75
+
76
+ def clean(self, text):
77
+ segments = text.split('<mask>')
78
+ if len(segments) == 3 and segments[2].startswith('.'):
79
+ return '<mask>'.join(segments[:2]) + '<mask>.'
80
+ else:
81
+ return text
82
+
83
+ def clean_references(self, texts):
84
+ for i, text in enumerate(texts):
85
+ if text.endswith(" ."):
86
+ texts[i] = text.replace(" .", ".")
87
+
88
+ return texts
89
+
90
+ def self_bleu(self, hypothesis):
91
+ bleus = []
92
+ for i in range(len(hypothesis)):
93
+ bleus.append(bleu(
94
+ hypothesis[:i] + hypothesis[i + 1:],
95
+ hypothesis[i],
96
+ weights=(0.5, 0.5)))
97
+
98
+ ret = np.mean(bleus)
99
+ return ret
100
+
101
+ def evaluate(self, task):
102
+ with torch.no_grad():
103
+ self.metrics = {
104
+ "bleu-4": [],
105
+ "bleu-3": [],
106
+ "bleu-2": [],
107
+ "bleu-1": [],
108
+ "METEOR": [],
109
+ "ROUGE-L": [],
110
+ "self-BLEU-2": [],
111
+ }
112
+ with open(FILES[task], 'r', encoding='utf-8') as file:
113
+ data = file.readlines()
114
+ with tqdm(total=len(data)) as pbar:
115
+ for row in data:
116
+ pbar.update(1)
117
+ row = row.strip().split('\t')
118
+ inputs, head, tail, relations = row[0], row[1], row[2], row[3]
119
+ inputs = inputs.strip()
120
+
121
+ if relations.startswith('[') and relations.endswith(']'):
122
+ inputs = re.sub("<A>|<B>", "<mask>", inputs)
123
+ references = [relation.replace('<A>', '<mask>').replace('<B>', '<mask>').lower().strip() for relation in eval(relations)]
124
+ else:
125
+ references = [relations.replace('[X]', '<mask>').replace('[Y]', '<mask>').lower().strip()]
126
+ references = self.clean_references(references)
127
+ hypothesis = self.inductor.generate(inputs, k=10, topk=10)
128
+
129
+ logger.info("***********Input************")
130
+ logger.info(inputs)
131
+ logger.info("*********Hypothesis*********")
132
+ for i, hypo in enumerate(hypothesis):
133
+ hypothesis[i] = self.clean(hypo.lower().strip())
134
+ logger.info(hypo)
135
+
136
+ logger.info("****************************")
137
+ logger.info("*********References*********")
138
+ logger.info(references)
139
+ logger.info("****************************")
140
+
141
+ if len(hypothesis) == 0:
142
+ for k in self.metrics.keys():
143
+ if k != 'self-BLEU-2':
144
+ self.metrics[k].append(0.)
145
+
146
+ else:
147
+ for hypo in hypothesis:
148
+ try:
149
+ self.metrics['bleu-4'].append(
150
+ bleu(
151
+ [reference.split() for reference in references],
152
+ hypo.split(),
153
+ weights=(0.25, 0.25, 0.25, 0.25)
154
+ )
155
+ )
156
+ except Exception:
157
+ logger.warning("Skip bleu-4 in example: {}".format(inputs))
158
+ pass
159
+
160
+ try:
161
+ self.metrics['bleu-3'].append(
162
+ bleu(
163
+ [reference.split() for reference in references],
164
+ hypo.split(),
165
+ weights=(1 / 3, ) * 3
166
+ )
167
+ )
168
+ except Exception:
169
+ logger.warning("Skip bleu-3 in example: {}".format(inputs))
170
+ pass
171
+
172
+ try:
173
+ self.metrics['bleu-2'].append(
174
+ bleu(
175
+ [reference.split() for reference in references],
176
+ hypo.split(),
177
+ weights=(0.5, 0.5)
178
+ )
179
+ )
180
+ except Exception:
181
+ logger.warning("Skip bleu-2 in example: {}".format(inputs))
182
+ pass
183
+
184
+ try:
185
+ self.metrics['bleu-1'].append(
186
+ bleu(
187
+ [reference.split() for reference in references],
188
+ hypo.split(),
189
+ weights=(1.0, )
190
+ )
191
+ )
192
+ except Exception:
193
+ logger.warning("Skip bleu-1 in example: {}".format(inputs))
194
+ pass
195
+
196
+ try:
197
+ self.metrics['METEOR'].append(
198
+ meteor(
199
+ references,
200
+ hypo,
201
+ )
202
+ )
203
+ except:
204
+ logger.warning("Skip METEOR in example: {}".format(inputs))
205
+ pass
206
+
207
+
208
+ try:
209
+ self.metrics['ROUGE-L'].append(
210
+ rouge(
211
+ references,
212
+ hypo,
213
+ )
214
+ )
215
+ except:
216
+ logger.warning("Skip ROUGE-L in example: {}".format(inputs))
217
+ pass
218
+ try:
219
+ self.metrics['self-BLEU-2'].append(
220
+ self.self_bleu(
221
+ hypothesis,
222
+ )
223
+ )
224
+ except:
225
+ logger.warning("Skip self-bleu-2 in example: {}.".format(inputs))
226
+ pass
227
+ # break
228
+
229
+ self.print(task, self.metrics)
230
+
231
+ def print(self, task, metrics):
232
+ logger.info("Task: {}".format(str(task)))
233
+ for k, v in metrics.items():
234
+ logger.info("{}: {}".format(k, str(np.mean(v))))
235
+
236
+ logger.info("*******************************************************")
237
+ logger.info("*******************************************************")
238
+ logger.info("*******************************************************")
239
+
240
+
241
+ if __name__ == '__main__':
242
+ parser = argparse.ArgumentParser()
243
+ parser.add_argument("--inductor", type=str, default='rule')
244
+ parser.add_argument("--group_beam", type=bool, default=False)
245
+ parser.add_argument("--mlm_training", type=bool, default=False)
246
+ parser.add_argument("--bart_training", type=bool, default=False)
247
+ parser.add_argument("--if_then", type=bool, default=False)
248
+ parser.add_argument("--task", type=str, default='openrule155')
249
+
250
+ args = parser.parse_args()
251
+
252
+ print_config(args)
253
+ evaluator = RelationExtractionEvaluator(args)
254
+ evaluator.evaluate(args.task)
expbert.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import random
5
+ from datetime import datetime
6
+
7
+ import numpy as np
8
+ import torch
9
+ from sklearn.metrics import accuracy_score, f1_score
10
+ from torch import nn
11
+ from torch.utils.data import DataLoader, Dataset
12
+ from tqdm import tqdm
13
+ from transformers import (AutoConfig, AutoModel,
14
+ AutoModelForSequenceClassification, AutoTokenizer,
15
+ BertForSequenceClassification, BertModel)
16
+
17
+ if not os.path.exists('logs/'):
18
+ os.mkdir('logs/')
19
+
20
+ logging.basicConfig(
21
+ filename='logs/expbert-{}.log'.format(str(datetime.now())),
22
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
23
+ datefmt='%m/%d/%Y %H:%M:%S',
24
+ level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ TASK2PATH = {
29
+ "disease-train": "data/disease/train.txt",
30
+ "disease-test": "data/disease/test.txt",
31
+ "spouse-train": "data/spouse/train.txt",
32
+ "spouse-test": "data/spouse/test.txt",
33
+ }
34
+
35
+ ANNOTATED_EXP = {
36
+ "spouse": "data/exp/expbert_spouse_explanation.txt",
37
+ "disease": "data/exp/expbert_disease_explanation.txt",
38
+ }
39
+
40
+ GENERATED_EXP = {
41
+ "spouse": "data/exp/orion_spouse_explanation.txt",
42
+ "disease": "data/exp/orion_disease_explanation.txt",
43
+ }
44
+
45
+
46
+ def set_random_seed(seed):
47
+ random.seed(seed)
48
+ np.random.seed(seed)
49
+ torch.manual_seed(seed)
50
+ torch.cuda.manual_seed(seed)
51
+ torch.cuda.manual_seed_all(seed)
52
+ torch.backends.cudnn.deterministic = True
53
+ torch.backends.cudnn.benchmark = False
54
+
55
+
56
+ def print_config(config):
57
+ config = vars(config)
58
+ logger.info("**************** MODEL CONFIGURATION ****************")
59
+ for key in sorted(config.keys()):
60
+ val = config[key]
61
+ keystr = "{}".format(key) + (" " * (25 - len(key)))
62
+ logger.info("{} --> {}".format(keystr, val))
63
+ logger.info("**************** MODEL CONFIGURATION ****************")
64
+
65
+
66
+ class ExpBERT(nn.Module):
67
+ def __init__(self, args, exp_num):
68
+ super(ExpBERT, self).__init__()
69
+ self.args = args
70
+ self.exp_num = exp_num
71
+ self.config = AutoConfig.from_pretrained(args.model)
72
+ self.model = AutoModel.from_pretrained(args.model, config=self.config)
73
+ self.dropout = nn.Dropout(p=0.1)
74
+ self.linear = nn.Linear(self.config.hidden_size * exp_num, 2)
75
+
76
+ self.criterion = nn.CrossEntropyLoss()
77
+
78
+ def forward(self, inputs):
79
+ for k, v in inputs["encoding"].items():
80
+ inputs["encoding"][k] = v.cuda()
81
+ pooler_output = self.model(**inputs["encoding"]).last_hidden_state[:, 0, :].reshape(1, self.exp_num * self.config.hidden_size)
82
+ pooler_output = self.dropout(pooler_output)
83
+ logits = self.linear(pooler_output)
84
+
85
+ loss = self.criterion(logits, torch.LongTensor([inputs["label"]]).cuda())
86
+ prediction = torch.argmax(logits)
87
+
88
+ return {
89
+ "loss": loss,
90
+ "prediction": prediction,
91
+ }
92
+
93
+
94
+ class REDataset(Dataset):
95
+ def __init__(self, path, exp, tokenizer):
96
+ super(REDataset, self).__init__()
97
+ self.tokenizer = tokenizer
98
+ self.exp = exp
99
+ self.sentences = []
100
+ self.labels = []
101
+ self.entities = []
102
+ with open(path, "r", encoding="utf-8") as file:
103
+ data = file.readlines()
104
+ for example in data:
105
+ sentence, entity1, entity2, id, label = example.strip().split("\t")
106
+ self.sentences.append(sentence)
107
+ if eval(label) == 1:
108
+ self.labels.append(1)
109
+ elif eval(label) == -1:
110
+ self.labels.append(0)
111
+
112
+ self.entities.append([entity1, entity2])
113
+
114
+ logger.info("Number of Example in {}: {}".format(path, str(len(self.labels))))
115
+
116
+ def __len__(self):
117
+ return len(self.labels)
118
+
119
+ def __getitem__(self, index):
120
+ return {
121
+ "sentence": self.sentences[index],
122
+ "entity": self.entities[index],
123
+ "label": self.labels[index],
124
+ }
125
+
126
+ def collate_fn(self, batch):
127
+ outputs = []
128
+ for ex in batch:
129
+ temp = []
130
+ for exp in self.exp:
131
+ if "{e1}" in exp or "{e2}" in exp:
132
+ exp = exp.replace("{e1}", ex["entity"][0]).replace("{e2}", ex["entity"][1])
133
+ else:
134
+ for entity in ex["entity"]:
135
+ index = exp.index('<mask>')
136
+ exp = exp[:index] + entity + exp[index + len('<mask>'):]
137
+ temp.append(exp)
138
+ outputs.append(
139
+ {
140
+ "encoding": self.tokenizer(
141
+ [ex["sentence"]] * len(temp), temp,
142
+ add_special_tokens=True,
143
+ padding="longest",
144
+ truncation=True,
145
+ max_length=156,
146
+ return_tensors="pt",
147
+ return_attention_mask=True,
148
+ return_token_type_ids=True,
149
+ ),
150
+ "label": ex["label"],
151
+ }
152
+ )
153
+ return outputs
154
+
155
+ def collate_fn_(self, batch):
156
+ texts = []
157
+ labels = []
158
+ for ex in batch:
159
+ texts.append(ex["sentence"])
160
+ labels.append(ex["label"])
161
+
162
+ outputs = self.tokenizer(
163
+ texts,
164
+ add_special_tokens=True,
165
+ padding="longest",
166
+ truncation=True,
167
+ max_length=156,
168
+ return_tensors="pt",
169
+ return_attention_mask=True,
170
+ return_token_type_ids=True,
171
+ )
172
+
173
+ outputs["labels"] = torch.LongTensor(labels)
174
+
175
+ return outputs
176
+
177
+
178
+ class Trainer(object):
179
+ def __init__(self, args):
180
+ self.args = args
181
+ print_config(args)
182
+ self.tokenizer = AutoTokenizer.from_pretrained(self.args.model)
183
+
184
+ TASK2EXP = GENERATED_EXP if args.generated_rules else ANNOTATED_EXP
185
+ with open(TASK2EXP[args.task], "r", encoding="utf-8") as file:
186
+ exp = file.readlines()
187
+
188
+ self.train_dataset = REDataset(TASK2PATH['{}-train'.format(args.task)], exp, self.tokenizer)
189
+ self.test_dataset = REDataset(TASK2PATH['{}-test'.format(args.task)], exp, self.tokenizer)
190
+ self.model = AutoModelForSequenceClassification.from_pretrained(args.model).cuda() if self.args.no_exp else ExpBERT(args, len(exp)).cuda()
191
+
192
+ self.train_loader = DataLoader(
193
+ self.train_dataset,
194
+ batch_size=args.batch_size,
195
+ shuffle=args.shuffle,
196
+ collate_fn=self.train_dataset.collate_fn_ if self.args.no_exp else self.train_dataset.collate_fn,
197
+ )
198
+
199
+ self.test_loader = DataLoader(
200
+ self.test_dataset,
201
+ batch_size=args.batch_size,
202
+ shuffle=args.shuffle,
203
+ collate_fn=self.test_dataset.collate_fn_ if self.args.no_exp else self.test_dataset.collate_fn,
204
+ )
205
+
206
+ self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.learning_rate)
207
+
208
+ def compute_metrics(self, labels, predictions):
209
+ accuracy = accuracy_score(y_pred=predictions, y_true=labels)
210
+ f1 = f1_score(y_pred=predictions, y_true=labels)
211
+
212
+ return accuracy, f1
213
+
214
+ def train(self):
215
+ self.model.train()
216
+ self.test(-1)
217
+ for e in range(self.args.epochs):
218
+ with tqdm(total=len(self.train_loader)) as pbar:
219
+ for step, examples in enumerate(self.train_loader):
220
+ self.model.zero_grad()
221
+ if self.args.no_exp:
222
+ for k, v in examples.items():
223
+ examples[k] = v.cuda()
224
+ outputs = self.model(**examples)
225
+ outputs.loss.backward()
226
+
227
+ else:
228
+ for ex in examples:
229
+ outputs = self.model(ex)
230
+ (outputs["loss"] / len(examples)).backward()
231
+
232
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
233
+ self.optimizer.step()
234
+ pbar.update(1)
235
+
236
+ self.test(e)
237
+
238
+ def test(self, epoch):
239
+ self.model.eval()
240
+ with torch.no_grad():
241
+ with tqdm(total=len(self.test_loader)) as pbar:
242
+ loss = []
243
+ labels = []
244
+ predictions = []
245
+ for step, examples in enumerate(self.test_loader):
246
+ if self.args.no_exp:
247
+ for k, v in examples.items():
248
+ examples[k] = v.cuda()
249
+ outputs = self.model(**examples)
250
+ loss.append(outputs.loss.float())
251
+ labels.extend(examples["labels"].tolist())
252
+ predictions.extend(torch.argmax(outputs.logits, dim=1).tolist())
253
+
254
+ else:
255
+ for ex in examples:
256
+ labels.append(ex['label'])
257
+ outputs = self.model(ex)
258
+ loss.append(outputs["loss"].item())
259
+ predictions.append(outputs['prediction'].tolist())
260
+
261
+ pbar.update(1)
262
+ accuracy, f1 = self.compute_metrics(predictions, labels)
263
+ logger.info("[EPOCH {}] Accuracy: {} | F1-Score: {}. (Number of Data {})".format(epoch, accuracy, f1, len(predictions)))
264
+
265
+
266
+ if __name__ == "__main__":
267
+ parser = argparse.ArgumentParser()
268
+ parser.add_argument("--task", type=str, default="spouse")
269
+ parser.add_argument("--model", type=str, default="bert-base-uncased")
270
+ parser.add_argument("--batch_size", type=int, default=32)
271
+ parser.add_argument("--learning_rate", type=float, default=2e-5)
272
+ parser.add_argument("--shuffle", type=bool, default=False)
273
+ parser.add_argument("--epochs", type=int, default=5)
274
+ parser.add_argument("--no_exp", type=bool, default=False)
275
+ parser.add_argument("--generated_rules", type=bool, default=False)
276
+
277
+ args = parser.parse_args()
278
+
279
+ for seed in range(42, 47):
280
+ set_random_seed(seed)
281
+ trainer = Trainer(args)
282
+ trainer.train()
inductor.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from copy import deepcopy
3
+
4
+ import argparse
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer,
8
+ BartForConditionalGeneration, BartTokenizer,)
9
+
10
+ from src.bart_with_group_beam import BartForConditionalGeneration_GroupBeam
11
+ from src.utils import (construct_template, filter_words,
12
+ formalize_tA, post_process_template)
13
+
14
+ ORION_HYPO_GENERATOR = 'chenxran/orion-hypothesis-generator'
15
+ ORION_INS_GENERATOR = 'chenxran/orion-instance-generator'
16
+
17
+ RELATIONS = [
18
+ "Causes",
19
+ "HasProperty",
20
+ "MadeUpOf",
21
+ "isAfter",
22
+ "isBefore",
23
+ "xReact",
24
+ "xWant",
25
+ "xReason",
26
+ "xAttr",
27
+ "Desires",
28
+ ]
29
+
30
+
31
+ class BartInductor(object):
32
+ def __init__(
33
+ self,
34
+ group_beam=True,
35
+ continue_pretrain_instance_generator=True,
36
+ continue_pretrain_hypo_generator=True,
37
+ if_then=False
38
+ ):
39
+ self.if_then = if_then
40
+ self.orion_instance_generator_path = 'facebook/bart-large' if not continue_pretrain_instance_generator else ORION_INS_GENERATOR
41
+ self.orion_hypothesis_generator_path = 'facebook/bart-large' if not continue_pretrain_hypo_generator else ORION_HYPO_GENERATOR
42
+
43
+ if group_beam:
44
+ self.orion_hypothesis_generator = BartForConditionalGeneration_GroupBeam.from_pretrained(self.orion_hypothesis_generator_path).cuda().eval().half()
45
+ else:
46
+ self.orion_hypothesis_generator = BartForConditionalGeneration.from_pretrained(self.orion_hypothesis_generator_path).cuda().eval().half()
47
+
48
+ self.orion_instance_generator = BartForConditionalGeneration.from_pretrained(self.orion_instance_generator_path).cuda().eval().half()
49
+
50
+ self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
51
+ self.word_length = 2
52
+
53
+ self.stop_sub_list = ['he', 'she', 'this', 'that', 'and', 'it', 'which', 'who', 'whose', 'there', 'they', '.', 'its', 'one',
54
+ 'i', ',', 'the', 'nobody', 'his', 'her', 'also', 'only', 'currently', 'here', '()', 'what', 'where',
55
+ 'why', 'a', 'some', '"', ')', '(', 'now', 'everyone', 'everybody', 'their', 'often', 'usually', 'you',
56
+ '-', '?', ';', 'in', 'on', 'each', 'both', 'him', 'typically', 'mostly', 'sometimes', 'normally',
57
+ 'always', 'usually', 'still', 'today', 'was', 'were', 'but', 'although', 'current', 'all', 'have',
58
+ 'has', 'later', 'with', 'most', 'nowadays', 'then', 'every', 'when', 'someone', 'anyone', 'somebody',
59
+ 'anybody', 'any', 'being', 'get', 'getting', 'thus', 'under', 'even', 'for', 'can', 'rarely', 'never',
60
+ 'may', 'generally', 'other', 'another', 'too', 'first', 'second', 'third', 'mainly', 'primarily',
61
+ 'having', 'have', 'has']
62
+
63
+ self.stop_size = len(self.stop_sub_list)
64
+ for i in range(self.stop_size):
65
+ if self.stop_sub_list[i][0].isalpha():
66
+ temp = self.stop_sub_list[i][0].upper() + self.stop_sub_list[i][1:]
67
+ self.stop_sub_list.append(temp)
68
+
69
+ self.bad_words_ids = [self.tokenizer.encode(bad_word)[1:-1] for bad_word in ['also', ' also']]
70
+ stop_index = self.tokenizer(self.stop_sub_list, max_length=4, padding=True)
71
+ stop_index = torch.tensor(stop_index['input_ids'])[:, 1]
72
+ stop_weight = torch.zeros(1, self.tokenizer.vocab_size).cuda()
73
+ stop_weight[0, stop_index] -= 100
74
+ self.stop_weight = stop_weight[0, :]
75
+
76
+ def clean(self, text):
77
+ segments = text.split('<mask>')
78
+ if len(segments) == 3 and segments[2].startswith('.'):
79
+ return '<mask>'.join(segments[:2]) + '<mask>.'
80
+ else:
81
+ return text
82
+
83
+ def generate(self, inputs, k=10, topk=10):
84
+ with torch.no_grad():
85
+ tB_probs = self.generate_rule(inputs, k)
86
+ ret = [t[0].replace('<ent0>','<mask>').replace('<ent1>','<mask>') for t in tB_probs]
87
+
88
+ new_ret = []
89
+ for temp in ret:
90
+ temp = self.clean(temp.strip())
91
+ if len(new_ret) < topk and temp not in new_ret:
92
+ new_ret.append(temp)
93
+
94
+ return new_ret
95
+
96
+ def explore_mask(self, tA, k, tokens, prob, required_token, probs):
97
+ if required_token == 0:
98
+ return [[tokens, prob, probs]]
99
+ if required_token <= self.word_length:
100
+ k = min(k, 2)
101
+ ret = []
102
+ generated_ids = self.tokenizer(tA, max_length=128, padding='longest', return_tensors='pt') # ["input_ids"].cuda()
103
+ for key in generated_ids.keys():
104
+ generated_ids[key] = generated_ids[key].cuda()
105
+ mask_index = torch.where(generated_ids["input_ids"][0] == self.tokenizer.mask_token_id)
106
+ generated_ret = self.orion_instance_generator(**generated_ids)
107
+ #logits = generated_ret.logits
108
+ logits = generated_ret[0]
109
+ softmax = F.softmax(logits, dim=-1)
110
+ mask_word = softmax[0, mask_index[0][0], :] + self.stop_weight
111
+ top_k = torch.topk(mask_word, k, dim=0)
112
+ for i in range(top_k[1].size(0)):
113
+ token_s = top_k[1][i]
114
+ prob_s = top_k[0][i].item()
115
+ token_this = self.tokenizer.decode([token_s]).strip()
116
+ if token_this[0].isalpha() == False or len(token_this) <= 2:
117
+ continue
118
+ index_s = tA.index(self.tokenizer.mask_token)
119
+ tAs = tA[:index_s] + token_this + tA[index_s + len(self.tokenizer.mask_token):]
120
+ tokens_this = [t for t in tokens]
121
+ tokens_this.append(token_this)
122
+ probs_new = deepcopy(probs)
123
+ probs_new.append(prob_s)
124
+ ret.extend(self.explore_mask(tAs, 1, tokens_this, prob_s * prob, required_token - 1,probs_new))
125
+ return ret
126
+
127
+ def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
128
+ spans = [t.lower().strip() for t in tA[:-1].split('<mask>')]
129
+ generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].cuda()
130
+ generated_ret = self.orion_instance_generator.generate(generated_ids, num_beams=max(120, k),
131
+ #num_beam_groups=max(120, k),
132
+ max_length=generated_ids.size(1) + 15,
133
+ num_return_sequences=max(120, k), #min_length=generated_ids.size(1),
134
+ #diversity_penalty=2.0,
135
+ #length_penalty= 0.8,
136
+ #early_stopping=True, bad_words_ids=bad_words_ids, no_repeat_ngram_size=2,
137
+ output_scores=True,
138
+ return_dict_in_generate=True)
139
+ summary_ids = generated_ret['sequences']
140
+ probs = F.softmax(generated_ret['sequences_scores'])
141
+ txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summary_ids]
142
+ ret = []
143
+
144
+ for i, txt in enumerate(txts):
145
+ if tA.endswith('.'):
146
+ if txt.endswith('.'):
147
+ txt = txt[:-1].strip()
148
+ txt += '.'
149
+ word_imcomplete = False
150
+ prob = probs[i].item()
151
+ words_i = []
152
+
153
+ start_index = 0
154
+ for j in range(len(spans)-1):
155
+ span1 = spans[j]
156
+ span2 = spans[j+1]
157
+ if (span1 in txt.lower()[start_index:]) and (span2 in txt.lower()[start_index:]):
158
+ index1 = txt.lower().index(span1,start_index)+len(span1)
159
+ if span2 == '':
160
+ if txt[-1] == '.':
161
+ index2 = len(txt) -1
162
+ else:
163
+ index2 = len(txt)
164
+ else:
165
+ index2 = txt.lower().index(span2, start_index)
166
+
167
+ words_i.append(txt[index1:index2].strip())
168
+ start_index = index2
169
+ #if words_i[-1] == '':
170
+ # word_imcomplete = True
171
+ else:
172
+ word_imcomplete = True
173
+ if word_imcomplete:
174
+ # if print_it:
175
+ # print(txt + '\t' + tA + '\t' + '×')
176
+ continue
177
+
178
+
179
+ ret.append([words_i, prob])
180
+ return sorted(ret, key=lambda x: x[1], reverse=True)[:k]
181
+
182
+
183
+ def extract_words_for_tA(self, tA, k=6):
184
+ word_mask_str = ' '.join([self.tokenizer.mask_token] * self.word_length)
185
+ tA = tA.replace('<mask>', word_mask_str)
186
+ mask_count = tA.count(self.tokenizer.mask_token)
187
+ mask_probs = self.explore_mask(tA, k*20, [], 1.0, mask_count, [])
188
+ ret = []
189
+ visited_mask_txt = {}
190
+ for mask, prob, probs in mask_probs:
191
+ mask_txt = ' '.join(mask).lower()
192
+ if mask_txt in visited_mask_txt:
193
+ continue
194
+ visited_mask_txt[mask_txt] = 1
195
+ words = []
196
+ probs_words = []
197
+ for i in range(0,mask_count, self.word_length):
198
+ words.append(' '.join(mask[i: i + self.word_length]))
199
+ prob_word = 1.0
200
+ for j in range(i, i + self.word_length):
201
+ prob_word *= probs[j]
202
+ probs_words.append(prob_word)
203
+ ret.append([words, prob, probs_words])
204
+ return sorted(ret, key=lambda x: x[1], reverse=True)[:k]
205
+
206
+ def extract_templateBs_batch(self, words_prob, tA, k, print_it = False):
207
+ words_prob_sorted = []
208
+ for (words, probA, *_) in words_prob:
209
+ tokenized_word = self.tokenizer(words[0])
210
+ words_prob_sorted.append([words,probA,len(tokenized_word['input_ids'])])
211
+ words_prob_sorted.sort(key=lambda x:x[2])
212
+
213
+ batch_size = 8
214
+ templates = []
215
+ index_words = {}
216
+ ret = {}
217
+ num_beams = k
218
+ for enum, (words, probA, *_) in enumerate(words_prob_sorted):
219
+ template = construct_template(words, tA, self.if_then)
220
+ templates.extend(template)
221
+ for t in template:
222
+ index_words[len(index_words)] = '\t'.join(words)
223
+ # index_words[len(templates)-1] = '\t'.join(words)
224
+ if (len(templates) == batch_size) or enum==len(words_prob_sorted)-1 or (words_prob_sorted[enum+1][2]!=words_prob_sorted[enum][2]):
225
+ generated_ids = self.tokenizer(templates, padding="longest", return_tensors='pt')['input_ids'].cuda()
226
+ generated_ret = self.orion_hypothesis_generator.generate(generated_ids, num_beams=num_beams,
227
+ num_beam_groups=num_beams,
228
+ max_length=28, #template_length+5,
229
+ num_return_sequences=num_beams, min_length=3,
230
+ diversity_penalty=1.0,
231
+ early_stopping=True,
232
+ #length_penalty = 0.1,
233
+ bad_words_ids=self.bad_words_ids,
234
+ #no_repeat_ngram_size=2,
235
+ output_scores=True,
236
+ return_dict_in_generate=True, decoder_ori_input_ids = generated_ids,
237
+ top_p=0.95,
238
+ )
239
+ summary_ids = generated_ret['sequences'].reshape((len(templates),num_beams,-1))
240
+ probs = F.softmax(generated_ret['sequences_scores'].reshape((len(templates),num_beams)),dim=1)
241
+ for ii in range(summary_ids.size(0)):
242
+ txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in
243
+ summary_ids[ii]]
244
+ ii_template = []
245
+ words_ii = index_words[ii].split('\t')
246
+ for i, txt in enumerate(txts):
247
+ prob = probs[ii][i].item() * probA
248
+
249
+ txt = txt.lower()
250
+ txt = post_process_template(txt)
251
+
252
+ words_ii_matched = [word.lower() for word in words_ii] #extract_similar_words(txt, words_ii)
253
+ if words_ii_matched is None:
254
+ prob = 0.0
255
+ else:
256
+ for j, word in enumerate(words_ii_matched):
257
+ if word not in txt:
258
+ prob = 0.0
259
+ else:
260
+ txt = txt.replace(word, '<ent{}>'.format(j), 1)
261
+
262
+ if txt.count(' ')+1<=3:
263
+ continue
264
+
265
+ ii_template.append([txt, prob])
266
+ # if print_it:
267
+ # print(index_words[ii]+'\t'+str(convert_for_print(ii_template)))
268
+ for template, prob in ii_template:
269
+ if template not in ret:
270
+ ret[template] = 0.0
271
+ ret[template] += prob
272
+ templates.clear()
273
+ index_words.clear()
274
+
275
+ return ret
276
+
277
+ def generate_rule(self, tA, k=10, print_it = False):
278
+ tA=formalize_tA(tA)
279
+ if 'bart' in str(self.orion_instance_generator.__class__).lower():
280
+ words_prob = self.extract_words_for_tA_bart(tA, k,print_it=print_it)
281
+ words_prob = filter_words(words_prob)[:k]
282
+ # if print_it:
283
+ # print(convert_for_print(words_prob))
284
+ else:
285
+ words_prob = self.extract_words_for_tA(tA, k)
286
+ words_prob = filter_words(words_prob)[:k]
287
+
288
+ tB_prob = self.extract_templateBs_batch(words_prob, tA, k,print_it=print_it)
289
+
290
+ ret = []
291
+ for k1 in tB_prob:
292
+ ret.append([k1, tB_prob[k1]])
293
+ ret = sorted(ret, key=lambda x: x[1], reverse=True)[:k]
294
+ if self.if_then:
295
+ for i, temp in enumerate(ret):
296
+ sentence = temp[0]
297
+ if "then" in sentence:
298
+ sentence = sentence.split("then")[-1]
299
+ else:
300
+ sentence = sentence.replace("if", "")
301
+ ret[i][0] = sentence
302
+ return ret
303
+
304
+
305
+ class CometInductor(object):
306
+ def __init__(self):
307
+ self.model = AutoModelForSeq2SeqLM.from_pretrained("adamlin/comet-atomic_2020_BART").cuda().eval() # .half()
308
+ self.tokenizer = AutoTokenizer.from_pretrained("adamlin/comet-atomic_2020_BART")
309
+ self.task = "summarization"
310
+ self.use_task_specific_params()
311
+ self.decoder_start_token_id = None
312
+
313
+ def drop_repeat(self, old_list):
314
+ new_list = []
315
+ for item in old_list:
316
+ if item not in new_list:
317
+ new_list.append(item)
318
+
319
+ return new_list
320
+
321
+ def chunks(self, lst, n):
322
+ """Yield successive n-sized chunks from lst."""
323
+ for i in range(0, len(lst), n):
324
+ yield lst[i : i + n]
325
+
326
+ def use_task_specific_params(self):
327
+ """Update config with summarization specific params."""
328
+ task_specific_params = self.model.config.task_specific_params
329
+
330
+ if task_specific_params is not None:
331
+ pars = task_specific_params.get(self.task, {})
332
+ self.model.config.update(pars)
333
+
334
+ def trim_batch(
335
+ self, input_ids, pad_token_id, attention_mask=None,
336
+ ):
337
+ """Remove columns that are populated exclusively by pad_token_id"""
338
+ keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
339
+ if attention_mask is None:
340
+ return input_ids[:, keep_column_mask]
341
+ else:
342
+ return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
343
+
344
+ def generate(self, inputs, k, topk):
345
+ outputs = []
346
+ words = ['PersonX', 'PersonY']
347
+ for i, _ in enumerate(re.findall("<mask>", inputs)):
348
+ index = inputs.index('<mask>')
349
+ inputs = inputs[:index] + words[i] + inputs[index + len('<mask>'):]
350
+
351
+ for relation in RELATIONS:
352
+ inputs = "{} {} [GEN]".format(inputs[:-1], relation)
353
+ gen = self.generate_(inputs, num_generate=10)
354
+ switch = 0
355
+ for output in gen[0]:
356
+ output = output.strip()
357
+ if re.search("PersonX|X", output) and re.search("PersonY|Y", output):
358
+ temp = re.sub("PersonX|X|PersonY|Y", "<mask>", output.strip())
359
+ if temp.endswith("."):
360
+ outputs.append(temp)
361
+ else:
362
+ outputs.append(temp + ".")
363
+ switch = 1
364
+ break
365
+
366
+ if switch == 0:
367
+ output = gen[0][0]
368
+ temp = re.sub("PersonX|X|PersonY|Y", "<mask>", output.strip())
369
+ if temp.endswith("."):
370
+ outputs.append(temp)
371
+ else:
372
+ outputs.append(temp + ".")
373
+
374
+ outputs = [output.replace('PersonX', '<mask>').replace('PersonY', '<mask>') for output in outputs]
375
+ return outputs
376
+
377
+ def generate_(
378
+ self,
379
+ queries,
380
+ decode_method="beam",
381
+ num_generate=5,
382
+ ):
383
+
384
+ with torch.no_grad():
385
+ decs = []
386
+ batch = self.tokenizer(queries, return_tensors="pt", padding="longest")
387
+ input_ids, attention_mask = self.trim_batch(**batch, pad_token_id=self.tokenizer.pad_token_id)
388
+
389
+ summaries = self.model.generate(
390
+ input_ids=input_ids.cuda(),
391
+ attention_mask=attention_mask.cuda(),
392
+ decoder_start_token_id=self.decoder_start_token_id,
393
+ num_beams=num_generate,
394
+ num_return_sequences=num_generate,
395
+ )
396
+
397
+ dec = self.tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
398
+ decs.append(dec)
399
+
400
+ return decs
401
+
src/__pycache__/bart_with_group_beam.cpython-38.pyc ADDED
Binary file (17.7 kB). View file
 
src/bart_with_group_beam.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.bart import BartForConditionalGeneration
2
+ import torch
3
+ from transformers.generation_beam_search import BeamScorer
4
+ from abc import ABC, abstractmethod
5
+ from collections import UserDict
6
+ from typing import Optional, Tuple, Union, Dict, Any
7
+ from transformers.generation_logits_process import LogitsProcessorList
8
+ from transformers.generation_utils import BeamSearchEncoderDecoderOutput,BeamSearchDecoderOnlyOutput
9
+ from torch.nn import functional as F
10
+ from transformers.file_utils import ModelOutput
11
+ import torch.nn
12
+
13
+ BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
14
+
15
+
16
+ class BartForConditionalGeneration_GroupBeam(BartForConditionalGeneration):
17
+
18
+
19
+ def beam_search(
20
+ self,
21
+ input_ids: torch.LongTensor,
22
+ beam_scorer: BeamScorer,
23
+ logits_processor: Optional[LogitsProcessorList] = None,
24
+ max_length: Optional[int] = None,
25
+ pad_token_id: Optional[int] = None,
26
+ eos_token_id: Optional[int] = None,
27
+ output_attentions: Optional[bool] = None,
28
+ output_hidden_states: Optional[bool] = None,
29
+ output_scores: Optional[bool] = None,
30
+ return_dict_in_generate: Optional[bool] = None,
31
+ **model_kwargs,
32
+ ) -> Union[BeamSearchOutput, torch.LongTensor]:
33
+ r"""
34
+ Generates sequences for models with a language modeling head using beam search decoding.
35
+
36
+ Parameters:
37
+
38
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
39
+ The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
40
+ :obj:`torch.LongTensor` of shape :obj:`(1,)`.
41
+ beam_scorer (:obj:`BeamScorer`):
42
+ An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are
43
+ constructed, stored and sorted during generation. For more information, the documentation of
44
+ :class:`~transformers.BeamScorer` should be read.
45
+ logits_processor (:obj:`LogitsProcessorList`, `optional`):
46
+ An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
47
+ :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
48
+ head applied at each generation step.
49
+ max_length (:obj:`int`, `optional`, defaults to 20):
50
+ The maximum length of the sequence to be generated.
51
+ pad_token_id (:obj:`int`, `optional`):
52
+ The id of the `padding` token.
53
+ eos_token_id (:obj:`int`, `optional`):
54
+ The id of the `end-of-sequence` token.
55
+ output_attentions (:obj:`bool`, `optional`, defaults to `False`):
56
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
57
+ returned tensors for more details.
58
+ output_hidden_states (:obj:`bool`, `optional`, defaults to `False`):
59
+ Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors
60
+ for more details.
61
+ output_scores (:obj:`bool`, `optional`, defaults to `False`):
62
+ Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
63
+ return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
64
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
65
+ model_kwargs:
66
+ Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
67
+ model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
68
+
69
+ Return:
70
+ :class:`~transformers.generation_utilsBeamSearchDecoderOnlyOutput`,
71
+ :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A
72
+ :obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a
73
+ :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if
74
+ ``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a
75
+ :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if
76
+ ``model.config.is_encoder_decoder=True``.
77
+
78
+
79
+ Examples::
80
+
81
+ >>> from transformers import (
82
+ ... AutoTokenizer,
83
+ ... AutoModelForSeq2SeqLM,
84
+ ... LogitsProcessorList,
85
+ ... MinLengthLogitsProcessor,
86
+ ... BeamSearchScorer,
87
+ ... )
88
+ >>> import torch
89
+
90
+ >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
91
+ >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
92
+
93
+ >>> encoder_input_str = "translate English to German: How old are you?"
94
+ >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
95
+
96
+
97
+ >>> # lets run beam search using 3 beams
98
+ >>> num_beams = 3
99
+ >>> # define decoder start token ids
100
+ >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
101
+ >>> input_ids = input_ids * model.config.decoder_start_token_id
102
+
103
+ >>> # add encoder_outputs to model keyword arguments
104
+ >>> model_kwargs = {
105
+ ... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
106
+ ... }
107
+
108
+ >>> # instantiate beam scorer
109
+ >>> beam_scorer = BeamSearchScorer(
110
+ ... batch_size=1,
111
+ ... max_length=model.config.max_length,
112
+ ... num_beams=num_beams,
113
+ ... device=model.device,
114
+ ... )
115
+
116
+ >>> # instantiate logits processors
117
+ >>> logits_processor = LogitsProcessorList([
118
+ ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
119
+ ... ])
120
+
121
+ >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
122
+
123
+ >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
124
+ """
125
+
126
+ # init values
127
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
128
+ max_length = max_length if max_length is not None else self.config.max_length
129
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
130
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
131
+ output_scores = output_scores if output_scores is not None else self.config.output_scores
132
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
133
+ output_hidden_states = (
134
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
135
+ )
136
+ return_dict_in_generate = (
137
+ return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
138
+ )
139
+
140
+ # init attention / hidden states / scores tuples
141
+ scores = () if (return_dict_in_generate and output_scores) else None
142
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
143
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
144
+
145
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
146
+ if return_dict_in_generate and self.config.is_encoder_decoder:
147
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
148
+ encoder_hidden_states = (
149
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
150
+ )
151
+
152
+ batch_size = len(beam_scorer._beam_hyps)
153
+ num_beams = beam_scorer.num_beams
154
+
155
+ batch_beam_size, cur_len = input_ids.shape
156
+
157
+ assert (
158
+ num_beams * batch_size == batch_beam_size
159
+ ), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
160
+
161
+ beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
162
+ beam_scores[:, 1:] = -1e9
163
+ beam_scores = beam_scores.view((batch_size * num_beams,))
164
+
165
+ while cur_len < max_length:
166
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
167
+
168
+ outputs = self(
169
+ **model_inputs,
170
+ return_dict=True,
171
+ output_attentions=output_attentions,
172
+ output_hidden_states=output_hidden_states,
173
+ )
174
+ next_token_logits = outputs.logits[:, -1, :]
175
+
176
+ # adjust tokens for Bart, *e.g.*
177
+ next_token_logits = self.adjust_logits_during_generation(
178
+ next_token_logits, cur_len=cur_len, max_length=max_length
179
+ )
180
+
181
+ next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
182
+
183
+ next_token_scores = logits_processor(input_ids, next_token_scores)
184
+ next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
185
+
186
+ # Store scores, attentions and hidden_states when required
187
+ if return_dict_in_generate:
188
+ if output_scores:
189
+ scores += (next_token_scores,)
190
+ if output_attentions:
191
+ decoder_attentions += (
192
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
193
+ )
194
+
195
+ if output_hidden_states:
196
+ decoder_hidden_states += (
197
+ (outputs.decoder_hidden_states,)
198
+ if self.config.is_encoder_decoder
199
+ else (outputs.hidden_states,)
200
+ )
201
+
202
+ # reshape for beam search
203
+ vocab_size = next_token_scores.shape[-1]
204
+ next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
205
+ #m = torch.nn.LayerNorm(num_beams * vocab_size)
206
+ #next_token_scores = m(next_token_scores)
207
+
208
+ next_token_scores_group = torch.sum(next_token_scores,dim=0,keepdim=True).expand(batch_size,-1) / batch_size
209
+
210
+ for i in range(next_token_scores.size(0)):
211
+ '''tmin = torch.min(next_token_scores_group[i])
212
+ for j in range(1,len(model_kwargs['decoder_ori_input_ids'][i])):
213
+ next_token_scores_group[i][model_kwargs['decoder_ori_input_ids'][i][j]] = tmin'''
214
+ for t in model_kwargs['decoder_ori_input_ids'][i]:
215
+ for j in range(num_beams):
216
+ #if t not in input_ids[i] or t==1:
217
+ next_token_scores_group[i][j * vocab_size + t] = next_token_scores[i][j * vocab_size + t]
218
+
219
+ next_token_scores, next_tokens = torch.topk(
220
+ next_token_scores_group, 2 * num_beams, dim=1, largest=True, sorted=True)
221
+
222
+ '''next_token_scores_group = next_token_scores_group.expand(batch_size,-1)
223
+ next_tokens_group = next_tokens_group.expand(batch_size,-1)
224
+
225
+ next_token_scores, next_tokens = torch.topk(
226
+ next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
227
+ )
228
+
229
+ for i in range(next_token_scores.size(0)):
230
+ j1 = 0
231
+ for j in range(next_token_scores.size(1)):
232
+ if next_tokens[i][j] not in model_kwargs['decoder_ori_input_ids'][i]:
233
+ next_tokens[i][j] = next_tokens_group[i][j1]
234
+ j1 += 1
235
+ next_token_scores = next_token_scores_group
236
+
237
+ del next_token_scores_group, next_tokens_group'''
238
+
239
+ next_indices = next_tokens // vocab_size
240
+ next_tokens = next_tokens % vocab_size
241
+
242
+ # stateless
243
+ beam_outputs = beam_scorer.process(
244
+ input_ids,
245
+ next_token_scores,
246
+ next_tokens,
247
+ next_indices,
248
+ pad_token_id=pad_token_id,
249
+ eos_token_id=eos_token_id,
250
+ )
251
+ beam_scores = beam_outputs["next_beam_scores"]
252
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
253
+ beam_idx = beam_outputs["next_beam_indices"]
254
+
255
+ input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
256
+
257
+ cur_len = cur_len + 1
258
+
259
+ model_kwargs = self._update_model_kwargs_for_generation(
260
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
261
+ )
262
+ if model_kwargs["past"] is not None:
263
+ model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
264
+
265
+ if beam_scorer.is_done:
266
+ break
267
+
268
+ sequence_outputs = beam_scorer.finalize(
269
+ input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
270
+ )
271
+
272
+ if return_dict_in_generate:
273
+ if not output_scores:
274
+ sequence_outputs["sequence_scores"] = None
275
+ if self.config.is_encoder_decoder:
276
+ return BeamSearchEncoderDecoderOutput(
277
+ sequences=sequence_outputs["sequences"],
278
+ sequences_scores=sequence_outputs["sequence_scores"],
279
+ scores=scores,
280
+ encoder_attentions=encoder_attentions,
281
+ encoder_hidden_states=encoder_hidden_states,
282
+ decoder_attentions=decoder_attentions,
283
+ decoder_hidden_states=decoder_hidden_states,
284
+ )
285
+ else:
286
+ return BeamSearchDecoderOnlyOutput(
287
+ sequences=sequence_outputs["sequences"],
288
+ sequences_scores=sequence_outputs["sequence_scores"],
289
+ scores=scores,
290
+ attentions=decoder_attentions,
291
+ hidden_states=decoder_hidden_states,
292
+ )
293
+ else:
294
+ return sequence_outputs["sequences"]
295
+
296
+ def group_beam_search(
297
+ self,
298
+ input_ids: torch.LongTensor,
299
+ beam_scorer: BeamScorer,
300
+ logits_processor: Optional[LogitsProcessorList] = None,
301
+ max_length: Optional[int] = None,
302
+ pad_token_id: Optional[int] = None,
303
+ eos_token_id: Optional[int] = None,
304
+ output_attentions: Optional[bool] = None,
305
+ output_hidden_states: Optional[bool] = None,
306
+ output_scores: Optional[bool] = None,
307
+ return_dict_in_generate: Optional[bool] = None,
308
+ **model_kwargs,
309
+ ):
310
+ r"""
311
+ Generates sequences for models with a language modeling head using beam search decoding.
312
+
313
+ Parameters:
314
+
315
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
316
+ The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
317
+ :obj:`torch.LongTensor` of shape :obj:`(1,)`.
318
+ beam_scorer (:obj:`BeamScorer`):
319
+ An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are
320
+ constructed, stored and sorted during generation. For more information, the documentation of
321
+ :class:`~transformers.BeamScorer` should be read.
322
+ logits_processor (:obj:`LogitsProcessorList`, `optional`):
323
+ An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
324
+ :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
325
+ head applied at each generation step.
326
+ max_length (:obj:`int`, `optional`, defaults to 20):
327
+ The maximum length of the sequence to be generated.
328
+ pad_token_id (:obj:`int`, `optional`):
329
+ The id of the `padding` token.
330
+ eos_token_id (:obj:`int`, `optional`):
331
+ The id of the `end-of-sequence` token.
332
+ output_attentions (:obj:`bool`, `optional`, defaults to `False`):
333
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
334
+ returned tensors for more details.
335
+ output_hidden_states (:obj:`bool`, `optional`, defaults to `False`):
336
+ Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors
337
+ for more details.
338
+ output_scores (:obj:`bool`, `optional`, defaults to `False`):
339
+ Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
340
+ return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
341
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
342
+ model_kwargs:
343
+ Additional model specific kwargs that will be forwarded to the :obj:`forward` function of the model. If
344
+ model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
345
+
346
+ Return:
347
+ :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput`,
348
+ :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A
349
+ :obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a
350
+ :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if
351
+ :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if
352
+ ``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a
353
+ :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if
354
+ ``model.config.is_encoder_decoder=True``.
355
+
356
+ Examples::
357
+
358
+ >>> from transformers import (
359
+ ... AutoTokenizer,
360
+ ... AutoModelForSeq2SeqLM,
361
+ ... LogitsProcessorList,
362
+ ... MinLengthLogitsProcessor,
363
+ ... HammingDiversityLogitsProcessor,
364
+ ... BeamSearchScorer,
365
+ ... )
366
+ >>> import torch
367
+
368
+ >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
369
+ >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
370
+
371
+ >>> encoder_input_str = "translate English to German: How old are you?"
372
+ >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
373
+
374
+
375
+ >>> # lets run diverse beam search using 6 beams
376
+ >>> num_beams = 6
377
+ >>> # define decoder start token ids
378
+ >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
379
+ >>> input_ids = input_ids * model.config.decoder_start_token_id
380
+
381
+ >>> # add encoder_outputs to model keyword arguments
382
+ >>> model_kwargs = {
383
+ ... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
384
+ ... }
385
+
386
+ >>> # instantiate beam scorer
387
+ >>> beam_scorer = BeamSearchScorer(
388
+ ... batch_size=1,
389
+ ... max_length=model.config.max_length,
390
+ ... num_beams=num_beams,
391
+ ... device=model.device,
392
+ ... num_beam_groups=3
393
+ ... )
394
+
395
+ >>> # instantiate logits processors
396
+ >>> logits_processor = LogitsProcessorList([
397
+ ... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3),
398
+ ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
399
+ ... ])
400
+
401
+ >>> outputs = model.group_beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
402
+
403
+ >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
404
+ """
405
+
406
+ # init values
407
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
408
+ max_length = max_length if max_length is not None else self.config.max_length
409
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
410
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
411
+ output_scores = output_scores if output_scores is not None else self.config.output_scores
412
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
413
+ output_hidden_states = (
414
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
415
+ )
416
+ return_dict_in_generate = (
417
+ return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
418
+ )
419
+
420
+ # init attention / hidden states / scores tuples
421
+ scores = () if (return_dict_in_generate and output_scores) else None
422
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
423
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
424
+
425
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
426
+ if return_dict_in_generate and self.config.is_encoder_decoder:
427
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
428
+ encoder_hidden_states = (
429
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
430
+ )
431
+
432
+ batch_size = len(beam_scorer._beam_hyps)
433
+ num_beams = beam_scorer.num_beams
434
+ num_beam_groups = beam_scorer.num_beam_groups
435
+ num_sub_beams = num_beams // num_beam_groups
436
+ device = input_ids.device
437
+
438
+ batch_beam_size, cur_len = input_ids.shape
439
+
440
+ assert (
441
+ num_beams * batch_size == batch_beam_size
442
+ ), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
443
+
444
+ beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
445
+ # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
446
+ # the same group don't produce same tokens everytime.
447
+ beam_scores[:, ::num_sub_beams] = 0
448
+ beam_scores = beam_scores.view((batch_size * num_beams,))
449
+
450
+ while cur_len < max_length:
451
+ # predicted tokens in cur_len step
452
+ current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
453
+
454
+ # indices which will form the beams in the next time step
455
+ reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
456
+
457
+ # do one decoder step on all beams of all sentences in batch
458
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
459
+ outputs = self(
460
+ **model_inputs,
461
+ return_dict=True,
462
+ output_attentions=output_attentions,
463
+ output_hidden_states=output_hidden_states,
464
+ )
465
+
466
+ for beam_group_idx in range(num_beam_groups):
467
+ group_start_idx = beam_group_idx * num_sub_beams
468
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
469
+ group_size = group_end_idx - group_start_idx
470
+
471
+ # indices of beams of current group among all sentences in batch
472
+ batch_group_indices = []
473
+
474
+ if output_scores:
475
+ processed_score = torch.zeros_like(outputs.logits[:, -1, :]).half() # .float()
476
+
477
+ for batch_idx in range(batch_size):
478
+ batch_group_indices.extend(
479
+ [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
480
+ )
481
+ group_input_ids = input_ids[batch_group_indices]
482
+
483
+ # select outputs of beams of current group only
484
+ next_token_logits = outputs.logits[batch_group_indices, -1, :]
485
+
486
+ # adjust tokens for Bart, *e.g.*
487
+ next_token_logits = self.adjust_logits_during_generation(
488
+ next_token_logits, cur_len=cur_len, max_length=max_length
489
+ )
490
+
491
+ next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
492
+ vocab_size = next_token_scores.shape[-1]
493
+
494
+ next_token_scores = logits_processor(
495
+ group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
496
+ )
497
+ next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as(
498
+ next_token_scores
499
+ )
500
+
501
+ if output_scores:
502
+ processed_score[batch_group_indices] = next_token_scores.half() # .float()
503
+
504
+ # reshape for beam search
505
+ next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
506
+ ###
507
+
508
+ next_token_scores_group = torch.sum(next_token_scores, dim=0, keepdim=True).expand(batch_size,
509
+ -1) / batch_size
510
+
511
+ for i in range(next_token_scores.size(0)):
512
+ '''tmin = torch.min(next_token_scores_group[i])
513
+ for j in range(1,len(model_kwargs['decoder_ori_input_ids'][i])):
514
+ next_token_scores_group[i][model_kwargs['decoder_ori_input_ids'][i][j]] = tmin'''
515
+ for t in model_kwargs['decoder_ori_input_ids'][i]:
516
+ for j in range(group_size):
517
+ # if t not in input_ids[i] or t==1:
518
+ next_token_scores_group[i][j * vocab_size + t] = next_token_scores[i][j * vocab_size + t]
519
+
520
+ next_token_scores, next_tokens = torch.topk(
521
+ next_token_scores_group, 2 * group_size, dim=1, largest=True, sorted=True)
522
+
523
+
524
+ ###
525
+ #next_token_scores, next_tokens = torch.topk(
526
+ # next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
527
+ #)
528
+
529
+ next_indices = next_tokens // vocab_size
530
+ next_tokens = next_tokens % vocab_size
531
+
532
+ # stateless
533
+ beam_outputs = beam_scorer.process(
534
+ group_input_ids,
535
+ next_token_scores,
536
+ next_tokens,
537
+ next_indices,
538
+ pad_token_id=pad_token_id,
539
+ eos_token_id=eos_token_id,
540
+ )
541
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
542
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
543
+ beam_idx = beam_outputs["next_beam_indices"]
544
+
545
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
546
+ group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
547
+ current_tokens[batch_group_indices] = group_input_ids[:, -1]
548
+
549
+ # (beam_idx // group_size) -> batch_idx
550
+ # (beam_idx % group_size) -> offset of idx inside the group
551
+ reordering_indices[batch_group_indices] = (
552
+ num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size)
553
+ )
554
+
555
+ # Store scores, attentions and hidden_states when required
556
+ if return_dict_in_generate:
557
+ if output_scores:
558
+ scores += (processed_score,)
559
+ if output_attentions:
560
+ decoder_attentions += (
561
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
562
+ )
563
+
564
+ if output_hidden_states:
565
+ decoder_hidden_states += (
566
+ (outputs.decoder_hidden_states,)
567
+ if self.config.is_encoder_decoder
568
+ else (outputs.hidden_states,)
569
+ )
570
+
571
+ model_kwargs = self._update_model_kwargs_for_generation(
572
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
573
+ )
574
+ if model_kwargs["past"] is not None:
575
+ model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices)
576
+
577
+ input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
578
+ cur_len = cur_len + 1
579
+ if beam_scorer.is_done:
580
+ break
581
+
582
+ sequence_outputs = beam_scorer.finalize(
583
+ input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=max_length,
584
+ )
585
+
586
+ if return_dict_in_generate:
587
+ if not output_scores:
588
+ sequence_outputs["sequence_scores"]
589
+ if self.config.is_encoder_decoder:
590
+ return BeamSearchEncoderDecoderOutput(
591
+ sequences=sequence_outputs["sequences"],
592
+ sequences_scores=sequence_outputs["sequence_scores"],
593
+ scores=scores,
594
+ encoder_attentions=encoder_attentions,
595
+ encoder_hidden_states=encoder_hidden_states,
596
+ decoder_attentions=decoder_attentions,
597
+ decoder_hidden_states=decoder_hidden_states,
598
+ )
599
+ else:
600
+ return BeamSearchDecoderOnlyOutput(
601
+ sequences=sequence_outputs["sequences"],
602
+ sequences_scores=sequence_outputs["sequence_scores"],
603
+ scores=scores,
604
+ attentions=decoder_attentions,
605
+ hidden_states=decoder_hidden_states,
606
+ )
607
+ else:
608
+ return sequence_outputs["sequences"]
src/distinct_n/.gitignore ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ state.py
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ env/
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+
27
+ # PyInstaller
28
+ # Usually these files are written by a python script from a template
29
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
30
+ *.manifest
31
+ *.spec
32
+
33
+ # Installer logs
34
+ pip-log.txt
35
+ pip-delete-this-directory.txt
36
+
37
+ # Unit test / coverage reports
38
+ htmlcov/
39
+ .tox/
40
+ .coverage
41
+ .coverage.*
42
+ .cache
43
+ nosetests.xml
44
+ coverage.xml
45
+ *,cover
46
+
47
+ # Translations
48
+ *.mo
49
+ *.pot
50
+
51
+ # Django stuff:
52
+ *.log
53
+
54
+ # Sphinx documentation
55
+ docs/_build/
56
+
57
+ # PyBuilder
58
+ target/
src/distinct_n/.idea/Distinct-N.iml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$">
5
+ <sourceFolder url="file://$MODULE_DIR$/distinct_n" isTestSource="false" />
6
+ <excludeFolder url="file://$MODULE_DIR$/docs" />
7
+ </content>
8
+ <orderEntry type="jdk" jdkName="Python 3.6 (Metrics)" jdkType="Python SDK" />
9
+ <orderEntry type="sourceFolder" forTests="false" />
10
+ </component>
11
+ </module>
src/distinct_n/.idea/encodings.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Encoding" addBOMForNewFiles="with NO BOM" />
4
+ </project>
src/distinct_n/.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="JavaScriptSettings">
4
+ <option name="languageLevel" value="ES6" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (tensorflow)" project-jdk-type="Python SDK" />
7
+ </project>
src/distinct_n/.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/Distinct-N.iml" filepath="$PROJECT_DIR$/.idea/Distinct-N.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
src/distinct_n/.idea/other.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="PySciProjectComponent">
4
+ <option name="PY_SCI_VIEW_SUGGESTED" value="true" />
5
+ </component>
6
+ </project>
src/distinct_n/.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ </component>
6
+ </project>
src/distinct_n/.idea/webResources.xml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="WebResourcesPaths">
4
+ <contentEntries>
5
+ <entry url="file://$PROJECT_DIR$">
6
+ <entryData>
7
+ <resourceRoots>
8
+ <path value="file://$PROJECT_DIR$/testdata" />
9
+ </resourceRoots>
10
+ </entryData>
11
+ </entry>
12
+ </contentEntries>
13
+ </component>
14
+ </project>
src/distinct_n/A Diversity-Promoting Objective Function for Neural Conversation Models.pdf ADDED
Binary file (200 kB). View file
 
src/distinct_n/LICENSE.txt ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
src/distinct_n/README.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Distinct-N
2
+ Distinct-N, most notably distinct-1 and distinct-2, is metric that measures the
3
+ diversity of a sentence. It focuses on the number of *distinct* n-gram of a sentence and thus
4
+ penalizes sentences with lots of repeated words. The metric is free of any *reference* or *ground truth*
5
+ sentence and devotes totally to the property of a sentence (generated by the system).
6
+ It is proposed by Jiwei Li et.al in the paper *A Diversity-Promoting Objective Function for Neural Conversation Models*.
7
+
8
+ # Definitions
9
+ The original paper coined *Distinct-N* as:
10
+
11
+ We report degree of diversity by calculating the number of distinct unigrams and bigrams in generated responses.
12
+ The value is scaled by total number of generated tokens to avoid favoring long sentences
13
+
14
+ which is exactly what we have mentioned before.
15
+
16
+ # Usage
17
+ ```bash
18
+ $ python distinct_metric.py -n N_NGRAMS PREDICTION
19
+ ```
20
+
21
+
22
+ where `N_GRAMS` is the length of token sequence to count as unique within one sentence.
23
+ `PREDICTION` is the prediction or response your model generates with one utterance (sentence) per line.
24
+
25
+
26
+ # Dependencies
27
+ `python>=3.6.1`
28
+
29
+ # References
30
+ [1] A Diversity-Promoting Objective Function for Neural Conversation Models
src/distinct_n/bin/distinct_metric.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+
4
+ from distinct_n import distinct_n_sentence_level
5
+ from pathlib import Path
6
+ from agenda.metric_helper import write_score
7
+
8
+ NAME = 'distinct_n'
9
+
10
+ if __name__ == '__main__':
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument('hypothesis', help="predicted text file, one example per line")
13
+ parser.add_argument('-n', dest='n_range', type=int, nargs='+', help="n to use as in distinct-N")
14
+ parser.add_argument('--output_dir')
15
+ args = parser.parse_args()
16
+
17
+ logging.basicConfig(level=logging.INFO)
18
+ logging.info('loading hypothesis file...')
19
+ with open(args.hypothesis) as f:
20
+ hypothesis = [sentence.split() for sentence in f.readlines()]
21
+
22
+ output_dir = Path(args.output_dir)
23
+ for n in args.n_range:
24
+ write_score(
25
+ name=NAME,
26
+ output=output_dir.joinpath(f'{NAME}_{n}').with_suffix('.json'),
27
+ params={'n': n},
28
+ scores=[distinct_n_sentence_level(s, n) for s in hypothesis],
29
+ )
src/distinct_n/bin/score.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ HYPO=/home/cgsdfc/UbuntuDialogueCorpus/ResponseContextPairs/ModelPredictions/VHRED/First_VHRED_BeamSearch_5_GeneratedTestResponses.txt_First.txt
4
+ DIR=/home/cgsdfc/Result/Test
5
+
6
+ python bin/distinct_metric.py --output_dir $DIR $HYPO -n 3
src/distinct_n/distinct_n/metrics.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.distinct_n.distinct_n.utils import ngrams
2
+
3
+ __all__ = ["distinct_n_sentence_level", "distinct_n_corpus_level"]
4
+
5
+
6
+ def distinct_n_sentence_level(sentence, n):
7
+ """
8
+ Compute distinct-N for a single sentence.
9
+ :param sentence: a list of words.
10
+ :param n: int, ngram.
11
+ :return: float, the metric value.
12
+ """
13
+ if len(sentence) == 0:
14
+ return 0.0 # Prevent a zero division
15
+ # distinct_ngrams = set(ngrams(sentence, n))
16
+ # print(ngrams(sentence, n))
17
+ return list(set(ngrams(sentence, n)))
18
+ # return len(distinct_ngrams) / len(sentence)
19
+
20
+
21
+ def distinct_n_corpus_level(sentences, n):
22
+ """
23
+ Compute average distinct-N of a list of sentences (the corpus).
24
+ :param sentences: a list of sentence.
25
+ :param n: int, ngram.
26
+ :return: float, the average value.
27
+ """
28
+ temp = []
29
+ length = 0
30
+ for sentence in sentences:
31
+ length += len(sentence)
32
+ temp.extend(distinct_n_sentence_level(sentence, n))
33
+ return len(set(temp)) / length
src/distinct_n/distinct_n/test.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ from distinct_n import distinct_n_sentence_level
4
+ from distinct_n import distinct_n_corpus_level
5
+
6
+
7
+ class TestDistinctN(unittest.TestCase):
8
+ def test_unigram(self):
9
+ sentence = "the the the the the".split()
10
+ self.assertAlmostEqual(
11
+ distinct_n_sentence_level(sentence, 1), 0.2
12
+ )
13
+ sentence = "the the the the cat".split()
14
+ self.assertAlmostEqual(
15
+ distinct_n_sentence_level(sentence, 1), 0.4
16
+ )
17
+
18
+ def test_bigram(self):
19
+ sentence = "the cat sat on the".split()
20
+ self.assertAlmostEqual(
21
+ distinct_n_sentence_level(sentence, 2), 0.8
22
+ )
23
+
24
+ def test_corpus_level(self):
25
+ sentences = [
26
+ 'the cat sat on the mat'.split(),
27
+ 'mat the on sat cat the'.split(),
28
+ 'i do not know'.split(),
29
+ 'Sorry but i do not know'.split(),
30
+ ]
31
+ self.assertAlmostEqual(0.916666, distinct_n_corpus_level(sentences, 1), delta=1e-5)
32
+ self.assertAlmostEqual(0.8125, distinct_n_corpus_level(sentences, 2), delta=1e-5)
src/distinct_n/distinct_n/utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from nltk.ngrams().
3
+ """
4
+ from itertools import chain
5
+
6
+ __all__ = ["ngrams"]
7
+
8
+
9
+ def pad_sequence(sequence, n, pad_left=False, pad_right=False,
10
+ left_pad_symbol=None, right_pad_symbol=None):
11
+ """
12
+ Returns a padded sequence of items before ngram extraction.
13
+
14
+ >>> list(pad_sequence([1,2,3,4,5], 2, pad_left=True, pad_right=True, left_pad_symbol='<s>', right_pad_symbol='</s>'))
15
+ ['<s>', 1, 2, 3, 4, 5, '</s>']
16
+ >>> list(pad_sequence([1,2,3,4,5], 2, pad_left=True, left_pad_symbol='<s>'))
17
+ ['<s>', 1, 2, 3, 4, 5]
18
+ >>> list(pad_sequence([1,2,3,4,5], 2, pad_right=True, right_pad_symbol='</s>'))
19
+ [1, 2, 3, 4, 5, '</s>']
20
+
21
+ :param sequence: the source data to be padded
22
+ :type sequence: sequence or iter
23
+ :param n: the degree of the ngrams
24
+ :type n: int
25
+ :param pad_left: whether the ngrams should be left-padded
26
+ :type pad_left: bool
27
+ :param pad_right: whether the ngrams should be right-padded
28
+ :type pad_right: bool
29
+ :param left_pad_symbol: the symbol to use for left padding (default is None)
30
+ :type left_pad_symbol: any
31
+ :param right_pad_symbol: the symbol to use for right padding (default is None)
32
+ :type right_pad_symbol: any
33
+ :rtype: sequence or iter
34
+ """
35
+ sequence = iter(sequence)
36
+ if pad_left:
37
+ sequence = chain((left_pad_symbol,) * (n - 1), sequence)
38
+ if pad_right:
39
+ sequence = chain(sequence, (right_pad_symbol,) * (n - 1))
40
+ return sequence
41
+
42
+
43
+ def ngrams(sequence, n, pad_left=False, pad_right=False,
44
+ left_pad_symbol=None, right_pad_symbol=None):
45
+ """
46
+ Return the ngrams generated from a sequence of items, as an iterator.
47
+ For example:
48
+
49
+ >>> from nltk.util import ngrams
50
+ >>> list(ngrams([1,2,3,4,5], 3))
51
+ [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
52
+
53
+ Wrap with list for a list version of this function. Set pad_left
54
+ or pad_right to true in order to get additional ngrams:
55
+
56
+ >>> list(ngrams([1,2,3,4,5], 2, pad_right=True))
57
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, None)]
58
+ >>> list(ngrams([1,2,3,4,5], 2, pad_right=True, right_pad_symbol='</s>'))
59
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, '</s>')]
60
+ >>> list(ngrams([1,2,3,4,5], 2, pad_left=True, left_pad_symbol='<s>'))
61
+ [('<s>', 1), (1, 2), (2, 3), (3, 4), (4, 5)]
62
+ >>> list(ngrams([1,2,3,4,5], 2, pad_left=True, pad_right=True, left_pad_symbol='<s>', right_pad_symbol='</s>'))
63
+ [('<s>', 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, '</s>')]
64
+
65
+
66
+ :param sequence: the source data to be converted into ngrams
67
+ :type sequence: sequence or iter
68
+ :param n: the degree of the ngrams
69
+ :type n: int
70
+ :param pad_left: whether the ngrams should be left-padded
71
+ :type pad_left: bool
72
+ :param pad_right: whether the ngrams should be right-padded
73
+ :type pad_right: bool
74
+ :param left_pad_symbol: the symbol to use for left padding (default is None)
75
+ :type left_pad_symbol: any
76
+ :param right_pad_symbol: the symbol to use for right padding (default is None)
77
+ :type right_pad_symbol: any
78
+ :rtype: sequence or iter
79
+ """
80
+ sequence = pad_sequence(sequence, n, pad_left, pad_right,
81
+ left_pad_symbol, right_pad_symbol)
82
+
83
+ history = []
84
+ while n > 1:
85
+ history.append(next(sequence))
86
+ n -= 1
87
+ for item in sequence:
88
+ history.append(item)
89
+ yield tuple(history)
90
+ del history[0]
src/distinct_n/setup.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ __version__ = '0.4.0'
4
+
5
+ setup(
6
+ name='Distinct_N',
7
+ version=__version__,
8
+ description='Distinct-N metric that measures degree of diversity of generated response',
9
+ url='https://github.com/neural-dialogue-metrics/Distinct-N.git',
10
+ author='cgsdfc',
11
+ author_email='cgsdfc@126.com',
12
+ keywords=[
13
+ 'NL', 'CL', 'MT',
14
+ 'natural language processing',
15
+ 'computational linguistics',
16
+ 'machine translation',
17
+ ],
18
+ packages=['distinct_n'],
19
+ scripts=['bin/distinct_metric.py'],
20
+ classifiers=[
21
+ 'Intended Audience :: Science/Research',
22
+ 'License :: OSI Approved :: Apache-v2',
23
+ 'Programming Language :: Python :: 3',
24
+ 'Topic :: Text Processing :: Linguistic',
25
+ ],
26
+ license='LICENCE.txt',
27
+ long_description=open('README.md').read(),
28
+ install_requires=[],
29
+ )
src/distinct_n/testdata/bigram.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ the cat sat on the mat
src/distinct_n/testdata/unigram.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ the the the the a
src/utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ngram import NGram
2
+
3
+
4
+ def post_process_template(tB):
5
+ if tB.endswith('.') == False:
6
+ tB += '.'
7
+ return tB
8
+ # return tB.split('.')[0] + '.'
9
+
10
+
11
+ def construct_template(words, templateA, if_then=False):
12
+ if len(words) == 2:
13
+ # template = ['{} <mask> {}.'.format(words[0], words[1])]
14
+ templates = [
15
+ # '{} is <mask> {}.'.format(words[0], words[1]),
16
+ '{} <mask> {}.'.format(words[0], words[1]),
17
+ ]
18
+ elif len(words) == 1:
19
+ templates = [
20
+ # '{} is <mask>.'.format(words[0]),
21
+ '{} <mask>.'.format(words[0])]
22
+
23
+ elif len(words) == 0:
24
+ templates = []
25
+
26
+ if if_then:
27
+ for word in words:
28
+ index = templateA.index('<mask>')
29
+ templateA = templateA[:index] + word + templateA[index + len('<mask>'):]
30
+ templates = ['If ' + templateA + ' then ' + template for template in templates]
31
+
32
+ return templates
33
+
34
+
35
+ def filter_words(words_prob):
36
+ word_count = {}
37
+ token1_count = {}
38
+ word2_count = {}
39
+ ret = []
40
+ for words, prob, *_ in words_prob:
41
+ filter_this = False
42
+
43
+ # filter repetitive token
44
+ token_count = {}
45
+ for word in words:
46
+ for token in word.split(' '):
47
+ if token in token_count:
48
+ filter_this = True
49
+ token_count[token] = 1
50
+ if filter_this:
51
+ prob *= 0.5
52
+
53
+ # filter repetitive words
54
+ if len(words) == 2 and words[0] == words[1]:
55
+ continue
56
+
57
+ # filter repetitive first token
58
+ token1 = words[0].split(' ')[0]
59
+ if token1 not in token1_count:
60
+ token1_count[token1] = 1
61
+ else:
62
+ token1_count[token1] += 1
63
+ prob /= token1_count[token1]
64
+
65
+ for word in words:
66
+ if word not in word_count:
67
+ word_count[word] = 0
68
+ word_count[word] += 1
69
+ prob /= word_count[word]
70
+
71
+ if len(words) == 2:
72
+ if words[1] not in word2_count:
73
+ word2_count[words[1]] = 0
74
+ word2_count[words[1]] += 1
75
+ prob /= word2_count[words[1]]
76
+
77
+ ret.append([words, prob])
78
+ return sorted(ret, key=lambda x: x[1], reverse=True)
79
+
80
+
81
+ import math
82
+ from copy import deepcopy
83
+
84
+
85
+ def convert_for_print(arr):
86
+ ret = deepcopy(arr)
87
+ for i in range(len(ret)):
88
+ ret[i][1] = round(ret[i][1], 7)
89
+ if len(ret[i]) == 3:
90
+ for j in range(len(ret[i][2])):
91
+ ret[i][2][j] = round(ret[i][2][j], 7)
92
+ return ret
93
+
94
+
95
+ def formalize_tA(tA):
96
+ tA = tA.strip()
97
+ if tA.endswith('.'):
98
+ tA = tA[:-1].strip() + '.'
99
+ else:
100
+ tA += '.'
101
+ tA = tA.replace(' ,', ',')
102
+ tA = tA.replace(" '", "'")
103
+ return tA
104
+
105
+
106
+ ngram_n = 3
107
+
108
+
109
+ def extract_similar_words(txt, words):
110
+ max_word_length = 0
111
+ for word in words:
112
+ if len(word) > max_word_length:
113
+ max_word_length = len(word)
114
+
115
+ txt_ngrams = []
116
+ for i in range(len(txt)):
117
+ for j in range(i + ngram_n, min(len(txt), i + max_word_length + 5)):
118
+ txt_ngrams.append(txt[i:j].lower())
119
+ n = NGram(txt_ngrams, key=lambda x: x.lower(), N=ngram_n)
120
+ ret = []
121
+ for word in words:
122
+ matched_word = n.find(word.lower(), 0.5)
123
+ if matched_word is None:
124
+ return None
125
+ ret.append(matched_word)
126
+ return ret
127
+
128
+
129
+ def extract_words(txt, words):
130
+ for word in words:
131
+ if word not in txt:
132
+ return None
133
+ return [word.lower() for word in words]