# -*- coding:utf-8 -*- """ @Author : Bao @Date : 2020/8/24 @Desc : @Last modified by : Bao @Last modified date : 2020/9/1 """ import os import json import numpy as np from collections import defaultdict import tensorflow as tf from sklearn.metrics import precision_recall_fscore_support try: from .scorer import fever_score except: from scorer import fever_score prefix = os.environ['PJ_HOME'] class FeverScorer: def __init__(self): self.id2label = {2: 'SUPPORTS', 0: 'REFUTES', 1: 'NOT ENOUGH INFO'} self.label2id = {value: key for key, value in self.id2label.items()} def get_scores(self, predicted_file, actual_file=f'{prefix}/data/fever/shared_task_dev.jsonl'): id2results = defaultdict(dict) with tf.io.gfile.GFile(predicted_file) as f: for line in f: js = json.loads(line) guid = js['id'] id2results[guid] = js with tf.io.gfile.GFile(actual_file) as fin: for line in fin: line = json.loads(line) guid = line['id'] evidence = line['evidence'] label = line['label'] id2results[guid]['evidence'] = evidence id2results[guid]['label'] = label results = self.label_score(list(id2results.values())) score, accuracy, precision, recall, f1 = fever_score(list(id2results.values())) results.update({ 'Evidence Precision': precision, 'Evidence Recall': recall, 'Evidence F1': f1, 'FEVER Score': score, 'Label Accuracy': accuracy }) return results def label_score(self, results): truth = np.array([v['label'] for v in results]) prediction = np.array([v['predicted_label'] for v in results]) labels = list(self.label2id.keys()) results = {} p, r, f, _ = precision_recall_fscore_support(truth, prediction, labels=labels) for i, label in enumerate(self.label2id.keys()): results['{} Precision'.format(label)] = p[i] results['{} Recall'.format(label)] = r[i] results['{} F1'.format(label)] = f[i] return results if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument("--predicted_file", '-i', type=str) args = parser.parse_args() scorer = FeverScorer() results = scorer.get_scores(args.predicted_file) print(json.dumps(results, ensure_ascii=False, indent=4))