|
import re
|
|
import os
|
|
from collections import Counter
|
|
import json
|
|
|
|
|
|
class Tag:
|
|
def __init__(self, txt_line:str):
|
|
|
|
|
|
try:
|
|
sep = txt_line.strip().split('\t')
|
|
self.file_id = sep[0]
|
|
self.type = sep[1]
|
|
self.start = sep[2]
|
|
self.end = sep[3]
|
|
self.text = sep[4]
|
|
except:
|
|
raise ValueError('The format of the input line is not correct. Please check the input line format.')
|
|
|
|
def get_type(self):
|
|
return self.type
|
|
|
|
def get_file_id(self):
|
|
return self.file_id
|
|
|
|
def __eq__(self, other: 'Tag'):
|
|
|
|
|
|
ck_file_id = self.file_id == other.file_id
|
|
ck_type = self.type == other.type
|
|
ck_start = self.start == other.start
|
|
ck_end = self.end == other.end
|
|
|
|
if ck_file_id and ck_type and ck_start and ck_end:
|
|
return True
|
|
else:
|
|
return False
|
|
def __repr__(self):
|
|
return f'<{self.__class__.__name__} {self.file_id:10} {self.type:10} s:{self.start:5} e:{self.end:5} {self.text}>\n'
|
|
|
|
def __hash__(self):
|
|
return hash((self.file_id, self.type, self.start, self.end))
|
|
|
|
class Evaluation_answer_txt:
|
|
def __init__(self, gold_answer, pred_answer):
|
|
self.gold_answer = gold_answer
|
|
self.pred_answer = pred_answer
|
|
|
|
self.gold_set = set()
|
|
self.pred_set = set()
|
|
|
|
self.type_set = set()
|
|
self.gold_label_counter = Counter()
|
|
|
|
self.resault_score = {}
|
|
|
|
def _lines_to_tag_set(self, lines, set_type):
|
|
tags = []
|
|
for i in range(len(lines)):
|
|
try:
|
|
tag = Tag(lines[i])
|
|
tags.append(tag)
|
|
except:
|
|
print(f'Error at {set_type} answer line: {i+1}, {lines[i]}')
|
|
return set(tags)
|
|
|
|
def _set_filter(self, tag_set, type):
|
|
|
|
return {tag for tag in tag_set if tag.get_type() == type}
|
|
|
|
def _division(self, a, b):
|
|
try:
|
|
return a / b
|
|
except:
|
|
return 0.0
|
|
|
|
def _f1_score(self, TP=None, FP=None, FN=None):
|
|
if TP is None or FP is None or FN is None:
|
|
raise ValueError('TP, FP, FN should be given.')
|
|
|
|
precision = self._division(TP, TP + FP)
|
|
recall = self._division(TP, TP + FN)
|
|
f1 = self._division(2 * precision * recall, precision + recall)
|
|
|
|
return {'precision': precision, 'recall': recall, 'f1': f1}
|
|
|
|
|
|
def eval(self, ignore_no_gold_tag_file=True):
|
|
with open(self.gold_answer, 'r') as f:
|
|
gold_line = f.readlines()
|
|
|
|
|
|
|
|
if isinstance(self.pred_answer, str):
|
|
with open(self.pred_answer, 'r') as f:
|
|
pred_line = f.readlines()
|
|
|
|
|
|
else:
|
|
pred_line = self.pred_answer.readlines()
|
|
|
|
pred_line = [line.decode('utf-8') for line in pred_line]
|
|
|
|
self.gold_set = self._lines_to_tag_set(gold_line, 'gold')
|
|
self.pred_set = self._lines_to_tag_set(pred_line, 'pred')
|
|
|
|
|
|
|
|
if ignore_no_gold_tag_file:
|
|
|
|
gold_files = {tag.get_file_id() for tag in self.gold_set}
|
|
self.pred_set = {tag for tag in self.pred_set if tag.get_file_id() in gold_files}
|
|
|
|
|
|
for tag in self.gold_set:
|
|
self.type_set.add(tag.get_type())
|
|
self.gold_label_counter[tag.get_type()] += 1
|
|
for tag in self.pred_set:
|
|
self.type_set.add(tag.get_type())
|
|
|
|
TP_set = self.gold_set & self.pred_set
|
|
FP_set = self.pred_set - self.gold_set
|
|
FN_set = self.gold_set - self.pred_set
|
|
|
|
|
|
for label in self.type_set:
|
|
filter_TP = self._set_filter(TP_set, label)
|
|
filter_FP = self._set_filter(FP_set, label)
|
|
filter_FN = self._set_filter(FN_set, label)
|
|
score = self._f1_score(len(filter_TP), len(filter_FP), len(filter_FN))
|
|
self.resault_score[label] = score
|
|
|
|
|
|
self.resault_score['MICRO_AVERAGE'] = self._f1_score(len(TP_set), len(FP_set), len(FN_set))
|
|
|
|
|
|
precision_sum = 0
|
|
recall_sum = 0
|
|
|
|
for label in self.type_set:
|
|
precision_sum += self.resault_score[label]['precision']
|
|
recall_sum += self.resault_score[label]['recall']
|
|
|
|
|
|
precision = self._division(precision_sum, len(self.type_set))
|
|
recall = self._division(recall_sum, len(self.type_set))
|
|
|
|
f1 = self._division(2 * precision * recall , (precision + recall))
|
|
|
|
self.resault_score['MACRO_AVERAGE'] = {'precision': precision, 'recall': recall, 'f1': f1}
|
|
|
|
|
|
for label in self.type_set:
|
|
self.resault_score[label]['support'] = self.gold_label_counter[label]
|
|
self.resault_score['MICRO_AVERAGE']['support'] = len(self.gold_set)
|
|
self.resault_score['MACRO_AVERAGE']['support'] = len(self.gold_set)
|
|
|
|
|
|
return self.resault_score
|
|
|
|
|
|
if __name__=="__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gold_path = 'dataset/Setting3_test_answer.txt'
|
|
pred_path = '.output/[meta-llama@Llama-2-7b-hf][Setting3][icl]answer.txt'
|
|
|
|
|
|
eval = Evaluation_answer_txt(gold_path, pred_path)
|
|
res = eval.eval()
|
|
print(res) |