import argparse import logging import re from datetime import datetime import os import numpy as np import torch from nltk import bleu, meteor from rouge_score.rouge_scorer import RougeScorer from tqdm import tqdm from src.distinct_n.distinct_n.metrics import distinct_n_corpus_level as distinct_n from inductor import BartInductor, CometInductor FILES = { 'amie-yago2': 'data/RE-datasets/AMIE-yago2.txt', 'rules-yago2': 'data/RE-datasets/RuLES-yago2.txt', "openrule155": "data/OpenRule155.txt", 'fewrel': 'data/RE/fewrel-5.txt', 'semeval': 'data/RE/semeval-5.txt', 'TREx': 'data/RE/trex-5.txt', 'nyt10': 'data/RE/nyt10-5.txt', 'google-re': 'data/RE/google-re-5.txt', 'wiki80': 'data/RE/wiki80-5.txt', } if not os.path.exists('logs/'): os.mkdir('logs/') logging.basicConfig( filename='logs/evaluation-{}.log'.format(str(datetime.now())), format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger = logging.getLogger(__name__) def print_config(config): config = vars(config) logger.info("**************** MODEL CONFIGURATION ****************") for key in sorted(config.keys()): val = config[key] keystr = "{}".format(key) + (" " * (25 - len(key))) logger.info("{} --> {}".format(keystr, val)) logger.info("**************** MODEL CONFIGURATION ****************") scorer = RougeScorer(['rougeL'], use_stemmer=True) def rouge(references, hypothesis): scores = [] for reference in references: scores.append( scorer.score( reference, hypothesis)['rougeL'][2] ) return max(scores) class RelationExtractionEvaluator(object): def __init__(self, args): self.args = args if self.args.inductor == 'rule': self.inductor = BartInductor( group_beam=self.args.group_beam, continue_pretrain_instance_generator=self.args.mlm_training, continue_pretrain_hypo_generator=self.args.bart_training, if_then=self.args.if_then, ) elif self.args.inductor == 'comet': self.inductor = CometInductor() def clean(self, text): segments = text.split('') if len(segments) == 3 and segments[2].startswith('.'): return ''.join(segments[:2]) + '.' else: return text def clean_references(self, texts): for i, text in enumerate(texts): if text.endswith(" ."): texts[i] = text.replace(" .", ".") return texts def self_bleu(self, hypothesis): bleus = [] for i in range(len(hypothesis)): bleus.append(bleu( hypothesis[:i] + hypothesis[i + 1:], hypothesis[i], weights=(0.5, 0.5))) ret = np.mean(bleus) return ret def evaluate(self, task): with torch.no_grad(): self.metrics = { "bleu-4": [], "bleu-3": [], "bleu-2": [], "bleu-1": [], "METEOR": [], "ROUGE-L": [], "self-BLEU-2": [], } with open(FILES[task], 'r', encoding='utf-8') as file: data = file.readlines() with tqdm(total=len(data)) as pbar: for row in data: pbar.update(1) row = row.strip().split('\t') inputs, head, tail, relations = row[0], row[1], row[2], row[3] inputs = inputs.strip() if relations.startswith('[') and relations.endswith(']'): inputs = re.sub("|", "", inputs) references = [relation.replace('', '').replace('', '').lower().strip() for relation in eval(relations)] else: references = [relations.replace('[X]', '').replace('[Y]', '').lower().strip()] references = self.clean_references(references) hypothesis = self.inductor.generate(inputs, k=10, topk=10) logger.info("***********Input************") logger.info(inputs) logger.info("*********Hypothesis*********") for i, hypo in enumerate(hypothesis): hypothesis[i] = self.clean(hypo.lower().strip()) logger.info(hypo) logger.info("****************************") logger.info("*********References*********") logger.info(references) logger.info("****************************") if len(hypothesis) == 0: for k in self.metrics.keys(): if k != 'self-BLEU-2': self.metrics[k].append(0.) else: for hypo in hypothesis: try: self.metrics['bleu-4'].append( bleu( [reference.split() for reference in references], hypo.split(), weights=(0.25, 0.25, 0.25, 0.25) ) ) except Exception: logger.warning("Skip bleu-4 in example: {}".format(inputs)) pass try: self.metrics['bleu-3'].append( bleu( [reference.split() for reference in references], hypo.split(), weights=(1 / 3, ) * 3 ) ) except Exception: logger.warning("Skip bleu-3 in example: {}".format(inputs)) pass try: self.metrics['bleu-2'].append( bleu( [reference.split() for reference in references], hypo.split(), weights=(0.5, 0.5) ) ) except Exception: logger.warning("Skip bleu-2 in example: {}".format(inputs)) pass try: self.metrics['bleu-1'].append( bleu( [reference.split() for reference in references], hypo.split(), weights=(1.0, ) ) ) except Exception: logger.warning("Skip bleu-1 in example: {}".format(inputs)) pass try: self.metrics['METEOR'].append( meteor( references, hypo, ) ) except: logger.warning("Skip METEOR in example: {}".format(inputs)) pass try: self.metrics['ROUGE-L'].append( rouge( references, hypo, ) ) except: logger.warning("Skip ROUGE-L in example: {}".format(inputs)) pass try: self.metrics['self-BLEU-2'].append( self.self_bleu( hypothesis, ) ) except: logger.warning("Skip self-bleu-2 in example: {}.".format(inputs)) pass # break self.print(task, self.metrics) def print(self, task, metrics): logger.info("Task: {}".format(str(task))) for k, v in metrics.items(): logger.info("{}: {}".format(k, str(np.mean(v)))) logger.info("*******************************************************") logger.info("*******************************************************") logger.info("*******************************************************") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--inductor", type=str, default='rule') parser.add_argument("--group_beam", type=bool, default=False) parser.add_argument("--mlm_training", type=bool, default=False) parser.add_argument("--bart_training", type=bool, default=False) parser.add_argument("--if_then", type=bool, default=False) parser.add_argument("--task", type=str, default='openrule155') args = parser.parse_args() print_config(args) evaluator = RelationExtractionEvaluator(args) evaluator.evaluate(args.task)