Spaces:
Runtime error
Runtime error
Upload 25 files
Browse files- evaluation.py +254 -0
- expbert.py +282 -0
- inductor.py +401 -0
- src/__pycache__/bart_with_group_beam.cpython-38.pyc +0 -0
- src/bart_with_group_beam.py +608 -0
- src/distinct_n/.gitignore +58 -0
- src/distinct_n/.idea/Distinct-N.iml +11 -0
- src/distinct_n/.idea/encodings.xml +4 -0
- src/distinct_n/.idea/misc.xml +7 -0
- src/distinct_n/.idea/modules.xml +8 -0
- src/distinct_n/.idea/other.xml +6 -0
- src/distinct_n/.idea/vcs.xml +6 -0
- src/distinct_n/.idea/webResources.xml +14 -0
- src/distinct_n/A Diversity-Promoting Objective Function for Neural Conversation Models.pdf +0 -0
- src/distinct_n/LICENSE.txt +202 -0
- src/distinct_n/README.md +30 -0
- src/distinct_n/bin/distinct_metric.py +29 -0
- src/distinct_n/bin/score.sh +6 -0
- src/distinct_n/distinct_n/metrics.py +33 -0
- src/distinct_n/distinct_n/test.py +32 -0
- src/distinct_n/distinct_n/utils.py +90 -0
- src/distinct_n/setup.py +29 -0
- src/distinct_n/testdata/bigram.txt +1 -0
- src/distinct_n/testdata/unigram.txt +1 -0
- src/utils.py +133 -0
evaluation.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import re
|
4 |
+
from datetime import datetime
|
5 |
+
import os
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from nltk import bleu, meteor
|
10 |
+
from rouge_score.rouge_scorer import RougeScorer
|
11 |
+
from tqdm import tqdm
|
12 |
+
from src.distinct_n.distinct_n.metrics import distinct_n_corpus_level as distinct_n
|
13 |
+
|
14 |
+
from inductor import BartInductor, CometInductor
|
15 |
+
|
16 |
+
FILES = {
|
17 |
+
'amie-yago2': 'data/RE-datasets/AMIE-yago2.txt',
|
18 |
+
'rules-yago2': 'data/RE-datasets/RuLES-yago2.txt',
|
19 |
+
"openrule155": "data/OpenRule155.txt",
|
20 |
+
'fewrel': 'data/RE/fewrel-5.txt',
|
21 |
+
'semeval': 'data/RE/semeval-5.txt',
|
22 |
+
'TREx': 'data/RE/trex-5.txt',
|
23 |
+
'nyt10': 'data/RE/nyt10-5.txt',
|
24 |
+
'google-re': 'data/RE/google-re-5.txt',
|
25 |
+
'wiki80': 'data/RE/wiki80-5.txt',
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
if not os.path.exists('logs/'):
|
30 |
+
os.mkdir('logs/')
|
31 |
+
|
32 |
+
logging.basicConfig(
|
33 |
+
filename='logs/evaluation-{}.log'.format(str(datetime.now())),
|
34 |
+
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
35 |
+
datefmt='%m/%d/%Y %H:%M:%S',
|
36 |
+
level=logging.INFO)
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
|
40 |
+
def print_config(config):
|
41 |
+
config = vars(config)
|
42 |
+
logger.info("**************** MODEL CONFIGURATION ****************")
|
43 |
+
for key in sorted(config.keys()):
|
44 |
+
val = config[key]
|
45 |
+
keystr = "{}".format(key) + (" " * (25 - len(key)))
|
46 |
+
logger.info("{} --> {}".format(keystr, val))
|
47 |
+
logger.info("**************** MODEL CONFIGURATION ****************")
|
48 |
+
|
49 |
+
scorer = RougeScorer(['rougeL'], use_stemmer=True)
|
50 |
+
|
51 |
+
def rouge(references, hypothesis):
|
52 |
+
scores = []
|
53 |
+
for reference in references:
|
54 |
+
scores.append(
|
55 |
+
scorer.score(
|
56 |
+
reference,
|
57 |
+
hypothesis)['rougeL'][2]
|
58 |
+
)
|
59 |
+
|
60 |
+
return max(scores)
|
61 |
+
|
62 |
+
|
63 |
+
class RelationExtractionEvaluator(object):
|
64 |
+
def __init__(self, args):
|
65 |
+
self.args = args
|
66 |
+
if self.args.inductor == 'rule':
|
67 |
+
self.inductor = BartInductor(
|
68 |
+
group_beam=self.args.group_beam,
|
69 |
+
continue_pretrain_instance_generator=self.args.mlm_training,
|
70 |
+
continue_pretrain_hypo_generator=self.args.bart_training,
|
71 |
+
if_then=self.args.if_then,
|
72 |
+
)
|
73 |
+
elif self.args.inductor == 'comet':
|
74 |
+
self.inductor = CometInductor()
|
75 |
+
|
76 |
+
def clean(self, text):
|
77 |
+
segments = text.split('<mask>')
|
78 |
+
if len(segments) == 3 and segments[2].startswith('.'):
|
79 |
+
return '<mask>'.join(segments[:2]) + '<mask>.'
|
80 |
+
else:
|
81 |
+
return text
|
82 |
+
|
83 |
+
def clean_references(self, texts):
|
84 |
+
for i, text in enumerate(texts):
|
85 |
+
if text.endswith(" ."):
|
86 |
+
texts[i] = text.replace(" .", ".")
|
87 |
+
|
88 |
+
return texts
|
89 |
+
|
90 |
+
def self_bleu(self, hypothesis):
|
91 |
+
bleus = []
|
92 |
+
for i in range(len(hypothesis)):
|
93 |
+
bleus.append(bleu(
|
94 |
+
hypothesis[:i] + hypothesis[i + 1:],
|
95 |
+
hypothesis[i],
|
96 |
+
weights=(0.5, 0.5)))
|
97 |
+
|
98 |
+
ret = np.mean(bleus)
|
99 |
+
return ret
|
100 |
+
|
101 |
+
def evaluate(self, task):
|
102 |
+
with torch.no_grad():
|
103 |
+
self.metrics = {
|
104 |
+
"bleu-4": [],
|
105 |
+
"bleu-3": [],
|
106 |
+
"bleu-2": [],
|
107 |
+
"bleu-1": [],
|
108 |
+
"METEOR": [],
|
109 |
+
"ROUGE-L": [],
|
110 |
+
"self-BLEU-2": [],
|
111 |
+
}
|
112 |
+
with open(FILES[task], 'r', encoding='utf-8') as file:
|
113 |
+
data = file.readlines()
|
114 |
+
with tqdm(total=len(data)) as pbar:
|
115 |
+
for row in data:
|
116 |
+
pbar.update(1)
|
117 |
+
row = row.strip().split('\t')
|
118 |
+
inputs, head, tail, relations = row[0], row[1], row[2], row[3]
|
119 |
+
inputs = inputs.strip()
|
120 |
+
|
121 |
+
if relations.startswith('[') and relations.endswith(']'):
|
122 |
+
inputs = re.sub("<A>|<B>", "<mask>", inputs)
|
123 |
+
references = [relation.replace('<A>', '<mask>').replace('<B>', '<mask>').lower().strip() for relation in eval(relations)]
|
124 |
+
else:
|
125 |
+
references = [relations.replace('[X]', '<mask>').replace('[Y]', '<mask>').lower().strip()]
|
126 |
+
references = self.clean_references(references)
|
127 |
+
hypothesis = self.inductor.generate(inputs, k=10, topk=10)
|
128 |
+
|
129 |
+
logger.info("***********Input************")
|
130 |
+
logger.info(inputs)
|
131 |
+
logger.info("*********Hypothesis*********")
|
132 |
+
for i, hypo in enumerate(hypothesis):
|
133 |
+
hypothesis[i] = self.clean(hypo.lower().strip())
|
134 |
+
logger.info(hypo)
|
135 |
+
|
136 |
+
logger.info("****************************")
|
137 |
+
logger.info("*********References*********")
|
138 |
+
logger.info(references)
|
139 |
+
logger.info("****************************")
|
140 |
+
|
141 |
+
if len(hypothesis) == 0:
|
142 |
+
for k in self.metrics.keys():
|
143 |
+
if k != 'self-BLEU-2':
|
144 |
+
self.metrics[k].append(0.)
|
145 |
+
|
146 |
+
else:
|
147 |
+
for hypo in hypothesis:
|
148 |
+
try:
|
149 |
+
self.metrics['bleu-4'].append(
|
150 |
+
bleu(
|
151 |
+
[reference.split() for reference in references],
|
152 |
+
hypo.split(),
|
153 |
+
weights=(0.25, 0.25, 0.25, 0.25)
|
154 |
+
)
|
155 |
+
)
|
156 |
+
except Exception:
|
157 |
+
logger.warning("Skip bleu-4 in example: {}".format(inputs))
|
158 |
+
pass
|
159 |
+
|
160 |
+
try:
|
161 |
+
self.metrics['bleu-3'].append(
|
162 |
+
bleu(
|
163 |
+
[reference.split() for reference in references],
|
164 |
+
hypo.split(),
|
165 |
+
weights=(1 / 3, ) * 3
|
166 |
+
)
|
167 |
+
)
|
168 |
+
except Exception:
|
169 |
+
logger.warning("Skip bleu-3 in example: {}".format(inputs))
|
170 |
+
pass
|
171 |
+
|
172 |
+
try:
|
173 |
+
self.metrics['bleu-2'].append(
|
174 |
+
bleu(
|
175 |
+
[reference.split() for reference in references],
|
176 |
+
hypo.split(),
|
177 |
+
weights=(0.5, 0.5)
|
178 |
+
)
|
179 |
+
)
|
180 |
+
except Exception:
|
181 |
+
logger.warning("Skip bleu-2 in example: {}".format(inputs))
|
182 |
+
pass
|
183 |
+
|
184 |
+
try:
|
185 |
+
self.metrics['bleu-1'].append(
|
186 |
+
bleu(
|
187 |
+
[reference.split() for reference in references],
|
188 |
+
hypo.split(),
|
189 |
+
weights=(1.0, )
|
190 |
+
)
|
191 |
+
)
|
192 |
+
except Exception:
|
193 |
+
logger.warning("Skip bleu-1 in example: {}".format(inputs))
|
194 |
+
pass
|
195 |
+
|
196 |
+
try:
|
197 |
+
self.metrics['METEOR'].append(
|
198 |
+
meteor(
|
199 |
+
references,
|
200 |
+
hypo,
|
201 |
+
)
|
202 |
+
)
|
203 |
+
except:
|
204 |
+
logger.warning("Skip METEOR in example: {}".format(inputs))
|
205 |
+
pass
|
206 |
+
|
207 |
+
|
208 |
+
try:
|
209 |
+
self.metrics['ROUGE-L'].append(
|
210 |
+
rouge(
|
211 |
+
references,
|
212 |
+
hypo,
|
213 |
+
)
|
214 |
+
)
|
215 |
+
except:
|
216 |
+
logger.warning("Skip ROUGE-L in example: {}".format(inputs))
|
217 |
+
pass
|
218 |
+
try:
|
219 |
+
self.metrics['self-BLEU-2'].append(
|
220 |
+
self.self_bleu(
|
221 |
+
hypothesis,
|
222 |
+
)
|
223 |
+
)
|
224 |
+
except:
|
225 |
+
logger.warning("Skip self-bleu-2 in example: {}.".format(inputs))
|
226 |
+
pass
|
227 |
+
# break
|
228 |
+
|
229 |
+
self.print(task, self.metrics)
|
230 |
+
|
231 |
+
def print(self, task, metrics):
|
232 |
+
logger.info("Task: {}".format(str(task)))
|
233 |
+
for k, v in metrics.items():
|
234 |
+
logger.info("{}: {}".format(k, str(np.mean(v))))
|
235 |
+
|
236 |
+
logger.info("*******************************************************")
|
237 |
+
logger.info("*******************************************************")
|
238 |
+
logger.info("*******************************************************")
|
239 |
+
|
240 |
+
|
241 |
+
if __name__ == '__main__':
|
242 |
+
parser = argparse.ArgumentParser()
|
243 |
+
parser.add_argument("--inductor", type=str, default='rule')
|
244 |
+
parser.add_argument("--group_beam", type=bool, default=False)
|
245 |
+
parser.add_argument("--mlm_training", type=bool, default=False)
|
246 |
+
parser.add_argument("--bart_training", type=bool, default=False)
|
247 |
+
parser.add_argument("--if_then", type=bool, default=False)
|
248 |
+
parser.add_argument("--task", type=str, default='openrule155')
|
249 |
+
|
250 |
+
args = parser.parse_args()
|
251 |
+
|
252 |
+
print_config(args)
|
253 |
+
evaluator = RelationExtractionEvaluator(args)
|
254 |
+
evaluator.evaluate(args.task)
|
expbert.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
from datetime import datetime
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from sklearn.metrics import accuracy_score, f1_score
|
10 |
+
from torch import nn
|
11 |
+
from torch.utils.data import DataLoader, Dataset
|
12 |
+
from tqdm import tqdm
|
13 |
+
from transformers import (AutoConfig, AutoModel,
|
14 |
+
AutoModelForSequenceClassification, AutoTokenizer,
|
15 |
+
BertForSequenceClassification, BertModel)
|
16 |
+
|
17 |
+
if not os.path.exists('logs/'):
|
18 |
+
os.mkdir('logs/')
|
19 |
+
|
20 |
+
logging.basicConfig(
|
21 |
+
filename='logs/expbert-{}.log'.format(str(datetime.now())),
|
22 |
+
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
23 |
+
datefmt='%m/%d/%Y %H:%M:%S',
|
24 |
+
level=logging.INFO)
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
TASK2PATH = {
|
29 |
+
"disease-train": "data/disease/train.txt",
|
30 |
+
"disease-test": "data/disease/test.txt",
|
31 |
+
"spouse-train": "data/spouse/train.txt",
|
32 |
+
"spouse-test": "data/spouse/test.txt",
|
33 |
+
}
|
34 |
+
|
35 |
+
ANNOTATED_EXP = {
|
36 |
+
"spouse": "data/exp/expbert_spouse_explanation.txt",
|
37 |
+
"disease": "data/exp/expbert_disease_explanation.txt",
|
38 |
+
}
|
39 |
+
|
40 |
+
GENERATED_EXP = {
|
41 |
+
"spouse": "data/exp/orion_spouse_explanation.txt",
|
42 |
+
"disease": "data/exp/orion_disease_explanation.txt",
|
43 |
+
}
|
44 |
+
|
45 |
+
|
46 |
+
def set_random_seed(seed):
|
47 |
+
random.seed(seed)
|
48 |
+
np.random.seed(seed)
|
49 |
+
torch.manual_seed(seed)
|
50 |
+
torch.cuda.manual_seed(seed)
|
51 |
+
torch.cuda.manual_seed_all(seed)
|
52 |
+
torch.backends.cudnn.deterministic = True
|
53 |
+
torch.backends.cudnn.benchmark = False
|
54 |
+
|
55 |
+
|
56 |
+
def print_config(config):
|
57 |
+
config = vars(config)
|
58 |
+
logger.info("**************** MODEL CONFIGURATION ****************")
|
59 |
+
for key in sorted(config.keys()):
|
60 |
+
val = config[key]
|
61 |
+
keystr = "{}".format(key) + (" " * (25 - len(key)))
|
62 |
+
logger.info("{} --> {}".format(keystr, val))
|
63 |
+
logger.info("**************** MODEL CONFIGURATION ****************")
|
64 |
+
|
65 |
+
|
66 |
+
class ExpBERT(nn.Module):
|
67 |
+
def __init__(self, args, exp_num):
|
68 |
+
super(ExpBERT, self).__init__()
|
69 |
+
self.args = args
|
70 |
+
self.exp_num = exp_num
|
71 |
+
self.config = AutoConfig.from_pretrained(args.model)
|
72 |
+
self.model = AutoModel.from_pretrained(args.model, config=self.config)
|
73 |
+
self.dropout = nn.Dropout(p=0.1)
|
74 |
+
self.linear = nn.Linear(self.config.hidden_size * exp_num, 2)
|
75 |
+
|
76 |
+
self.criterion = nn.CrossEntropyLoss()
|
77 |
+
|
78 |
+
def forward(self, inputs):
|
79 |
+
for k, v in inputs["encoding"].items():
|
80 |
+
inputs["encoding"][k] = v.cuda()
|
81 |
+
pooler_output = self.model(**inputs["encoding"]).last_hidden_state[:, 0, :].reshape(1, self.exp_num * self.config.hidden_size)
|
82 |
+
pooler_output = self.dropout(pooler_output)
|
83 |
+
logits = self.linear(pooler_output)
|
84 |
+
|
85 |
+
loss = self.criterion(logits, torch.LongTensor([inputs["label"]]).cuda())
|
86 |
+
prediction = torch.argmax(logits)
|
87 |
+
|
88 |
+
return {
|
89 |
+
"loss": loss,
|
90 |
+
"prediction": prediction,
|
91 |
+
}
|
92 |
+
|
93 |
+
|
94 |
+
class REDataset(Dataset):
|
95 |
+
def __init__(self, path, exp, tokenizer):
|
96 |
+
super(REDataset, self).__init__()
|
97 |
+
self.tokenizer = tokenizer
|
98 |
+
self.exp = exp
|
99 |
+
self.sentences = []
|
100 |
+
self.labels = []
|
101 |
+
self.entities = []
|
102 |
+
with open(path, "r", encoding="utf-8") as file:
|
103 |
+
data = file.readlines()
|
104 |
+
for example in data:
|
105 |
+
sentence, entity1, entity2, id, label = example.strip().split("\t")
|
106 |
+
self.sentences.append(sentence)
|
107 |
+
if eval(label) == 1:
|
108 |
+
self.labels.append(1)
|
109 |
+
elif eval(label) == -1:
|
110 |
+
self.labels.append(0)
|
111 |
+
|
112 |
+
self.entities.append([entity1, entity2])
|
113 |
+
|
114 |
+
logger.info("Number of Example in {}: {}".format(path, str(len(self.labels))))
|
115 |
+
|
116 |
+
def __len__(self):
|
117 |
+
return len(self.labels)
|
118 |
+
|
119 |
+
def __getitem__(self, index):
|
120 |
+
return {
|
121 |
+
"sentence": self.sentences[index],
|
122 |
+
"entity": self.entities[index],
|
123 |
+
"label": self.labels[index],
|
124 |
+
}
|
125 |
+
|
126 |
+
def collate_fn(self, batch):
|
127 |
+
outputs = []
|
128 |
+
for ex in batch:
|
129 |
+
temp = []
|
130 |
+
for exp in self.exp:
|
131 |
+
if "{e1}" in exp or "{e2}" in exp:
|
132 |
+
exp = exp.replace("{e1}", ex["entity"][0]).replace("{e2}", ex["entity"][1])
|
133 |
+
else:
|
134 |
+
for entity in ex["entity"]:
|
135 |
+
index = exp.index('<mask>')
|
136 |
+
exp = exp[:index] + entity + exp[index + len('<mask>'):]
|
137 |
+
temp.append(exp)
|
138 |
+
outputs.append(
|
139 |
+
{
|
140 |
+
"encoding": self.tokenizer(
|
141 |
+
[ex["sentence"]] * len(temp), temp,
|
142 |
+
add_special_tokens=True,
|
143 |
+
padding="longest",
|
144 |
+
truncation=True,
|
145 |
+
max_length=156,
|
146 |
+
return_tensors="pt",
|
147 |
+
return_attention_mask=True,
|
148 |
+
return_token_type_ids=True,
|
149 |
+
),
|
150 |
+
"label": ex["label"],
|
151 |
+
}
|
152 |
+
)
|
153 |
+
return outputs
|
154 |
+
|
155 |
+
def collate_fn_(self, batch):
|
156 |
+
texts = []
|
157 |
+
labels = []
|
158 |
+
for ex in batch:
|
159 |
+
texts.append(ex["sentence"])
|
160 |
+
labels.append(ex["label"])
|
161 |
+
|
162 |
+
outputs = self.tokenizer(
|
163 |
+
texts,
|
164 |
+
add_special_tokens=True,
|
165 |
+
padding="longest",
|
166 |
+
truncation=True,
|
167 |
+
max_length=156,
|
168 |
+
return_tensors="pt",
|
169 |
+
return_attention_mask=True,
|
170 |
+
return_token_type_ids=True,
|
171 |
+
)
|
172 |
+
|
173 |
+
outputs["labels"] = torch.LongTensor(labels)
|
174 |
+
|
175 |
+
return outputs
|
176 |
+
|
177 |
+
|
178 |
+
class Trainer(object):
|
179 |
+
def __init__(self, args):
|
180 |
+
self.args = args
|
181 |
+
print_config(args)
|
182 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.args.model)
|
183 |
+
|
184 |
+
TASK2EXP = GENERATED_EXP if args.generated_rules else ANNOTATED_EXP
|
185 |
+
with open(TASK2EXP[args.task], "r", encoding="utf-8") as file:
|
186 |
+
exp = file.readlines()
|
187 |
+
|
188 |
+
self.train_dataset = REDataset(TASK2PATH['{}-train'.format(args.task)], exp, self.tokenizer)
|
189 |
+
self.test_dataset = REDataset(TASK2PATH['{}-test'.format(args.task)], exp, self.tokenizer)
|
190 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(args.model).cuda() if self.args.no_exp else ExpBERT(args, len(exp)).cuda()
|
191 |
+
|
192 |
+
self.train_loader = DataLoader(
|
193 |
+
self.train_dataset,
|
194 |
+
batch_size=args.batch_size,
|
195 |
+
shuffle=args.shuffle,
|
196 |
+
collate_fn=self.train_dataset.collate_fn_ if self.args.no_exp else self.train_dataset.collate_fn,
|
197 |
+
)
|
198 |
+
|
199 |
+
self.test_loader = DataLoader(
|
200 |
+
self.test_dataset,
|
201 |
+
batch_size=args.batch_size,
|
202 |
+
shuffle=args.shuffle,
|
203 |
+
collate_fn=self.test_dataset.collate_fn_ if self.args.no_exp else self.test_dataset.collate_fn,
|
204 |
+
)
|
205 |
+
|
206 |
+
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.learning_rate)
|
207 |
+
|
208 |
+
def compute_metrics(self, labels, predictions):
|
209 |
+
accuracy = accuracy_score(y_pred=predictions, y_true=labels)
|
210 |
+
f1 = f1_score(y_pred=predictions, y_true=labels)
|
211 |
+
|
212 |
+
return accuracy, f1
|
213 |
+
|
214 |
+
def train(self):
|
215 |
+
self.model.train()
|
216 |
+
self.test(-1)
|
217 |
+
for e in range(self.args.epochs):
|
218 |
+
with tqdm(total=len(self.train_loader)) as pbar:
|
219 |
+
for step, examples in enumerate(self.train_loader):
|
220 |
+
self.model.zero_grad()
|
221 |
+
if self.args.no_exp:
|
222 |
+
for k, v in examples.items():
|
223 |
+
examples[k] = v.cuda()
|
224 |
+
outputs = self.model(**examples)
|
225 |
+
outputs.loss.backward()
|
226 |
+
|
227 |
+
else:
|
228 |
+
for ex in examples:
|
229 |
+
outputs = self.model(ex)
|
230 |
+
(outputs["loss"] / len(examples)).backward()
|
231 |
+
|
232 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
233 |
+
self.optimizer.step()
|
234 |
+
pbar.update(1)
|
235 |
+
|
236 |
+
self.test(e)
|
237 |
+
|
238 |
+
def test(self, epoch):
|
239 |
+
self.model.eval()
|
240 |
+
with torch.no_grad():
|
241 |
+
with tqdm(total=len(self.test_loader)) as pbar:
|
242 |
+
loss = []
|
243 |
+
labels = []
|
244 |
+
predictions = []
|
245 |
+
for step, examples in enumerate(self.test_loader):
|
246 |
+
if self.args.no_exp:
|
247 |
+
for k, v in examples.items():
|
248 |
+
examples[k] = v.cuda()
|
249 |
+
outputs = self.model(**examples)
|
250 |
+
loss.append(outputs.loss.float())
|
251 |
+
labels.extend(examples["labels"].tolist())
|
252 |
+
predictions.extend(torch.argmax(outputs.logits, dim=1).tolist())
|
253 |
+
|
254 |
+
else:
|
255 |
+
for ex in examples:
|
256 |
+
labels.append(ex['label'])
|
257 |
+
outputs = self.model(ex)
|
258 |
+
loss.append(outputs["loss"].item())
|
259 |
+
predictions.append(outputs['prediction'].tolist())
|
260 |
+
|
261 |
+
pbar.update(1)
|
262 |
+
accuracy, f1 = self.compute_metrics(predictions, labels)
|
263 |
+
logger.info("[EPOCH {}] Accuracy: {} | F1-Score: {}. (Number of Data {})".format(epoch, accuracy, f1, len(predictions)))
|
264 |
+
|
265 |
+
|
266 |
+
if __name__ == "__main__":
|
267 |
+
parser = argparse.ArgumentParser()
|
268 |
+
parser.add_argument("--task", type=str, default="spouse")
|
269 |
+
parser.add_argument("--model", type=str, default="bert-base-uncased")
|
270 |
+
parser.add_argument("--batch_size", type=int, default=32)
|
271 |
+
parser.add_argument("--learning_rate", type=float, default=2e-5)
|
272 |
+
parser.add_argument("--shuffle", type=bool, default=False)
|
273 |
+
parser.add_argument("--epochs", type=int, default=5)
|
274 |
+
parser.add_argument("--no_exp", type=bool, default=False)
|
275 |
+
parser.add_argument("--generated_rules", type=bool, default=False)
|
276 |
+
|
277 |
+
args = parser.parse_args()
|
278 |
+
|
279 |
+
for seed in range(42, 47):
|
280 |
+
set_random_seed(seed)
|
281 |
+
trainer = Trainer(args)
|
282 |
+
trainer.train()
|
inductor.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from copy import deepcopy
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer,
|
8 |
+
BartForConditionalGeneration, BartTokenizer,)
|
9 |
+
|
10 |
+
from src.bart_with_group_beam import BartForConditionalGeneration_GroupBeam
|
11 |
+
from src.utils import (construct_template, filter_words,
|
12 |
+
formalize_tA, post_process_template)
|
13 |
+
|
14 |
+
ORION_HYPO_GENERATOR = 'chenxran/orion-hypothesis-generator'
|
15 |
+
ORION_INS_GENERATOR = 'chenxran/orion-instance-generator'
|
16 |
+
|
17 |
+
RELATIONS = [
|
18 |
+
"Causes",
|
19 |
+
"HasProperty",
|
20 |
+
"MadeUpOf",
|
21 |
+
"isAfter",
|
22 |
+
"isBefore",
|
23 |
+
"xReact",
|
24 |
+
"xWant",
|
25 |
+
"xReason",
|
26 |
+
"xAttr",
|
27 |
+
"Desires",
|
28 |
+
]
|
29 |
+
|
30 |
+
|
31 |
+
class BartInductor(object):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
group_beam=True,
|
35 |
+
continue_pretrain_instance_generator=True,
|
36 |
+
continue_pretrain_hypo_generator=True,
|
37 |
+
if_then=False
|
38 |
+
):
|
39 |
+
self.if_then = if_then
|
40 |
+
self.orion_instance_generator_path = 'facebook/bart-large' if not continue_pretrain_instance_generator else ORION_INS_GENERATOR
|
41 |
+
self.orion_hypothesis_generator_path = 'facebook/bart-large' if not continue_pretrain_hypo_generator else ORION_HYPO_GENERATOR
|
42 |
+
|
43 |
+
if group_beam:
|
44 |
+
self.orion_hypothesis_generator = BartForConditionalGeneration_GroupBeam.from_pretrained(self.orion_hypothesis_generator_path).cuda().eval().half()
|
45 |
+
else:
|
46 |
+
self.orion_hypothesis_generator = BartForConditionalGeneration.from_pretrained(self.orion_hypothesis_generator_path).cuda().eval().half()
|
47 |
+
|
48 |
+
self.orion_instance_generator = BartForConditionalGeneration.from_pretrained(self.orion_instance_generator_path).cuda().eval().half()
|
49 |
+
|
50 |
+
self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
51 |
+
self.word_length = 2
|
52 |
+
|
53 |
+
self.stop_sub_list = ['he', 'she', 'this', 'that', 'and', 'it', 'which', 'who', 'whose', 'there', 'they', '.', 'its', 'one',
|
54 |
+
'i', ',', 'the', 'nobody', 'his', 'her', 'also', 'only', 'currently', 'here', '()', 'what', 'where',
|
55 |
+
'why', 'a', 'some', '"', ')', '(', 'now', 'everyone', 'everybody', 'their', 'often', 'usually', 'you',
|
56 |
+
'-', '?', ';', 'in', 'on', 'each', 'both', 'him', 'typically', 'mostly', 'sometimes', 'normally',
|
57 |
+
'always', 'usually', 'still', 'today', 'was', 'were', 'but', 'although', 'current', 'all', 'have',
|
58 |
+
'has', 'later', 'with', 'most', 'nowadays', 'then', 'every', 'when', 'someone', 'anyone', 'somebody',
|
59 |
+
'anybody', 'any', 'being', 'get', 'getting', 'thus', 'under', 'even', 'for', 'can', 'rarely', 'never',
|
60 |
+
'may', 'generally', 'other', 'another', 'too', 'first', 'second', 'third', 'mainly', 'primarily',
|
61 |
+
'having', 'have', 'has']
|
62 |
+
|
63 |
+
self.stop_size = len(self.stop_sub_list)
|
64 |
+
for i in range(self.stop_size):
|
65 |
+
if self.stop_sub_list[i][0].isalpha():
|
66 |
+
temp = self.stop_sub_list[i][0].upper() + self.stop_sub_list[i][1:]
|
67 |
+
self.stop_sub_list.append(temp)
|
68 |
+
|
69 |
+
self.bad_words_ids = [self.tokenizer.encode(bad_word)[1:-1] for bad_word in ['also', ' also']]
|
70 |
+
stop_index = self.tokenizer(self.stop_sub_list, max_length=4, padding=True)
|
71 |
+
stop_index = torch.tensor(stop_index['input_ids'])[:, 1]
|
72 |
+
stop_weight = torch.zeros(1, self.tokenizer.vocab_size).cuda()
|
73 |
+
stop_weight[0, stop_index] -= 100
|
74 |
+
self.stop_weight = stop_weight[0, :]
|
75 |
+
|
76 |
+
def clean(self, text):
|
77 |
+
segments = text.split('<mask>')
|
78 |
+
if len(segments) == 3 and segments[2].startswith('.'):
|
79 |
+
return '<mask>'.join(segments[:2]) + '<mask>.'
|
80 |
+
else:
|
81 |
+
return text
|
82 |
+
|
83 |
+
def generate(self, inputs, k=10, topk=10):
|
84 |
+
with torch.no_grad():
|
85 |
+
tB_probs = self.generate_rule(inputs, k)
|
86 |
+
ret = [t[0].replace('<ent0>','<mask>').replace('<ent1>','<mask>') for t in tB_probs]
|
87 |
+
|
88 |
+
new_ret = []
|
89 |
+
for temp in ret:
|
90 |
+
temp = self.clean(temp.strip())
|
91 |
+
if len(new_ret) < topk and temp not in new_ret:
|
92 |
+
new_ret.append(temp)
|
93 |
+
|
94 |
+
return new_ret
|
95 |
+
|
96 |
+
def explore_mask(self, tA, k, tokens, prob, required_token, probs):
|
97 |
+
if required_token == 0:
|
98 |
+
return [[tokens, prob, probs]]
|
99 |
+
if required_token <= self.word_length:
|
100 |
+
k = min(k, 2)
|
101 |
+
ret = []
|
102 |
+
generated_ids = self.tokenizer(tA, max_length=128, padding='longest', return_tensors='pt') # ["input_ids"].cuda()
|
103 |
+
for key in generated_ids.keys():
|
104 |
+
generated_ids[key] = generated_ids[key].cuda()
|
105 |
+
mask_index = torch.where(generated_ids["input_ids"][0] == self.tokenizer.mask_token_id)
|
106 |
+
generated_ret = self.orion_instance_generator(**generated_ids)
|
107 |
+
#logits = generated_ret.logits
|
108 |
+
logits = generated_ret[0]
|
109 |
+
softmax = F.softmax(logits, dim=-1)
|
110 |
+
mask_word = softmax[0, mask_index[0][0], :] + self.stop_weight
|
111 |
+
top_k = torch.topk(mask_word, k, dim=0)
|
112 |
+
for i in range(top_k[1].size(0)):
|
113 |
+
token_s = top_k[1][i]
|
114 |
+
prob_s = top_k[0][i].item()
|
115 |
+
token_this = self.tokenizer.decode([token_s]).strip()
|
116 |
+
if token_this[0].isalpha() == False or len(token_this) <= 2:
|
117 |
+
continue
|
118 |
+
index_s = tA.index(self.tokenizer.mask_token)
|
119 |
+
tAs = tA[:index_s] + token_this + tA[index_s + len(self.tokenizer.mask_token):]
|
120 |
+
tokens_this = [t for t in tokens]
|
121 |
+
tokens_this.append(token_this)
|
122 |
+
probs_new = deepcopy(probs)
|
123 |
+
probs_new.append(prob_s)
|
124 |
+
ret.extend(self.explore_mask(tAs, 1, tokens_this, prob_s * prob, required_token - 1,probs_new))
|
125 |
+
return ret
|
126 |
+
|
127 |
+
def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
|
128 |
+
spans = [t.lower().strip() for t in tA[:-1].split('<mask>')]
|
129 |
+
generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].cuda()
|
130 |
+
generated_ret = self.orion_instance_generator.generate(generated_ids, num_beams=max(120, k),
|
131 |
+
#num_beam_groups=max(120, k),
|
132 |
+
max_length=generated_ids.size(1) + 15,
|
133 |
+
num_return_sequences=max(120, k), #min_length=generated_ids.size(1),
|
134 |
+
#diversity_penalty=2.0,
|
135 |
+
#length_penalty= 0.8,
|
136 |
+
#early_stopping=True, bad_words_ids=bad_words_ids, no_repeat_ngram_size=2,
|
137 |
+
output_scores=True,
|
138 |
+
return_dict_in_generate=True)
|
139 |
+
summary_ids = generated_ret['sequences']
|
140 |
+
probs = F.softmax(generated_ret['sequences_scores'])
|
141 |
+
txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summary_ids]
|
142 |
+
ret = []
|
143 |
+
|
144 |
+
for i, txt in enumerate(txts):
|
145 |
+
if tA.endswith('.'):
|
146 |
+
if txt.endswith('.'):
|
147 |
+
txt = txt[:-1].strip()
|
148 |
+
txt += '.'
|
149 |
+
word_imcomplete = False
|
150 |
+
prob = probs[i].item()
|
151 |
+
words_i = []
|
152 |
+
|
153 |
+
start_index = 0
|
154 |
+
for j in range(len(spans)-1):
|
155 |
+
span1 = spans[j]
|
156 |
+
span2 = spans[j+1]
|
157 |
+
if (span1 in txt.lower()[start_index:]) and (span2 in txt.lower()[start_index:]):
|
158 |
+
index1 = txt.lower().index(span1,start_index)+len(span1)
|
159 |
+
if span2 == '':
|
160 |
+
if txt[-1] == '.':
|
161 |
+
index2 = len(txt) -1
|
162 |
+
else:
|
163 |
+
index2 = len(txt)
|
164 |
+
else:
|
165 |
+
index2 = txt.lower().index(span2, start_index)
|
166 |
+
|
167 |
+
words_i.append(txt[index1:index2].strip())
|
168 |
+
start_index = index2
|
169 |
+
#if words_i[-1] == '':
|
170 |
+
# word_imcomplete = True
|
171 |
+
else:
|
172 |
+
word_imcomplete = True
|
173 |
+
if word_imcomplete:
|
174 |
+
# if print_it:
|
175 |
+
# print(txt + '\t' + tA + '\t' + '×')
|
176 |
+
continue
|
177 |
+
|
178 |
+
|
179 |
+
ret.append([words_i, prob])
|
180 |
+
return sorted(ret, key=lambda x: x[1], reverse=True)[:k]
|
181 |
+
|
182 |
+
|
183 |
+
def extract_words_for_tA(self, tA, k=6):
|
184 |
+
word_mask_str = ' '.join([self.tokenizer.mask_token] * self.word_length)
|
185 |
+
tA = tA.replace('<mask>', word_mask_str)
|
186 |
+
mask_count = tA.count(self.tokenizer.mask_token)
|
187 |
+
mask_probs = self.explore_mask(tA, k*20, [], 1.0, mask_count, [])
|
188 |
+
ret = []
|
189 |
+
visited_mask_txt = {}
|
190 |
+
for mask, prob, probs in mask_probs:
|
191 |
+
mask_txt = ' '.join(mask).lower()
|
192 |
+
if mask_txt in visited_mask_txt:
|
193 |
+
continue
|
194 |
+
visited_mask_txt[mask_txt] = 1
|
195 |
+
words = []
|
196 |
+
probs_words = []
|
197 |
+
for i in range(0,mask_count, self.word_length):
|
198 |
+
words.append(' '.join(mask[i: i + self.word_length]))
|
199 |
+
prob_word = 1.0
|
200 |
+
for j in range(i, i + self.word_length):
|
201 |
+
prob_word *= probs[j]
|
202 |
+
probs_words.append(prob_word)
|
203 |
+
ret.append([words, prob, probs_words])
|
204 |
+
return sorted(ret, key=lambda x: x[1], reverse=True)[:k]
|
205 |
+
|
206 |
+
def extract_templateBs_batch(self, words_prob, tA, k, print_it = False):
|
207 |
+
words_prob_sorted = []
|
208 |
+
for (words, probA, *_) in words_prob:
|
209 |
+
tokenized_word = self.tokenizer(words[0])
|
210 |
+
words_prob_sorted.append([words,probA,len(tokenized_word['input_ids'])])
|
211 |
+
words_prob_sorted.sort(key=lambda x:x[2])
|
212 |
+
|
213 |
+
batch_size = 8
|
214 |
+
templates = []
|
215 |
+
index_words = {}
|
216 |
+
ret = {}
|
217 |
+
num_beams = k
|
218 |
+
for enum, (words, probA, *_) in enumerate(words_prob_sorted):
|
219 |
+
template = construct_template(words, tA, self.if_then)
|
220 |
+
templates.extend(template)
|
221 |
+
for t in template:
|
222 |
+
index_words[len(index_words)] = '\t'.join(words)
|
223 |
+
# index_words[len(templates)-1] = '\t'.join(words)
|
224 |
+
if (len(templates) == batch_size) or enum==len(words_prob_sorted)-1 or (words_prob_sorted[enum+1][2]!=words_prob_sorted[enum][2]):
|
225 |
+
generated_ids = self.tokenizer(templates, padding="longest", return_tensors='pt')['input_ids'].cuda()
|
226 |
+
generated_ret = self.orion_hypothesis_generator.generate(generated_ids, num_beams=num_beams,
|
227 |
+
num_beam_groups=num_beams,
|
228 |
+
max_length=28, #template_length+5,
|
229 |
+
num_return_sequences=num_beams, min_length=3,
|
230 |
+
diversity_penalty=1.0,
|
231 |
+
early_stopping=True,
|
232 |
+
#length_penalty = 0.1,
|
233 |
+
bad_words_ids=self.bad_words_ids,
|
234 |
+
#no_repeat_ngram_size=2,
|
235 |
+
output_scores=True,
|
236 |
+
return_dict_in_generate=True, decoder_ori_input_ids = generated_ids,
|
237 |
+
top_p=0.95,
|
238 |
+
)
|
239 |
+
summary_ids = generated_ret['sequences'].reshape((len(templates),num_beams,-1))
|
240 |
+
probs = F.softmax(generated_ret['sequences_scores'].reshape((len(templates),num_beams)),dim=1)
|
241 |
+
for ii in range(summary_ids.size(0)):
|
242 |
+
txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in
|
243 |
+
summary_ids[ii]]
|
244 |
+
ii_template = []
|
245 |
+
words_ii = index_words[ii].split('\t')
|
246 |
+
for i, txt in enumerate(txts):
|
247 |
+
prob = probs[ii][i].item() * probA
|
248 |
+
|
249 |
+
txt = txt.lower()
|
250 |
+
txt = post_process_template(txt)
|
251 |
+
|
252 |
+
words_ii_matched = [word.lower() for word in words_ii] #extract_similar_words(txt, words_ii)
|
253 |
+
if words_ii_matched is None:
|
254 |
+
prob = 0.0
|
255 |
+
else:
|
256 |
+
for j, word in enumerate(words_ii_matched):
|
257 |
+
if word not in txt:
|
258 |
+
prob = 0.0
|
259 |
+
else:
|
260 |
+
txt = txt.replace(word, '<ent{}>'.format(j), 1)
|
261 |
+
|
262 |
+
if txt.count(' ')+1<=3:
|
263 |
+
continue
|
264 |
+
|
265 |
+
ii_template.append([txt, prob])
|
266 |
+
# if print_it:
|
267 |
+
# print(index_words[ii]+'\t'+str(convert_for_print(ii_template)))
|
268 |
+
for template, prob in ii_template:
|
269 |
+
if template not in ret:
|
270 |
+
ret[template] = 0.0
|
271 |
+
ret[template] += prob
|
272 |
+
templates.clear()
|
273 |
+
index_words.clear()
|
274 |
+
|
275 |
+
return ret
|
276 |
+
|
277 |
+
def generate_rule(self, tA, k=10, print_it = False):
|
278 |
+
tA=formalize_tA(tA)
|
279 |
+
if 'bart' in str(self.orion_instance_generator.__class__).lower():
|
280 |
+
words_prob = self.extract_words_for_tA_bart(tA, k,print_it=print_it)
|
281 |
+
words_prob = filter_words(words_prob)[:k]
|
282 |
+
# if print_it:
|
283 |
+
# print(convert_for_print(words_prob))
|
284 |
+
else:
|
285 |
+
words_prob = self.extract_words_for_tA(tA, k)
|
286 |
+
words_prob = filter_words(words_prob)[:k]
|
287 |
+
|
288 |
+
tB_prob = self.extract_templateBs_batch(words_prob, tA, k,print_it=print_it)
|
289 |
+
|
290 |
+
ret = []
|
291 |
+
for k1 in tB_prob:
|
292 |
+
ret.append([k1, tB_prob[k1]])
|
293 |
+
ret = sorted(ret, key=lambda x: x[1], reverse=True)[:k]
|
294 |
+
if self.if_then:
|
295 |
+
for i, temp in enumerate(ret):
|
296 |
+
sentence = temp[0]
|
297 |
+
if "then" in sentence:
|
298 |
+
sentence = sentence.split("then")[-1]
|
299 |
+
else:
|
300 |
+
sentence = sentence.replace("if", "")
|
301 |
+
ret[i][0] = sentence
|
302 |
+
return ret
|
303 |
+
|
304 |
+
|
305 |
+
class CometInductor(object):
|
306 |
+
def __init__(self):
|
307 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained("adamlin/comet-atomic_2020_BART").cuda().eval() # .half()
|
308 |
+
self.tokenizer = AutoTokenizer.from_pretrained("adamlin/comet-atomic_2020_BART")
|
309 |
+
self.task = "summarization"
|
310 |
+
self.use_task_specific_params()
|
311 |
+
self.decoder_start_token_id = None
|
312 |
+
|
313 |
+
def drop_repeat(self, old_list):
|
314 |
+
new_list = []
|
315 |
+
for item in old_list:
|
316 |
+
if item not in new_list:
|
317 |
+
new_list.append(item)
|
318 |
+
|
319 |
+
return new_list
|
320 |
+
|
321 |
+
def chunks(self, lst, n):
|
322 |
+
"""Yield successive n-sized chunks from lst."""
|
323 |
+
for i in range(0, len(lst), n):
|
324 |
+
yield lst[i : i + n]
|
325 |
+
|
326 |
+
def use_task_specific_params(self):
|
327 |
+
"""Update config with summarization specific params."""
|
328 |
+
task_specific_params = self.model.config.task_specific_params
|
329 |
+
|
330 |
+
if task_specific_params is not None:
|
331 |
+
pars = task_specific_params.get(self.task, {})
|
332 |
+
self.model.config.update(pars)
|
333 |
+
|
334 |
+
def trim_batch(
|
335 |
+
self, input_ids, pad_token_id, attention_mask=None,
|
336 |
+
):
|
337 |
+
"""Remove columns that are populated exclusively by pad_token_id"""
|
338 |
+
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
339 |
+
if attention_mask is None:
|
340 |
+
return input_ids[:, keep_column_mask]
|
341 |
+
else:
|
342 |
+
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
343 |
+
|
344 |
+
def generate(self, inputs, k, topk):
|
345 |
+
outputs = []
|
346 |
+
words = ['PersonX', 'PersonY']
|
347 |
+
for i, _ in enumerate(re.findall("<mask>", inputs)):
|
348 |
+
index = inputs.index('<mask>')
|
349 |
+
inputs = inputs[:index] + words[i] + inputs[index + len('<mask>'):]
|
350 |
+
|
351 |
+
for relation in RELATIONS:
|
352 |
+
inputs = "{} {} [GEN]".format(inputs[:-1], relation)
|
353 |
+
gen = self.generate_(inputs, num_generate=10)
|
354 |
+
switch = 0
|
355 |
+
for output in gen[0]:
|
356 |
+
output = output.strip()
|
357 |
+
if re.search("PersonX|X", output) and re.search("PersonY|Y", output):
|
358 |
+
temp = re.sub("PersonX|X|PersonY|Y", "<mask>", output.strip())
|
359 |
+
if temp.endswith("."):
|
360 |
+
outputs.append(temp)
|
361 |
+
else:
|
362 |
+
outputs.append(temp + ".")
|
363 |
+
switch = 1
|
364 |
+
break
|
365 |
+
|
366 |
+
if switch == 0:
|
367 |
+
output = gen[0][0]
|
368 |
+
temp = re.sub("PersonX|X|PersonY|Y", "<mask>", output.strip())
|
369 |
+
if temp.endswith("."):
|
370 |
+
outputs.append(temp)
|
371 |
+
else:
|
372 |
+
outputs.append(temp + ".")
|
373 |
+
|
374 |
+
outputs = [output.replace('PersonX', '<mask>').replace('PersonY', '<mask>') for output in outputs]
|
375 |
+
return outputs
|
376 |
+
|
377 |
+
def generate_(
|
378 |
+
self,
|
379 |
+
queries,
|
380 |
+
decode_method="beam",
|
381 |
+
num_generate=5,
|
382 |
+
):
|
383 |
+
|
384 |
+
with torch.no_grad():
|
385 |
+
decs = []
|
386 |
+
batch = self.tokenizer(queries, return_tensors="pt", padding="longest")
|
387 |
+
input_ids, attention_mask = self.trim_batch(**batch, pad_token_id=self.tokenizer.pad_token_id)
|
388 |
+
|
389 |
+
summaries = self.model.generate(
|
390 |
+
input_ids=input_ids.cuda(),
|
391 |
+
attention_mask=attention_mask.cuda(),
|
392 |
+
decoder_start_token_id=self.decoder_start_token_id,
|
393 |
+
num_beams=num_generate,
|
394 |
+
num_return_sequences=num_generate,
|
395 |
+
)
|
396 |
+
|
397 |
+
dec = self.tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
398 |
+
decs.append(dec)
|
399 |
+
|
400 |
+
return decs
|
401 |
+
|
src/__pycache__/bart_with_group_beam.cpython-38.pyc
ADDED
Binary file (17.7 kB). View file
|
|
src/bart_with_group_beam.py
ADDED
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.models.bart import BartForConditionalGeneration
|
2 |
+
import torch
|
3 |
+
from transformers.generation_beam_search import BeamScorer
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from collections import UserDict
|
6 |
+
from typing import Optional, Tuple, Union, Dict, Any
|
7 |
+
from transformers.generation_logits_process import LogitsProcessorList
|
8 |
+
from transformers.generation_utils import BeamSearchEncoderDecoderOutput,BeamSearchDecoderOnlyOutput
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from transformers.file_utils import ModelOutput
|
11 |
+
import torch.nn
|
12 |
+
|
13 |
+
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
|
14 |
+
|
15 |
+
|
16 |
+
class BartForConditionalGeneration_GroupBeam(BartForConditionalGeneration):
|
17 |
+
|
18 |
+
|
19 |
+
def beam_search(
|
20 |
+
self,
|
21 |
+
input_ids: torch.LongTensor,
|
22 |
+
beam_scorer: BeamScorer,
|
23 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
24 |
+
max_length: Optional[int] = None,
|
25 |
+
pad_token_id: Optional[int] = None,
|
26 |
+
eos_token_id: Optional[int] = None,
|
27 |
+
output_attentions: Optional[bool] = None,
|
28 |
+
output_hidden_states: Optional[bool] = None,
|
29 |
+
output_scores: Optional[bool] = None,
|
30 |
+
return_dict_in_generate: Optional[bool] = None,
|
31 |
+
**model_kwargs,
|
32 |
+
) -> Union[BeamSearchOutput, torch.LongTensor]:
|
33 |
+
r"""
|
34 |
+
Generates sequences for models with a language modeling head using beam search decoding.
|
35 |
+
|
36 |
+
Parameters:
|
37 |
+
|
38 |
+
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
39 |
+
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
|
40 |
+
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
|
41 |
+
beam_scorer (:obj:`BeamScorer`):
|
42 |
+
An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are
|
43 |
+
constructed, stored and sorted during generation. For more information, the documentation of
|
44 |
+
:class:`~transformers.BeamScorer` should be read.
|
45 |
+
logits_processor (:obj:`LogitsProcessorList`, `optional`):
|
46 |
+
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
47 |
+
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
|
48 |
+
head applied at each generation step.
|
49 |
+
max_length (:obj:`int`, `optional`, defaults to 20):
|
50 |
+
The maximum length of the sequence to be generated.
|
51 |
+
pad_token_id (:obj:`int`, `optional`):
|
52 |
+
The id of the `padding` token.
|
53 |
+
eos_token_id (:obj:`int`, `optional`):
|
54 |
+
The id of the `end-of-sequence` token.
|
55 |
+
output_attentions (:obj:`bool`, `optional`, defaults to `False`):
|
56 |
+
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
57 |
+
returned tensors for more details.
|
58 |
+
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`):
|
59 |
+
Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors
|
60 |
+
for more details.
|
61 |
+
output_scores (:obj:`bool`, `optional`, defaults to `False`):
|
62 |
+
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
|
63 |
+
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
|
64 |
+
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
65 |
+
model_kwargs:
|
66 |
+
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
|
67 |
+
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
|
68 |
+
|
69 |
+
Return:
|
70 |
+
:class:`~transformers.generation_utilsBeamSearchDecoderOnlyOutput`,
|
71 |
+
:class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A
|
72 |
+
:obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a
|
73 |
+
:class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if
|
74 |
+
``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a
|
75 |
+
:class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if
|
76 |
+
``model.config.is_encoder_decoder=True``.
|
77 |
+
|
78 |
+
|
79 |
+
Examples::
|
80 |
+
|
81 |
+
>>> from transformers import (
|
82 |
+
... AutoTokenizer,
|
83 |
+
... AutoModelForSeq2SeqLM,
|
84 |
+
... LogitsProcessorList,
|
85 |
+
... MinLengthLogitsProcessor,
|
86 |
+
... BeamSearchScorer,
|
87 |
+
... )
|
88 |
+
>>> import torch
|
89 |
+
|
90 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
91 |
+
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
92 |
+
|
93 |
+
>>> encoder_input_str = "translate English to German: How old are you?"
|
94 |
+
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
95 |
+
|
96 |
+
|
97 |
+
>>> # lets run beam search using 3 beams
|
98 |
+
>>> num_beams = 3
|
99 |
+
>>> # define decoder start token ids
|
100 |
+
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
|
101 |
+
>>> input_ids = input_ids * model.config.decoder_start_token_id
|
102 |
+
|
103 |
+
>>> # add encoder_outputs to model keyword arguments
|
104 |
+
>>> model_kwargs = {
|
105 |
+
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
|
106 |
+
... }
|
107 |
+
|
108 |
+
>>> # instantiate beam scorer
|
109 |
+
>>> beam_scorer = BeamSearchScorer(
|
110 |
+
... batch_size=1,
|
111 |
+
... max_length=model.config.max_length,
|
112 |
+
... num_beams=num_beams,
|
113 |
+
... device=model.device,
|
114 |
+
... )
|
115 |
+
|
116 |
+
>>> # instantiate logits processors
|
117 |
+
>>> logits_processor = LogitsProcessorList([
|
118 |
+
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
|
119 |
+
... ])
|
120 |
+
|
121 |
+
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
|
122 |
+
|
123 |
+
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
124 |
+
"""
|
125 |
+
|
126 |
+
# init values
|
127 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
128 |
+
max_length = max_length if max_length is not None else self.config.max_length
|
129 |
+
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
130 |
+
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
131 |
+
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
132 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
133 |
+
output_hidden_states = (
|
134 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
135 |
+
)
|
136 |
+
return_dict_in_generate = (
|
137 |
+
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
138 |
+
)
|
139 |
+
|
140 |
+
# init attention / hidden states / scores tuples
|
141 |
+
scores = () if (return_dict_in_generate and output_scores) else None
|
142 |
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
143 |
+
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
144 |
+
|
145 |
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
146 |
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
147 |
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
148 |
+
encoder_hidden_states = (
|
149 |
+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
150 |
+
)
|
151 |
+
|
152 |
+
batch_size = len(beam_scorer._beam_hyps)
|
153 |
+
num_beams = beam_scorer.num_beams
|
154 |
+
|
155 |
+
batch_beam_size, cur_len = input_ids.shape
|
156 |
+
|
157 |
+
assert (
|
158 |
+
num_beams * batch_size == batch_beam_size
|
159 |
+
), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
160 |
+
|
161 |
+
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
162 |
+
beam_scores[:, 1:] = -1e9
|
163 |
+
beam_scores = beam_scores.view((batch_size * num_beams,))
|
164 |
+
|
165 |
+
while cur_len < max_length:
|
166 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
167 |
+
|
168 |
+
outputs = self(
|
169 |
+
**model_inputs,
|
170 |
+
return_dict=True,
|
171 |
+
output_attentions=output_attentions,
|
172 |
+
output_hidden_states=output_hidden_states,
|
173 |
+
)
|
174 |
+
next_token_logits = outputs.logits[:, -1, :]
|
175 |
+
|
176 |
+
# adjust tokens for Bart, *e.g.*
|
177 |
+
next_token_logits = self.adjust_logits_during_generation(
|
178 |
+
next_token_logits, cur_len=cur_len, max_length=max_length
|
179 |
+
)
|
180 |
+
|
181 |
+
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
182 |
+
|
183 |
+
next_token_scores = logits_processor(input_ids, next_token_scores)
|
184 |
+
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
|
185 |
+
|
186 |
+
# Store scores, attentions and hidden_states when required
|
187 |
+
if return_dict_in_generate:
|
188 |
+
if output_scores:
|
189 |
+
scores += (next_token_scores,)
|
190 |
+
if output_attentions:
|
191 |
+
decoder_attentions += (
|
192 |
+
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
193 |
+
)
|
194 |
+
|
195 |
+
if output_hidden_states:
|
196 |
+
decoder_hidden_states += (
|
197 |
+
(outputs.decoder_hidden_states,)
|
198 |
+
if self.config.is_encoder_decoder
|
199 |
+
else (outputs.hidden_states,)
|
200 |
+
)
|
201 |
+
|
202 |
+
# reshape for beam search
|
203 |
+
vocab_size = next_token_scores.shape[-1]
|
204 |
+
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
205 |
+
#m = torch.nn.LayerNorm(num_beams * vocab_size)
|
206 |
+
#next_token_scores = m(next_token_scores)
|
207 |
+
|
208 |
+
next_token_scores_group = torch.sum(next_token_scores,dim=0,keepdim=True).expand(batch_size,-1) / batch_size
|
209 |
+
|
210 |
+
for i in range(next_token_scores.size(0)):
|
211 |
+
'''tmin = torch.min(next_token_scores_group[i])
|
212 |
+
for j in range(1,len(model_kwargs['decoder_ori_input_ids'][i])):
|
213 |
+
next_token_scores_group[i][model_kwargs['decoder_ori_input_ids'][i][j]] = tmin'''
|
214 |
+
for t in model_kwargs['decoder_ori_input_ids'][i]:
|
215 |
+
for j in range(num_beams):
|
216 |
+
#if t not in input_ids[i] or t==1:
|
217 |
+
next_token_scores_group[i][j * vocab_size + t] = next_token_scores[i][j * vocab_size + t]
|
218 |
+
|
219 |
+
next_token_scores, next_tokens = torch.topk(
|
220 |
+
next_token_scores_group, 2 * num_beams, dim=1, largest=True, sorted=True)
|
221 |
+
|
222 |
+
'''next_token_scores_group = next_token_scores_group.expand(batch_size,-1)
|
223 |
+
next_tokens_group = next_tokens_group.expand(batch_size,-1)
|
224 |
+
|
225 |
+
next_token_scores, next_tokens = torch.topk(
|
226 |
+
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
|
227 |
+
)
|
228 |
+
|
229 |
+
for i in range(next_token_scores.size(0)):
|
230 |
+
j1 = 0
|
231 |
+
for j in range(next_token_scores.size(1)):
|
232 |
+
if next_tokens[i][j] not in model_kwargs['decoder_ori_input_ids'][i]:
|
233 |
+
next_tokens[i][j] = next_tokens_group[i][j1]
|
234 |
+
j1 += 1
|
235 |
+
next_token_scores = next_token_scores_group
|
236 |
+
|
237 |
+
del next_token_scores_group, next_tokens_group'''
|
238 |
+
|
239 |
+
next_indices = next_tokens // vocab_size
|
240 |
+
next_tokens = next_tokens % vocab_size
|
241 |
+
|
242 |
+
# stateless
|
243 |
+
beam_outputs = beam_scorer.process(
|
244 |
+
input_ids,
|
245 |
+
next_token_scores,
|
246 |
+
next_tokens,
|
247 |
+
next_indices,
|
248 |
+
pad_token_id=pad_token_id,
|
249 |
+
eos_token_id=eos_token_id,
|
250 |
+
)
|
251 |
+
beam_scores = beam_outputs["next_beam_scores"]
|
252 |
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
253 |
+
beam_idx = beam_outputs["next_beam_indices"]
|
254 |
+
|
255 |
+
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
256 |
+
|
257 |
+
cur_len = cur_len + 1
|
258 |
+
|
259 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
260 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
261 |
+
)
|
262 |
+
if model_kwargs["past"] is not None:
|
263 |
+
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
|
264 |
+
|
265 |
+
if beam_scorer.is_done:
|
266 |
+
break
|
267 |
+
|
268 |
+
sequence_outputs = beam_scorer.finalize(
|
269 |
+
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
270 |
+
)
|
271 |
+
|
272 |
+
if return_dict_in_generate:
|
273 |
+
if not output_scores:
|
274 |
+
sequence_outputs["sequence_scores"] = None
|
275 |
+
if self.config.is_encoder_decoder:
|
276 |
+
return BeamSearchEncoderDecoderOutput(
|
277 |
+
sequences=sequence_outputs["sequences"],
|
278 |
+
sequences_scores=sequence_outputs["sequence_scores"],
|
279 |
+
scores=scores,
|
280 |
+
encoder_attentions=encoder_attentions,
|
281 |
+
encoder_hidden_states=encoder_hidden_states,
|
282 |
+
decoder_attentions=decoder_attentions,
|
283 |
+
decoder_hidden_states=decoder_hidden_states,
|
284 |
+
)
|
285 |
+
else:
|
286 |
+
return BeamSearchDecoderOnlyOutput(
|
287 |
+
sequences=sequence_outputs["sequences"],
|
288 |
+
sequences_scores=sequence_outputs["sequence_scores"],
|
289 |
+
scores=scores,
|
290 |
+
attentions=decoder_attentions,
|
291 |
+
hidden_states=decoder_hidden_states,
|
292 |
+
)
|
293 |
+
else:
|
294 |
+
return sequence_outputs["sequences"]
|
295 |
+
|
296 |
+
def group_beam_search(
|
297 |
+
self,
|
298 |
+
input_ids: torch.LongTensor,
|
299 |
+
beam_scorer: BeamScorer,
|
300 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
301 |
+
max_length: Optional[int] = None,
|
302 |
+
pad_token_id: Optional[int] = None,
|
303 |
+
eos_token_id: Optional[int] = None,
|
304 |
+
output_attentions: Optional[bool] = None,
|
305 |
+
output_hidden_states: Optional[bool] = None,
|
306 |
+
output_scores: Optional[bool] = None,
|
307 |
+
return_dict_in_generate: Optional[bool] = None,
|
308 |
+
**model_kwargs,
|
309 |
+
):
|
310 |
+
r"""
|
311 |
+
Generates sequences for models with a language modeling head using beam search decoding.
|
312 |
+
|
313 |
+
Parameters:
|
314 |
+
|
315 |
+
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
316 |
+
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
|
317 |
+
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
|
318 |
+
beam_scorer (:obj:`BeamScorer`):
|
319 |
+
An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are
|
320 |
+
constructed, stored and sorted during generation. For more information, the documentation of
|
321 |
+
:class:`~transformers.BeamScorer` should be read.
|
322 |
+
logits_processor (:obj:`LogitsProcessorList`, `optional`):
|
323 |
+
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
324 |
+
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
|
325 |
+
head applied at each generation step.
|
326 |
+
max_length (:obj:`int`, `optional`, defaults to 20):
|
327 |
+
The maximum length of the sequence to be generated.
|
328 |
+
pad_token_id (:obj:`int`, `optional`):
|
329 |
+
The id of the `padding` token.
|
330 |
+
eos_token_id (:obj:`int`, `optional`):
|
331 |
+
The id of the `end-of-sequence` token.
|
332 |
+
output_attentions (:obj:`bool`, `optional`, defaults to `False`):
|
333 |
+
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
334 |
+
returned tensors for more details.
|
335 |
+
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`):
|
336 |
+
Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors
|
337 |
+
for more details.
|
338 |
+
output_scores (:obj:`bool`, `optional`, defaults to `False`):
|
339 |
+
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
|
340 |
+
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
|
341 |
+
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
342 |
+
model_kwargs:
|
343 |
+
Additional model specific kwargs that will be forwarded to the :obj:`forward` function of the model. If
|
344 |
+
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
|
345 |
+
|
346 |
+
Return:
|
347 |
+
:class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput`,
|
348 |
+
:class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A
|
349 |
+
:obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a
|
350 |
+
:class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if
|
351 |
+
:class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if
|
352 |
+
``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a
|
353 |
+
:class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if
|
354 |
+
``model.config.is_encoder_decoder=True``.
|
355 |
+
|
356 |
+
Examples::
|
357 |
+
|
358 |
+
>>> from transformers import (
|
359 |
+
... AutoTokenizer,
|
360 |
+
... AutoModelForSeq2SeqLM,
|
361 |
+
... LogitsProcessorList,
|
362 |
+
... MinLengthLogitsProcessor,
|
363 |
+
... HammingDiversityLogitsProcessor,
|
364 |
+
... BeamSearchScorer,
|
365 |
+
... )
|
366 |
+
>>> import torch
|
367 |
+
|
368 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
369 |
+
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
370 |
+
|
371 |
+
>>> encoder_input_str = "translate English to German: How old are you?"
|
372 |
+
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
373 |
+
|
374 |
+
|
375 |
+
>>> # lets run diverse beam search using 6 beams
|
376 |
+
>>> num_beams = 6
|
377 |
+
>>> # define decoder start token ids
|
378 |
+
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
|
379 |
+
>>> input_ids = input_ids * model.config.decoder_start_token_id
|
380 |
+
|
381 |
+
>>> # add encoder_outputs to model keyword arguments
|
382 |
+
>>> model_kwargs = {
|
383 |
+
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
|
384 |
+
... }
|
385 |
+
|
386 |
+
>>> # instantiate beam scorer
|
387 |
+
>>> beam_scorer = BeamSearchScorer(
|
388 |
+
... batch_size=1,
|
389 |
+
... max_length=model.config.max_length,
|
390 |
+
... num_beams=num_beams,
|
391 |
+
... device=model.device,
|
392 |
+
... num_beam_groups=3
|
393 |
+
... )
|
394 |
+
|
395 |
+
>>> # instantiate logits processors
|
396 |
+
>>> logits_processor = LogitsProcessorList([
|
397 |
+
... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3),
|
398 |
+
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
|
399 |
+
... ])
|
400 |
+
|
401 |
+
>>> outputs = model.group_beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
|
402 |
+
|
403 |
+
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
404 |
+
"""
|
405 |
+
|
406 |
+
# init values
|
407 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
408 |
+
max_length = max_length if max_length is not None else self.config.max_length
|
409 |
+
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
410 |
+
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
411 |
+
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
412 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
413 |
+
output_hidden_states = (
|
414 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
415 |
+
)
|
416 |
+
return_dict_in_generate = (
|
417 |
+
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
418 |
+
)
|
419 |
+
|
420 |
+
# init attention / hidden states / scores tuples
|
421 |
+
scores = () if (return_dict_in_generate and output_scores) else None
|
422 |
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
423 |
+
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
424 |
+
|
425 |
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
426 |
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
427 |
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
428 |
+
encoder_hidden_states = (
|
429 |
+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
430 |
+
)
|
431 |
+
|
432 |
+
batch_size = len(beam_scorer._beam_hyps)
|
433 |
+
num_beams = beam_scorer.num_beams
|
434 |
+
num_beam_groups = beam_scorer.num_beam_groups
|
435 |
+
num_sub_beams = num_beams // num_beam_groups
|
436 |
+
device = input_ids.device
|
437 |
+
|
438 |
+
batch_beam_size, cur_len = input_ids.shape
|
439 |
+
|
440 |
+
assert (
|
441 |
+
num_beams * batch_size == batch_beam_size
|
442 |
+
), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
443 |
+
|
444 |
+
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
445 |
+
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
446 |
+
# the same group don't produce same tokens everytime.
|
447 |
+
beam_scores[:, ::num_sub_beams] = 0
|
448 |
+
beam_scores = beam_scores.view((batch_size * num_beams,))
|
449 |
+
|
450 |
+
while cur_len < max_length:
|
451 |
+
# predicted tokens in cur_len step
|
452 |
+
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
453 |
+
|
454 |
+
# indices which will form the beams in the next time step
|
455 |
+
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
456 |
+
|
457 |
+
# do one decoder step on all beams of all sentences in batch
|
458 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
459 |
+
outputs = self(
|
460 |
+
**model_inputs,
|
461 |
+
return_dict=True,
|
462 |
+
output_attentions=output_attentions,
|
463 |
+
output_hidden_states=output_hidden_states,
|
464 |
+
)
|
465 |
+
|
466 |
+
for beam_group_idx in range(num_beam_groups):
|
467 |
+
group_start_idx = beam_group_idx * num_sub_beams
|
468 |
+
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
469 |
+
group_size = group_end_idx - group_start_idx
|
470 |
+
|
471 |
+
# indices of beams of current group among all sentences in batch
|
472 |
+
batch_group_indices = []
|
473 |
+
|
474 |
+
if output_scores:
|
475 |
+
processed_score = torch.zeros_like(outputs.logits[:, -1, :]).half() # .float()
|
476 |
+
|
477 |
+
for batch_idx in range(batch_size):
|
478 |
+
batch_group_indices.extend(
|
479 |
+
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
480 |
+
)
|
481 |
+
group_input_ids = input_ids[batch_group_indices]
|
482 |
+
|
483 |
+
# select outputs of beams of current group only
|
484 |
+
next_token_logits = outputs.logits[batch_group_indices, -1, :]
|
485 |
+
|
486 |
+
# adjust tokens for Bart, *e.g.*
|
487 |
+
next_token_logits = self.adjust_logits_during_generation(
|
488 |
+
next_token_logits, cur_len=cur_len, max_length=max_length
|
489 |
+
)
|
490 |
+
|
491 |
+
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
|
492 |
+
vocab_size = next_token_scores.shape[-1]
|
493 |
+
|
494 |
+
next_token_scores = logits_processor(
|
495 |
+
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
496 |
+
)
|
497 |
+
next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as(
|
498 |
+
next_token_scores
|
499 |
+
)
|
500 |
+
|
501 |
+
if output_scores:
|
502 |
+
processed_score[batch_group_indices] = next_token_scores.half() # .float()
|
503 |
+
|
504 |
+
# reshape for beam search
|
505 |
+
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
506 |
+
###
|
507 |
+
|
508 |
+
next_token_scores_group = torch.sum(next_token_scores, dim=0, keepdim=True).expand(batch_size,
|
509 |
+
-1) / batch_size
|
510 |
+
|
511 |
+
for i in range(next_token_scores.size(0)):
|
512 |
+
'''tmin = torch.min(next_token_scores_group[i])
|
513 |
+
for j in range(1,len(model_kwargs['decoder_ori_input_ids'][i])):
|
514 |
+
next_token_scores_group[i][model_kwargs['decoder_ori_input_ids'][i][j]] = tmin'''
|
515 |
+
for t in model_kwargs['decoder_ori_input_ids'][i]:
|
516 |
+
for j in range(group_size):
|
517 |
+
# if t not in input_ids[i] or t==1:
|
518 |
+
next_token_scores_group[i][j * vocab_size + t] = next_token_scores[i][j * vocab_size + t]
|
519 |
+
|
520 |
+
next_token_scores, next_tokens = torch.topk(
|
521 |
+
next_token_scores_group, 2 * group_size, dim=1, largest=True, sorted=True)
|
522 |
+
|
523 |
+
|
524 |
+
###
|
525 |
+
#next_token_scores, next_tokens = torch.topk(
|
526 |
+
# next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
527 |
+
#)
|
528 |
+
|
529 |
+
next_indices = next_tokens // vocab_size
|
530 |
+
next_tokens = next_tokens % vocab_size
|
531 |
+
|
532 |
+
# stateless
|
533 |
+
beam_outputs = beam_scorer.process(
|
534 |
+
group_input_ids,
|
535 |
+
next_token_scores,
|
536 |
+
next_tokens,
|
537 |
+
next_indices,
|
538 |
+
pad_token_id=pad_token_id,
|
539 |
+
eos_token_id=eos_token_id,
|
540 |
+
)
|
541 |
+
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
542 |
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
543 |
+
beam_idx = beam_outputs["next_beam_indices"]
|
544 |
+
|
545 |
+
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
546 |
+
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
547 |
+
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
548 |
+
|
549 |
+
# (beam_idx // group_size) -> batch_idx
|
550 |
+
# (beam_idx % group_size) -> offset of idx inside the group
|
551 |
+
reordering_indices[batch_group_indices] = (
|
552 |
+
num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size)
|
553 |
+
)
|
554 |
+
|
555 |
+
# Store scores, attentions and hidden_states when required
|
556 |
+
if return_dict_in_generate:
|
557 |
+
if output_scores:
|
558 |
+
scores += (processed_score,)
|
559 |
+
if output_attentions:
|
560 |
+
decoder_attentions += (
|
561 |
+
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
562 |
+
)
|
563 |
+
|
564 |
+
if output_hidden_states:
|
565 |
+
decoder_hidden_states += (
|
566 |
+
(outputs.decoder_hidden_states,)
|
567 |
+
if self.config.is_encoder_decoder
|
568 |
+
else (outputs.hidden_states,)
|
569 |
+
)
|
570 |
+
|
571 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
572 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
573 |
+
)
|
574 |
+
if model_kwargs["past"] is not None:
|
575 |
+
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices)
|
576 |
+
|
577 |
+
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
578 |
+
cur_len = cur_len + 1
|
579 |
+
if beam_scorer.is_done:
|
580 |
+
break
|
581 |
+
|
582 |
+
sequence_outputs = beam_scorer.finalize(
|
583 |
+
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=max_length,
|
584 |
+
)
|
585 |
+
|
586 |
+
if return_dict_in_generate:
|
587 |
+
if not output_scores:
|
588 |
+
sequence_outputs["sequence_scores"]
|
589 |
+
if self.config.is_encoder_decoder:
|
590 |
+
return BeamSearchEncoderDecoderOutput(
|
591 |
+
sequences=sequence_outputs["sequences"],
|
592 |
+
sequences_scores=sequence_outputs["sequence_scores"],
|
593 |
+
scores=scores,
|
594 |
+
encoder_attentions=encoder_attentions,
|
595 |
+
encoder_hidden_states=encoder_hidden_states,
|
596 |
+
decoder_attentions=decoder_attentions,
|
597 |
+
decoder_hidden_states=decoder_hidden_states,
|
598 |
+
)
|
599 |
+
else:
|
600 |
+
return BeamSearchDecoderOnlyOutput(
|
601 |
+
sequences=sequence_outputs["sequences"],
|
602 |
+
sequences_scores=sequence_outputs["sequence_scores"],
|
603 |
+
scores=scores,
|
604 |
+
attentions=decoder_attentions,
|
605 |
+
hidden_states=decoder_hidden_states,
|
606 |
+
)
|
607 |
+
else:
|
608 |
+
return sequence_outputs["sequences"]
|
src/distinct_n/.gitignore
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
state.py
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
env/
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
|
27 |
+
# PyInstaller
|
28 |
+
# Usually these files are written by a python script from a template
|
29 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
30 |
+
*.manifest
|
31 |
+
*.spec
|
32 |
+
|
33 |
+
# Installer logs
|
34 |
+
pip-log.txt
|
35 |
+
pip-delete-this-directory.txt
|
36 |
+
|
37 |
+
# Unit test / coverage reports
|
38 |
+
htmlcov/
|
39 |
+
.tox/
|
40 |
+
.coverage
|
41 |
+
.coverage.*
|
42 |
+
.cache
|
43 |
+
nosetests.xml
|
44 |
+
coverage.xml
|
45 |
+
*,cover
|
46 |
+
|
47 |
+
# Translations
|
48 |
+
*.mo
|
49 |
+
*.pot
|
50 |
+
|
51 |
+
# Django stuff:
|
52 |
+
*.log
|
53 |
+
|
54 |
+
# Sphinx documentation
|
55 |
+
docs/_build/
|
56 |
+
|
57 |
+
# PyBuilder
|
58 |
+
target/
|
src/distinct_n/.idea/Distinct-N.iml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$">
|
5 |
+
<sourceFolder url="file://$MODULE_DIR$/distinct_n" isTestSource="false" />
|
6 |
+
<excludeFolder url="file://$MODULE_DIR$/docs" />
|
7 |
+
</content>
|
8 |
+
<orderEntry type="jdk" jdkName="Python 3.6 (Metrics)" jdkType="Python SDK" />
|
9 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
10 |
+
</component>
|
11 |
+
</module>
|
src/distinct_n/.idea/encodings.xml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="Encoding" addBOMForNewFiles="with NO BOM" />
|
4 |
+
</project>
|
src/distinct_n/.idea/misc.xml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="JavaScriptSettings">
|
4 |
+
<option name="languageLevel" value="ES6" />
|
5 |
+
</component>
|
6 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (tensorflow)" project-jdk-type="Python SDK" />
|
7 |
+
</project>
|
src/distinct_n/.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/Distinct-N.iml" filepath="$PROJECT_DIR$/.idea/Distinct-N.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
src/distinct_n/.idea/other.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="PySciProjectComponent">
|
4 |
+
<option name="PY_SCI_VIEW_SUGGESTED" value="true" />
|
5 |
+
</component>
|
6 |
+
</project>
|
src/distinct_n/.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
src/distinct_n/.idea/webResources.xml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="WebResourcesPaths">
|
4 |
+
<contentEntries>
|
5 |
+
<entry url="file://$PROJECT_DIR$">
|
6 |
+
<entryData>
|
7 |
+
<resourceRoots>
|
8 |
+
<path value="file://$PROJECT_DIR$/testdata" />
|
9 |
+
</resourceRoots>
|
10 |
+
</entryData>
|
11 |
+
</entry>
|
12 |
+
</contentEntries>
|
13 |
+
</component>
|
14 |
+
</project>
|
src/distinct_n/A Diversity-Promoting Objective Function for Neural Conversation Models.pdf
ADDED
Binary file (200 kB). View file
|
|
src/distinct_n/LICENSE.txt
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [yyyy] [name of copyright owner]
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
src/distinct_n/README.md
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Distinct-N
|
2 |
+
Distinct-N, most notably distinct-1 and distinct-2, is metric that measures the
|
3 |
+
diversity of a sentence. It focuses on the number of *distinct* n-gram of a sentence and thus
|
4 |
+
penalizes sentences with lots of repeated words. The metric is free of any *reference* or *ground truth*
|
5 |
+
sentence and devotes totally to the property of a sentence (generated by the system).
|
6 |
+
It is proposed by Jiwei Li et.al in the paper *A Diversity-Promoting Objective Function for Neural Conversation Models*.
|
7 |
+
|
8 |
+
# Definitions
|
9 |
+
The original paper coined *Distinct-N* as:
|
10 |
+
|
11 |
+
We report degree of diversity by calculating the number of distinct unigrams and bigrams in generated responses.
|
12 |
+
The value is scaled by total number of generated tokens to avoid favoring long sentences
|
13 |
+
|
14 |
+
which is exactly what we have mentioned before.
|
15 |
+
|
16 |
+
# Usage
|
17 |
+
```bash
|
18 |
+
$ python distinct_metric.py -n N_NGRAMS PREDICTION
|
19 |
+
```
|
20 |
+
|
21 |
+
|
22 |
+
where `N_GRAMS` is the length of token sequence to count as unique within one sentence.
|
23 |
+
`PREDICTION` is the prediction or response your model generates with one utterance (sentence) per line.
|
24 |
+
|
25 |
+
|
26 |
+
# Dependencies
|
27 |
+
`python>=3.6.1`
|
28 |
+
|
29 |
+
# References
|
30 |
+
[1] A Diversity-Promoting Objective Function for Neural Conversation Models
|
src/distinct_n/bin/distinct_metric.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from distinct_n import distinct_n_sentence_level
|
5 |
+
from pathlib import Path
|
6 |
+
from agenda.metric_helper import write_score
|
7 |
+
|
8 |
+
NAME = 'distinct_n'
|
9 |
+
|
10 |
+
if __name__ == '__main__':
|
11 |
+
parser = argparse.ArgumentParser()
|
12 |
+
parser.add_argument('hypothesis', help="predicted text file, one example per line")
|
13 |
+
parser.add_argument('-n', dest='n_range', type=int, nargs='+', help="n to use as in distinct-N")
|
14 |
+
parser.add_argument('--output_dir')
|
15 |
+
args = parser.parse_args()
|
16 |
+
|
17 |
+
logging.basicConfig(level=logging.INFO)
|
18 |
+
logging.info('loading hypothesis file...')
|
19 |
+
with open(args.hypothesis) as f:
|
20 |
+
hypothesis = [sentence.split() for sentence in f.readlines()]
|
21 |
+
|
22 |
+
output_dir = Path(args.output_dir)
|
23 |
+
for n in args.n_range:
|
24 |
+
write_score(
|
25 |
+
name=NAME,
|
26 |
+
output=output_dir.joinpath(f'{NAME}_{n}').with_suffix('.json'),
|
27 |
+
params={'n': n},
|
28 |
+
scores=[distinct_n_sentence_level(s, n) for s in hypothesis],
|
29 |
+
)
|
src/distinct_n/bin/score.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
HYPO=/home/cgsdfc/UbuntuDialogueCorpus/ResponseContextPairs/ModelPredictions/VHRED/First_VHRED_BeamSearch_5_GeneratedTestResponses.txt_First.txt
|
4 |
+
DIR=/home/cgsdfc/Result/Test
|
5 |
+
|
6 |
+
python bin/distinct_metric.py --output_dir $DIR $HYPO -n 3
|
src/distinct_n/distinct_n/metrics.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.distinct_n.distinct_n.utils import ngrams
|
2 |
+
|
3 |
+
__all__ = ["distinct_n_sentence_level", "distinct_n_corpus_level"]
|
4 |
+
|
5 |
+
|
6 |
+
def distinct_n_sentence_level(sentence, n):
|
7 |
+
"""
|
8 |
+
Compute distinct-N for a single sentence.
|
9 |
+
:param sentence: a list of words.
|
10 |
+
:param n: int, ngram.
|
11 |
+
:return: float, the metric value.
|
12 |
+
"""
|
13 |
+
if len(sentence) == 0:
|
14 |
+
return 0.0 # Prevent a zero division
|
15 |
+
# distinct_ngrams = set(ngrams(sentence, n))
|
16 |
+
# print(ngrams(sentence, n))
|
17 |
+
return list(set(ngrams(sentence, n)))
|
18 |
+
# return len(distinct_ngrams) / len(sentence)
|
19 |
+
|
20 |
+
|
21 |
+
def distinct_n_corpus_level(sentences, n):
|
22 |
+
"""
|
23 |
+
Compute average distinct-N of a list of sentences (the corpus).
|
24 |
+
:param sentences: a list of sentence.
|
25 |
+
:param n: int, ngram.
|
26 |
+
:return: float, the average value.
|
27 |
+
"""
|
28 |
+
temp = []
|
29 |
+
length = 0
|
30 |
+
for sentence in sentences:
|
31 |
+
length += len(sentence)
|
32 |
+
temp.extend(distinct_n_sentence_level(sentence, n))
|
33 |
+
return len(set(temp)) / length
|
src/distinct_n/distinct_n/test.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
|
3 |
+
from distinct_n import distinct_n_sentence_level
|
4 |
+
from distinct_n import distinct_n_corpus_level
|
5 |
+
|
6 |
+
|
7 |
+
class TestDistinctN(unittest.TestCase):
|
8 |
+
def test_unigram(self):
|
9 |
+
sentence = "the the the the the".split()
|
10 |
+
self.assertAlmostEqual(
|
11 |
+
distinct_n_sentence_level(sentence, 1), 0.2
|
12 |
+
)
|
13 |
+
sentence = "the the the the cat".split()
|
14 |
+
self.assertAlmostEqual(
|
15 |
+
distinct_n_sentence_level(sentence, 1), 0.4
|
16 |
+
)
|
17 |
+
|
18 |
+
def test_bigram(self):
|
19 |
+
sentence = "the cat sat on the".split()
|
20 |
+
self.assertAlmostEqual(
|
21 |
+
distinct_n_sentence_level(sentence, 2), 0.8
|
22 |
+
)
|
23 |
+
|
24 |
+
def test_corpus_level(self):
|
25 |
+
sentences = [
|
26 |
+
'the cat sat on the mat'.split(),
|
27 |
+
'mat the on sat cat the'.split(),
|
28 |
+
'i do not know'.split(),
|
29 |
+
'Sorry but i do not know'.split(),
|
30 |
+
]
|
31 |
+
self.assertAlmostEqual(0.916666, distinct_n_corpus_level(sentences, 1), delta=1e-5)
|
32 |
+
self.assertAlmostEqual(0.8125, distinct_n_corpus_level(sentences, 2), delta=1e-5)
|
src/distinct_n/distinct_n/utils.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from nltk.ngrams().
|
3 |
+
"""
|
4 |
+
from itertools import chain
|
5 |
+
|
6 |
+
__all__ = ["ngrams"]
|
7 |
+
|
8 |
+
|
9 |
+
def pad_sequence(sequence, n, pad_left=False, pad_right=False,
|
10 |
+
left_pad_symbol=None, right_pad_symbol=None):
|
11 |
+
"""
|
12 |
+
Returns a padded sequence of items before ngram extraction.
|
13 |
+
|
14 |
+
>>> list(pad_sequence([1,2,3,4,5], 2, pad_left=True, pad_right=True, left_pad_symbol='<s>', right_pad_symbol='</s>'))
|
15 |
+
['<s>', 1, 2, 3, 4, 5, '</s>']
|
16 |
+
>>> list(pad_sequence([1,2,3,4,5], 2, pad_left=True, left_pad_symbol='<s>'))
|
17 |
+
['<s>', 1, 2, 3, 4, 5]
|
18 |
+
>>> list(pad_sequence([1,2,3,4,5], 2, pad_right=True, right_pad_symbol='</s>'))
|
19 |
+
[1, 2, 3, 4, 5, '</s>']
|
20 |
+
|
21 |
+
:param sequence: the source data to be padded
|
22 |
+
:type sequence: sequence or iter
|
23 |
+
:param n: the degree of the ngrams
|
24 |
+
:type n: int
|
25 |
+
:param pad_left: whether the ngrams should be left-padded
|
26 |
+
:type pad_left: bool
|
27 |
+
:param pad_right: whether the ngrams should be right-padded
|
28 |
+
:type pad_right: bool
|
29 |
+
:param left_pad_symbol: the symbol to use for left padding (default is None)
|
30 |
+
:type left_pad_symbol: any
|
31 |
+
:param right_pad_symbol: the symbol to use for right padding (default is None)
|
32 |
+
:type right_pad_symbol: any
|
33 |
+
:rtype: sequence or iter
|
34 |
+
"""
|
35 |
+
sequence = iter(sequence)
|
36 |
+
if pad_left:
|
37 |
+
sequence = chain((left_pad_symbol,) * (n - 1), sequence)
|
38 |
+
if pad_right:
|
39 |
+
sequence = chain(sequence, (right_pad_symbol,) * (n - 1))
|
40 |
+
return sequence
|
41 |
+
|
42 |
+
|
43 |
+
def ngrams(sequence, n, pad_left=False, pad_right=False,
|
44 |
+
left_pad_symbol=None, right_pad_symbol=None):
|
45 |
+
"""
|
46 |
+
Return the ngrams generated from a sequence of items, as an iterator.
|
47 |
+
For example:
|
48 |
+
|
49 |
+
>>> from nltk.util import ngrams
|
50 |
+
>>> list(ngrams([1,2,3,4,5], 3))
|
51 |
+
[(1, 2, 3), (2, 3, 4), (3, 4, 5)]
|
52 |
+
|
53 |
+
Wrap with list for a list version of this function. Set pad_left
|
54 |
+
or pad_right to true in order to get additional ngrams:
|
55 |
+
|
56 |
+
>>> list(ngrams([1,2,3,4,5], 2, pad_right=True))
|
57 |
+
[(1, 2), (2, 3), (3, 4), (4, 5), (5, None)]
|
58 |
+
>>> list(ngrams([1,2,3,4,5], 2, pad_right=True, right_pad_symbol='</s>'))
|
59 |
+
[(1, 2), (2, 3), (3, 4), (4, 5), (5, '</s>')]
|
60 |
+
>>> list(ngrams([1,2,3,4,5], 2, pad_left=True, left_pad_symbol='<s>'))
|
61 |
+
[('<s>', 1), (1, 2), (2, 3), (3, 4), (4, 5)]
|
62 |
+
>>> list(ngrams([1,2,3,4,5], 2, pad_left=True, pad_right=True, left_pad_symbol='<s>', right_pad_symbol='</s>'))
|
63 |
+
[('<s>', 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, '</s>')]
|
64 |
+
|
65 |
+
|
66 |
+
:param sequence: the source data to be converted into ngrams
|
67 |
+
:type sequence: sequence or iter
|
68 |
+
:param n: the degree of the ngrams
|
69 |
+
:type n: int
|
70 |
+
:param pad_left: whether the ngrams should be left-padded
|
71 |
+
:type pad_left: bool
|
72 |
+
:param pad_right: whether the ngrams should be right-padded
|
73 |
+
:type pad_right: bool
|
74 |
+
:param left_pad_symbol: the symbol to use for left padding (default is None)
|
75 |
+
:type left_pad_symbol: any
|
76 |
+
:param right_pad_symbol: the symbol to use for right padding (default is None)
|
77 |
+
:type right_pad_symbol: any
|
78 |
+
:rtype: sequence or iter
|
79 |
+
"""
|
80 |
+
sequence = pad_sequence(sequence, n, pad_left, pad_right,
|
81 |
+
left_pad_symbol, right_pad_symbol)
|
82 |
+
|
83 |
+
history = []
|
84 |
+
while n > 1:
|
85 |
+
history.append(next(sequence))
|
86 |
+
n -= 1
|
87 |
+
for item in sequence:
|
88 |
+
history.append(item)
|
89 |
+
yield tuple(history)
|
90 |
+
del history[0]
|
src/distinct_n/setup.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
|
3 |
+
__version__ = '0.4.0'
|
4 |
+
|
5 |
+
setup(
|
6 |
+
name='Distinct_N',
|
7 |
+
version=__version__,
|
8 |
+
description='Distinct-N metric that measures degree of diversity of generated response',
|
9 |
+
url='https://github.com/neural-dialogue-metrics/Distinct-N.git',
|
10 |
+
author='cgsdfc',
|
11 |
+
author_email='cgsdfc@126.com',
|
12 |
+
keywords=[
|
13 |
+
'NL', 'CL', 'MT',
|
14 |
+
'natural language processing',
|
15 |
+
'computational linguistics',
|
16 |
+
'machine translation',
|
17 |
+
],
|
18 |
+
packages=['distinct_n'],
|
19 |
+
scripts=['bin/distinct_metric.py'],
|
20 |
+
classifiers=[
|
21 |
+
'Intended Audience :: Science/Research',
|
22 |
+
'License :: OSI Approved :: Apache-v2',
|
23 |
+
'Programming Language :: Python :: 3',
|
24 |
+
'Topic :: Text Processing :: Linguistic',
|
25 |
+
],
|
26 |
+
license='LICENCE.txt',
|
27 |
+
long_description=open('README.md').read(),
|
28 |
+
install_requires=[],
|
29 |
+
)
|
src/distinct_n/testdata/bigram.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
the cat sat on the mat
|
src/distinct_n/testdata/unigram.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
the the the the a
|
src/utils.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ngram import NGram
|
2 |
+
|
3 |
+
|
4 |
+
def post_process_template(tB):
|
5 |
+
if tB.endswith('.') == False:
|
6 |
+
tB += '.'
|
7 |
+
return tB
|
8 |
+
# return tB.split('.')[0] + '.'
|
9 |
+
|
10 |
+
|
11 |
+
def construct_template(words, templateA, if_then=False):
|
12 |
+
if len(words) == 2:
|
13 |
+
# template = ['{} <mask> {}.'.format(words[0], words[1])]
|
14 |
+
templates = [
|
15 |
+
# '{} is <mask> {}.'.format(words[0], words[1]),
|
16 |
+
'{} <mask> {}.'.format(words[0], words[1]),
|
17 |
+
]
|
18 |
+
elif len(words) == 1:
|
19 |
+
templates = [
|
20 |
+
# '{} is <mask>.'.format(words[0]),
|
21 |
+
'{} <mask>.'.format(words[0])]
|
22 |
+
|
23 |
+
elif len(words) == 0:
|
24 |
+
templates = []
|
25 |
+
|
26 |
+
if if_then:
|
27 |
+
for word in words:
|
28 |
+
index = templateA.index('<mask>')
|
29 |
+
templateA = templateA[:index] + word + templateA[index + len('<mask>'):]
|
30 |
+
templates = ['If ' + templateA + ' then ' + template for template in templates]
|
31 |
+
|
32 |
+
return templates
|
33 |
+
|
34 |
+
|
35 |
+
def filter_words(words_prob):
|
36 |
+
word_count = {}
|
37 |
+
token1_count = {}
|
38 |
+
word2_count = {}
|
39 |
+
ret = []
|
40 |
+
for words, prob, *_ in words_prob:
|
41 |
+
filter_this = False
|
42 |
+
|
43 |
+
# filter repetitive token
|
44 |
+
token_count = {}
|
45 |
+
for word in words:
|
46 |
+
for token in word.split(' '):
|
47 |
+
if token in token_count:
|
48 |
+
filter_this = True
|
49 |
+
token_count[token] = 1
|
50 |
+
if filter_this:
|
51 |
+
prob *= 0.5
|
52 |
+
|
53 |
+
# filter repetitive words
|
54 |
+
if len(words) == 2 and words[0] == words[1]:
|
55 |
+
continue
|
56 |
+
|
57 |
+
# filter repetitive first token
|
58 |
+
token1 = words[0].split(' ')[0]
|
59 |
+
if token1 not in token1_count:
|
60 |
+
token1_count[token1] = 1
|
61 |
+
else:
|
62 |
+
token1_count[token1] += 1
|
63 |
+
prob /= token1_count[token1]
|
64 |
+
|
65 |
+
for word in words:
|
66 |
+
if word not in word_count:
|
67 |
+
word_count[word] = 0
|
68 |
+
word_count[word] += 1
|
69 |
+
prob /= word_count[word]
|
70 |
+
|
71 |
+
if len(words) == 2:
|
72 |
+
if words[1] not in word2_count:
|
73 |
+
word2_count[words[1]] = 0
|
74 |
+
word2_count[words[1]] += 1
|
75 |
+
prob /= word2_count[words[1]]
|
76 |
+
|
77 |
+
ret.append([words, prob])
|
78 |
+
return sorted(ret, key=lambda x: x[1], reverse=True)
|
79 |
+
|
80 |
+
|
81 |
+
import math
|
82 |
+
from copy import deepcopy
|
83 |
+
|
84 |
+
|
85 |
+
def convert_for_print(arr):
|
86 |
+
ret = deepcopy(arr)
|
87 |
+
for i in range(len(ret)):
|
88 |
+
ret[i][1] = round(ret[i][1], 7)
|
89 |
+
if len(ret[i]) == 3:
|
90 |
+
for j in range(len(ret[i][2])):
|
91 |
+
ret[i][2][j] = round(ret[i][2][j], 7)
|
92 |
+
return ret
|
93 |
+
|
94 |
+
|
95 |
+
def formalize_tA(tA):
|
96 |
+
tA = tA.strip()
|
97 |
+
if tA.endswith('.'):
|
98 |
+
tA = tA[:-1].strip() + '.'
|
99 |
+
else:
|
100 |
+
tA += '.'
|
101 |
+
tA = tA.replace(' ,', ',')
|
102 |
+
tA = tA.replace(" '", "'")
|
103 |
+
return tA
|
104 |
+
|
105 |
+
|
106 |
+
ngram_n = 3
|
107 |
+
|
108 |
+
|
109 |
+
def extract_similar_words(txt, words):
|
110 |
+
max_word_length = 0
|
111 |
+
for word in words:
|
112 |
+
if len(word) > max_word_length:
|
113 |
+
max_word_length = len(word)
|
114 |
+
|
115 |
+
txt_ngrams = []
|
116 |
+
for i in range(len(txt)):
|
117 |
+
for j in range(i + ngram_n, min(len(txt), i + max_word_length + 5)):
|
118 |
+
txt_ngrams.append(txt[i:j].lower())
|
119 |
+
n = NGram(txt_ngrams, key=lambda x: x.lower(), N=ngram_n)
|
120 |
+
ret = []
|
121 |
+
for word in words:
|
122 |
+
matched_word = n.find(word.lower(), 0.5)
|
123 |
+
if matched_word is None:
|
124 |
+
return None
|
125 |
+
ret.append(matched_word)
|
126 |
+
return ret
|
127 |
+
|
128 |
+
|
129 |
+
def extract_words(txt, words):
|
130 |
+
for word in words:
|
131 |
+
if word not in txt:
|
132 |
+
return None
|
133 |
+
return [word.lower() for word in words]
|