Spaces:
Runtime error
Runtime error
''' | |
Author: Qiguang Chen | |
Date: 2023-01-11 10:39:26 | |
LastEditors: Qiguang Chen | |
LastEditTime: 2023-02-17 19:39:22 | |
Description: Metric calculation class | |
''' | |
from collections import Counter | |
from typing import List, Dict | |
import numpy as np | |
from sklearn.metrics import f1_score | |
from common.utils import InputData, OutputData | |
class Evaluator(object): | |
"""Evaluation metric funtions library class | |
supported metric: | |
- slot_f1 | |
- intent_acc | |
- exactly_match_accuracy | |
- intent_f1 (defult "macro_intent_f1") | |
- macro_intent_f1 | |
- micro_intent_f1= | |
""" | |
def exactly_match_accuracy(pred_slot: List[List[str or int]], | |
real_slot: List[List[str or int]], | |
pred_intent: List[List[str or int] or str or int], | |
real_intent: List[List[str or int] or str or int]) -> float: | |
"""Compute the accuracy based on the whole predictions of given sentence, including slot and intent. | |
(both support str or int index as the representation of slot and intent) | |
Args: | |
pred_slot (List[List[str or int]]): predicted sequence of slot list | |
real_slot (List[List[str or int]]): golden sequence of slot list. | |
pred_intent (List[List[str or int] or str or int]): golden intent list / golden multi intent list. | |
real_intent (List[List[str or int] or str or int]): predicted intent list / predicted multi intent list. | |
Returns: | |
float: exactly match accuracy score | |
""" | |
total_count, correct_count = 0.0, 0.0 | |
for p_slot, r_slot, p_intent, r_intent in zip(pred_slot, real_slot, pred_intent, real_intent): | |
if isinstance(p_intent, list): | |
p_intent, r_intent = set(p_intent), set(r_intent) | |
if p_slot == r_slot and p_intent == r_intent: | |
correct_count += 1.0 | |
total_count += 1.0 | |
return 1.0 * correct_count / total_count | |
def intent_accuracy(pred_list: List, real_list: List) -> float: | |
"""Get intent accuracy measured by predictions and ground-trues. Support both multi intent and single intent. | |
Args: | |
pred_list (List): predicted intent list | |
real_list (List): golden intent list | |
Returns: | |
float: intent accuracy score | |
""" | |
total_count, correct_count = 0.0, 0.0 | |
for p_intent, r_intent in zip(pred_list, real_list): | |
if isinstance(p_intent, list): | |
p_intent, r_intent = set(p_intent), set(r_intent) | |
if p_intent == r_intent: | |
correct_count += 1.0 | |
total_count += 1.0 | |
return 1.0 * correct_count / total_count | |
def intent_f1(pred_list: List[List[int]], real_list: List[List[int]], num_intent: int, average='macro') -> float: | |
"""Get intent accuracy measured by predictions and ground-trues. Support both multi intent and single intent. | |
(Only support multi intent now, but you can use [[intent1], [intent2], ...] to compute intent f1 in single intent) | |
Args: | |
pred_list (List[List[int]]): predicted multi intent list. | |
real_list (List[List[int]]): golden multi intent list. | |
num_intent (int) | |
average (str): support "micro" and "macro" | |
Returns: | |
float: intent accuracy score | |
""" | |
return f1_score(Evaluator.__instance2onehot(num_intent, real_list), | |
Evaluator.__instance2onehot(num_intent, pred_list), | |
average=average, | |
zero_division=0) | |
def __multilabel2one_hot(labels, nums): | |
res = [0.] * nums | |
if len(labels) == 0: | |
return res | |
if isinstance(labels[0], list): | |
for label in labels[0]: | |
res[label] = 1. | |
return res | |
for label in labels: | |
res[label] = 1. | |
return res | |
def __instance2onehot(num_intent, data): | |
res = [] | |
for intents in data: | |
res.append(Evaluator.__multilabel2one_hot(intents, num_intent)) | |
return np.array(res) | |
def __startOfChunk(prevTag, tag, prevTagType, tagType, chunkStart=False): | |
if prevTag == 'B' and tag == 'B': | |
chunkStart = True | |
if prevTag == 'I' and tag == 'B': | |
chunkStart = True | |
if prevTag == 'O' and tag == 'B': | |
chunkStart = True | |
if prevTag == 'O' and tag == 'I': | |
chunkStart = True | |
if prevTag == 'E' and tag == 'E': | |
chunkStart = True | |
if prevTag == 'E' and tag == 'I': | |
chunkStart = True | |
if prevTag == 'O' and tag == 'E': | |
chunkStart = True | |
if prevTag == 'O' and tag == 'I': | |
chunkStart = True | |
if tag != 'O' and tag != '.' and prevTagType != tagType: | |
chunkStart = True | |
return chunkStart | |
def __endOfChunk(prevTag, tag, prevTagType, tagType, chunkEnd=False): | |
if prevTag == 'B' and tag == 'B': | |
chunkEnd = True | |
if prevTag == 'B' and tag == 'O': | |
chunkEnd = True | |
if prevTag == 'I' and tag == 'B': | |
chunkEnd = True | |
if prevTag == 'I' and tag == 'O': | |
chunkEnd = True | |
if prevTag == 'E' and tag == 'E': | |
chunkEnd = True | |
if prevTag == 'E' and tag == 'I': | |
chunkEnd = True | |
if prevTag == 'E' and tag == 'O': | |
chunkEnd = True | |
if prevTag == 'I' and tag == 'O': | |
chunkEnd = True | |
if prevTag != 'O' and prevTag != '.' and prevTagType != tagType: | |
chunkEnd = True | |
return chunkEnd | |
def __splitTagType(tag): | |
s = tag.split('-') | |
if len(s) > 2 or len(s) == 0: | |
raise ValueError('tag format wrong. it must be B-xxx.xxx') | |
if len(s) == 1: | |
tag = s[0] | |
tagType = "" | |
else: | |
tag = s[0] | |
tagType = s[1] | |
return tag, tagType | |
def computeF1Score(correct_slots: List[List[str]], pred_slots: List[List[str]]) -> float: | |
"""compute f1 score is modified from conlleval.pl | |
Args: | |
correct_slots (List[List[str]]): golden slot string list | |
pred_slots (List[List[str]]): predicted slot string list | |
Returns: | |
float: slot f1 score | |
""" | |
correctChunk = {} | |
correctChunkCnt = 0.0 | |
foundCorrect = {} | |
foundCorrectCnt = 0.0 | |
foundPred = {} | |
foundPredCnt = 0.0 | |
correctTags = 0.0 | |
tokenCount = 0.0 | |
for correct_slot, pred_slot in zip(correct_slots, pred_slots): | |
inCorrect = False | |
lastCorrectTag = 'O' | |
lastCorrectType = '' | |
lastPredTag = 'O' | |
lastPredType = '' | |
for c, p in zip(correct_slot, pred_slot): | |
c = str(c) | |
p = str(p) | |
correctTag, correctType = Evaluator.__splitTagType(c) | |
predTag, predType = Evaluator.__splitTagType(p) | |
if inCorrect == True: | |
if Evaluator.__endOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) == True and \ | |
Evaluator.__endOfChunk(lastPredTag, predTag, lastPredType, predType) == True and \ | |
(lastCorrectType == lastPredType): | |
inCorrect = False | |
correctChunkCnt += 1.0 | |
if lastCorrectType in correctChunk: | |
correctChunk[lastCorrectType] += 1.0 | |
else: | |
correctChunk[lastCorrectType] = 1.0 | |
elif Evaluator.__endOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) != \ | |
Evaluator.__endOfChunk(lastPredTag, predTag, lastPredType, predType) or \ | |
(correctType != predType): | |
inCorrect = False | |
if Evaluator.__startOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) == True and \ | |
Evaluator.__startOfChunk(lastPredTag, predTag, lastPredType, predType) == True and \ | |
(correctType == predType): | |
inCorrect = True | |
if Evaluator.__startOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) == True: | |
foundCorrectCnt += 1 | |
if correctType in foundCorrect: | |
foundCorrect[correctType] += 1.0 | |
else: | |
foundCorrect[correctType] = 1.0 | |
if Evaluator.__startOfChunk(lastPredTag, predTag, lastPredType, predType) == True: | |
foundPredCnt += 1.0 | |
if predType in foundPred: | |
foundPred[predType] += 1.0 | |
else: | |
foundPred[predType] = 1.0 | |
if correctTag == predTag and correctType == predType: | |
correctTags += 1.0 | |
tokenCount += 1.0 | |
lastCorrectTag = correctTag | |
lastCorrectType = correctType | |
lastPredTag = predTag | |
lastPredType = predType | |
if inCorrect == True: | |
correctChunkCnt += 1.0 | |
if lastCorrectType in correctChunk: | |
correctChunk[lastCorrectType] += 1.0 | |
else: | |
correctChunk[lastCorrectType] = 1.0 | |
if foundPredCnt > 0: | |
precision = 1.0 * correctChunkCnt / foundPredCnt | |
else: | |
precision = 0 | |
if foundCorrectCnt > 0: | |
recall = 1.0 * correctChunkCnt / foundCorrectCnt | |
else: | |
recall = 0 | |
if (precision + recall) > 0: | |
f1 = (2.0 * precision * recall) / (precision + recall) | |
else: | |
f1 = 0 | |
return f1 | |
def max_freq_predict(sample): | |
"""Max frequency prediction. | |
""" | |
predict = [] | |
for items in sample: | |
predict.append(Counter(items).most_common(1)[0][0]) | |
return predict | |
def __token_map(indexes, token_label_map): | |
return [[token_label_map[idx] if idx in token_label_map else -1 for idx in index] for index in indexes] | |
def compute_all_metric(inps: InputData, | |
output: OutputData, | |
intent_label_map: dict = None, | |
metric_list: List=None)-> Dict: | |
"""Auto compute all metric mentioned in 'metric_list' | |
Args: | |
inps (InputData): input golden slot and intent labels | |
output (OutputData): output predicted slot and intent labels | |
intent_label_map (dict, Optional): dict like {"intent1": 0, "intent2": 1, ...},which aims to map intent string to index | |
metric_list (List): support metrics in ["slot_f1", "intent_acc", "intent_f1", "macro_intent_f1", "micro_intent_f1", "EMA"] | |
Returns: | |
Dict: all metric mentioned in 'metric_list', like {'EMA': 0.7, ...} | |
Example: | |
if compute slot metric: | |
inps.slot = [["slot1", "slot2", ...], ...]; output.slot_ids=[["slot1", "slot2", ...], ...]; | |
if compute intent metric: | |
[Multi Intent] inps.intent = [["intent1", "intent2", ...], ...]; output.intent_ids = [["intent1", "intent2", ...], ...] | |
[Single Intent] inps.intent = ["intent1", ...]; [Single Intent] output.intent_ids = ["intent1", ...] | |
""" | |
if not metric_list: | |
metric_list = ["slot_f1", "intent_acc", "EMA"] | |
res_dict = {} | |
use_slot = output.slot_ids is not None and len(output.slot_ids) > 0 | |
use_intent = output.intent_ids is not None and len( | |
output.intent_ids) > 0 | |
if use_slot and "slot_f1" in metric_list: | |
res_dict["slot_f1"] = Evaluator.computeF1Score( | |
output.slot_ids, inps.slot) | |
if use_intent and "intent_acc" in metric_list: | |
res_dict["intent_acc"] = Evaluator.intent_accuracy( | |
output.intent_ids, inps.intent) | |
if isinstance(output.intent_ids[0], list): | |
if "intent_f1" in metric_list: | |
res_dict["intent_f1"] = Evaluator.intent_f1(Evaluator.__token_map(output.intent_ids, intent_label_map), | |
Evaluator.__token_map( | |
inps.intent, intent_label_map), | |
len(intent_label_map.keys())) | |
elif "macro_intent_f1" in metric_list: | |
res_dict["macro_intent_f1"] = Evaluator.intent_f1(Evaluator.__token_map(output.intent_ids, intent_label_map), | |
Evaluator.__token_map(inps.intent, intent_label_map), | |
len(intent_label_map.keys()), average="macro") | |
if "micro_intent_f1" in metric_list: | |
res_dict["micro_intent_f1"] = Evaluator.intent_f1(Evaluator.__token_map(output.intent_ids, intent_label_map), | |
Evaluator.__token_map(inps.intent, intent_label_map), | |
len(intent_label_map.keys()), average="micro") | |
if use_slot and use_intent and "EMA" in metric_list: | |
res_dict["EMA"] = Evaluator.exactly_match_accuracy(output.slot_ids, inps.slot, output.intent_ids, | |
inps.intent) | |
return res_dict | |