andreslu commited on
Commit
d34098d
1 Parent(s): f34d241

Delete evaluation.py

Browse files
Files changed (1) hide show
  1. evaluation.py +0 -254
evaluation.py DELETED
@@ -1,254 +0,0 @@
1
- import argparse
2
- import logging
3
- import re
4
- from datetime import datetime
5
- import os
6
-
7
- import numpy as np
8
- import torch
9
- from nltk import bleu, meteor
10
- from rouge_score.rouge_scorer import RougeScorer
11
- from tqdm import tqdm
12
- from src.distinct_n.distinct_n.metrics import distinct_n_corpus_level as distinct_n
13
-
14
- from inductor import BartInductor, CometInductor
15
-
16
- FILES = {
17
- 'amie-yago2': 'data/RE-datasets/AMIE-yago2.txt',
18
- 'rules-yago2': 'data/RE-datasets/RuLES-yago2.txt',
19
- "openrule155": "data/OpenRule155.txt",
20
- 'fewrel': 'data/RE/fewrel-5.txt',
21
- 'semeval': 'data/RE/semeval-5.txt',
22
- 'TREx': 'data/RE/trex-5.txt',
23
- 'nyt10': 'data/RE/nyt10-5.txt',
24
- 'google-re': 'data/RE/google-re-5.txt',
25
- 'wiki80': 'data/RE/wiki80-5.txt',
26
- }
27
-
28
-
29
- if not os.path.exists('logs/'):
30
- os.mkdir('logs/')
31
-
32
- logging.basicConfig(
33
- filename='logs/evaluation-{}.log'.format(str(datetime.now())),
34
- format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
35
- datefmt='%m/%d/%Y %H:%M:%S',
36
- level=logging.INFO)
37
- logger = logging.getLogger(__name__)
38
-
39
-
40
- def print_config(config):
41
- config = vars(config)
42
- logger.info("**************** MODEL CONFIGURATION ****************")
43
- for key in sorted(config.keys()):
44
- val = config[key]
45
- keystr = "{}".format(key) + (" " * (25 - len(key)))
46
- logger.info("{} --> {}".format(keystr, val))
47
- logger.info("**************** MODEL CONFIGURATION ****************")
48
-
49
- scorer = RougeScorer(['rougeL'], use_stemmer=True)
50
-
51
- def rouge(references, hypothesis):
52
- scores = []
53
- for reference in references:
54
- scores.append(
55
- scorer.score(
56
- reference,
57
- hypothesis)['rougeL'][2]
58
- )
59
-
60
- return max(scores)
61
-
62
-
63
- class RelationExtractionEvaluator(object):
64
- def __init__(self, args):
65
- self.args = args
66
- if self.args.inductor == 'rule':
67
- self.inductor = BartInductor(
68
- group_beam=self.args.group_beam,
69
- continue_pretrain_instance_generator=self.args.mlm_training,
70
- continue_pretrain_hypo_generator=self.args.bart_training,
71
- if_then=self.args.if_then,
72
- )
73
- elif self.args.inductor == 'comet':
74
- self.inductor = CometInductor()
75
-
76
- def clean(self, text):
77
- segments = text.split('<mask>')
78
- if len(segments) == 3 and segments[2].startswith('.'):
79
- return '<mask>'.join(segments[:2]) + '<mask>.'
80
- else:
81
- return text
82
-
83
- def clean_references(self, texts):
84
- for i, text in enumerate(texts):
85
- if text.endswith(" ."):
86
- texts[i] = text.replace(" .", ".")
87
-
88
- return texts
89
-
90
- def self_bleu(self, hypothesis):
91
- bleus = []
92
- for i in range(len(hypothesis)):
93
- bleus.append(bleu(
94
- hypothesis[:i] + hypothesis[i + 1:],
95
- hypothesis[i],
96
- weights=(0.5, 0.5)))
97
-
98
- ret = np.mean(bleus)
99
- return ret
100
-
101
- def evaluate(self, task):
102
- with torch.no_grad():
103
- self.metrics = {
104
- "bleu-4": [],
105
- "bleu-3": [],
106
- "bleu-2": [],
107
- "bleu-1": [],
108
- "METEOR": [],
109
- "ROUGE-L": [],
110
- "self-BLEU-2": [],
111
- }
112
- with open(FILES[task], 'r', encoding='utf-8') as file:
113
- data = file.readlines()
114
- with tqdm(total=len(data)) as pbar:
115
- for row in data:
116
- pbar.update(1)
117
- row = row.strip().split('\t')
118
- inputs, head, tail, relations = row[0], row[1], row[2], row[3]
119
- inputs = inputs.strip()
120
-
121
- if relations.startswith('[') and relations.endswith(']'):
122
- inputs = re.sub("<A>|<B>", "<mask>", inputs)
123
- references = [relation.replace('<A>', '<mask>').replace('<B>', '<mask>').lower().strip() for relation in eval(relations)]
124
- else:
125
- references = [relations.replace('[X]', '<mask>').replace('[Y]', '<mask>').lower().strip()]
126
- references = self.clean_references(references)
127
- hypothesis = self.inductor.generate(inputs, k=10, topk=10)
128
-
129
- logger.info("***********Input************")
130
- logger.info(inputs)
131
- logger.info("*********Hypothesis*********")
132
- for i, hypo in enumerate(hypothesis):
133
- hypothesis[i] = self.clean(hypo.lower().strip())
134
- logger.info(hypo)
135
-
136
- logger.info("****************************")
137
- logger.info("*********References*********")
138
- logger.info(references)
139
- logger.info("****************************")
140
-
141
- if len(hypothesis) == 0:
142
- for k in self.metrics.keys():
143
- if k != 'self-BLEU-2':
144
- self.metrics[k].append(0.)
145
-
146
- else:
147
- for hypo in hypothesis:
148
- try:
149
- self.metrics['bleu-4'].append(
150
- bleu(
151
- [reference.split() for reference in references],
152
- hypo.split(),
153
- weights=(0.25, 0.25, 0.25, 0.25)
154
- )
155
- )
156
- except Exception:
157
- logger.warning("Skip bleu-4 in example: {}".format(inputs))
158
- pass
159
-
160
- try:
161
- self.metrics['bleu-3'].append(
162
- bleu(
163
- [reference.split() for reference in references],
164
- hypo.split(),
165
- weights=(1 / 3, ) * 3
166
- )
167
- )
168
- except Exception:
169
- logger.warning("Skip bleu-3 in example: {}".format(inputs))
170
- pass
171
-
172
- try:
173
- self.metrics['bleu-2'].append(
174
- bleu(
175
- [reference.split() for reference in references],
176
- hypo.split(),
177
- weights=(0.5, 0.5)
178
- )
179
- )
180
- except Exception:
181
- logger.warning("Skip bleu-2 in example: {}".format(inputs))
182
- pass
183
-
184
- try:
185
- self.metrics['bleu-1'].append(
186
- bleu(
187
- [reference.split() for reference in references],
188
- hypo.split(),
189
- weights=(1.0, )
190
- )
191
- )
192
- except Exception:
193
- logger.warning("Skip bleu-1 in example: {}".format(inputs))
194
- pass
195
-
196
- try:
197
- self.metrics['METEOR'].append(
198
- meteor(
199
- references,
200
- hypo,
201
- )
202
- )
203
- except:
204
- logger.warning("Skip METEOR in example: {}".format(inputs))
205
- pass
206
-
207
-
208
- try:
209
- self.metrics['ROUGE-L'].append(
210
- rouge(
211
- references,
212
- hypo,
213
- )
214
- )
215
- except:
216
- logger.warning("Skip ROUGE-L in example: {}".format(inputs))
217
- pass
218
- try:
219
- self.metrics['self-BLEU-2'].append(
220
- self.self_bleu(
221
- hypothesis,
222
- )
223
- )
224
- except:
225
- logger.warning("Skip self-bleu-2 in example: {}.".format(inputs))
226
- pass
227
- # break
228
-
229
- self.print(task, self.metrics)
230
-
231
- def print(self, task, metrics):
232
- logger.info("Task: {}".format(str(task)))
233
- for k, v in metrics.items():
234
- logger.info("{}: {}".format(k, str(np.mean(v))))
235
-
236
- logger.info("*******************************************************")
237
- logger.info("*******************************************************")
238
- logger.info("*******************************************************")
239
-
240
-
241
- if __name__ == '__main__':
242
- parser = argparse.ArgumentParser()
243
- parser.add_argument("--inductor", type=str, default='rule')
244
- parser.add_argument("--group_beam", type=bool, default=False)
245
- parser.add_argument("--mlm_training", type=bool, default=False)
246
- parser.add_argument("--bart_training", type=bool, default=False)
247
- parser.add_argument("--if_then", type=bool, default=False)
248
- parser.add_argument("--task", type=str, default='openrule155')
249
-
250
- args = parser.parse_args()
251
-
252
- print_config(args)
253
- evaluator = RelationExtractionEvaluator(args)
254
- evaluator.evaluate(args.task)