OpenSLU / common /metric.py
LightChen2333's picture
Upload 78 files
223340a
raw
history blame
14.2 kB
'''
Author: Qiguang Chen
Date: 2023-01-11 10:39:26
LastEditors: Qiguang Chen
LastEditTime: 2023-02-17 19:39:22
Description: Metric calculation class
'''
from collections import Counter
from typing import List, Dict
import numpy as np
from sklearn.metrics import f1_score
from common.utils import InputData, OutputData
class Evaluator(object):
"""Evaluation metric funtions library class
supported metric:
- slot_f1
- intent_acc
- exactly_match_accuracy
- intent_f1 (defult "macro_intent_f1")
- macro_intent_f1
- micro_intent_f1=
"""
@staticmethod
def exactly_match_accuracy(pred_slot: List[List[str or int]],
real_slot: List[List[str or int]],
pred_intent: List[List[str or int] or str or int],
real_intent: List[List[str or int] or str or int]) -> float:
"""Compute the accuracy based on the whole predictions of given sentence, including slot and intent.
(both support str or int index as the representation of slot and intent)
Args:
pred_slot (List[List[str or int]]): predicted sequence of slot list
real_slot (List[List[str or int]]): golden sequence of slot list.
pred_intent (List[List[str or int] or str or int]): golden intent list / golden multi intent list.
real_intent (List[List[str or int] or str or int]): predicted intent list / predicted multi intent list.
Returns:
float: exactly match accuracy score
"""
total_count, correct_count = 0.0, 0.0
for p_slot, r_slot, p_intent, r_intent in zip(pred_slot, real_slot, pred_intent, real_intent):
if isinstance(p_intent, list):
p_intent, r_intent = set(p_intent), set(r_intent)
if p_slot == r_slot and p_intent == r_intent:
correct_count += 1.0
total_count += 1.0
return 1.0 * correct_count / total_count
@staticmethod
def intent_accuracy(pred_list: List, real_list: List) -> float:
"""Get intent accuracy measured by predictions and ground-trues. Support both multi intent and single intent.
Args:
pred_list (List): predicted intent list
real_list (List): golden intent list
Returns:
float: intent accuracy score
"""
total_count, correct_count = 0.0, 0.0
for p_intent, r_intent in zip(pred_list, real_list):
if isinstance(p_intent, list):
p_intent, r_intent = set(p_intent), set(r_intent)
if p_intent == r_intent:
correct_count += 1.0
total_count += 1.0
return 1.0 * correct_count / total_count
@staticmethod
def intent_f1(pred_list: List[List[int]], real_list: List[List[int]], num_intent: int, average='macro') -> float:
"""Get intent accuracy measured by predictions and ground-trues. Support both multi intent and single intent.
(Only support multi intent now, but you can use [[intent1], [intent2], ...] to compute intent f1 in single intent)
Args:
pred_list (List[List[int]]): predicted multi intent list.
real_list (List[List[int]]): golden multi intent list.
num_intent (int)
average (str): support "micro" and "macro"
Returns:
float: intent accuracy score
"""
return f1_score(Evaluator.__instance2onehot(num_intent, real_list),
Evaluator.__instance2onehot(num_intent, pred_list),
average=average,
zero_division=0)
@staticmethod
def __multilabel2one_hot(labels, nums):
res = [0.] * nums
if len(labels) == 0:
return res
if isinstance(labels[0], list):
for label in labels[0]:
res[label] = 1.
return res
for label in labels:
res[label] = 1.
return res
@staticmethod
def __instance2onehot(num_intent, data):
res = []
for intents in data:
res.append(Evaluator.__multilabel2one_hot(intents, num_intent))
return np.array(res)
@staticmethod
def __startOfChunk(prevTag, tag, prevTagType, tagType, chunkStart=False):
if prevTag == 'B' and tag == 'B':
chunkStart = True
if prevTag == 'I' and tag == 'B':
chunkStart = True
if prevTag == 'O' and tag == 'B':
chunkStart = True
if prevTag == 'O' and tag == 'I':
chunkStart = True
if prevTag == 'E' and tag == 'E':
chunkStart = True
if prevTag == 'E' and tag == 'I':
chunkStart = True
if prevTag == 'O' and tag == 'E':
chunkStart = True
if prevTag == 'O' and tag == 'I':
chunkStart = True
if tag != 'O' and tag != '.' and prevTagType != tagType:
chunkStart = True
return chunkStart
@staticmethod
def __endOfChunk(prevTag, tag, prevTagType, tagType, chunkEnd=False):
if prevTag == 'B' and tag == 'B':
chunkEnd = True
if prevTag == 'B' and tag == 'O':
chunkEnd = True
if prevTag == 'I' and tag == 'B':
chunkEnd = True
if prevTag == 'I' and tag == 'O':
chunkEnd = True
if prevTag == 'E' and tag == 'E':
chunkEnd = True
if prevTag == 'E' and tag == 'I':
chunkEnd = True
if prevTag == 'E' and tag == 'O':
chunkEnd = True
if prevTag == 'I' and tag == 'O':
chunkEnd = True
if prevTag != 'O' and prevTag != '.' and prevTagType != tagType:
chunkEnd = True
return chunkEnd
@staticmethod
def __splitTagType(tag):
s = tag.split('-')
if len(s) > 2 or len(s) == 0:
raise ValueError('tag format wrong. it must be B-xxx.xxx')
if len(s) == 1:
tag = s[0]
tagType = ""
else:
tag = s[0]
tagType = s[1]
return tag, tagType
@staticmethod
def computeF1Score(correct_slots: List[List[str]], pred_slots: List[List[str]]) -> float:
"""compute f1 score is modified from conlleval.pl
Args:
correct_slots (List[List[str]]): golden slot string list
pred_slots (List[List[str]]): predicted slot string list
Returns:
float: slot f1 score
"""
correctChunk = {}
correctChunkCnt = 0.0
foundCorrect = {}
foundCorrectCnt = 0.0
foundPred = {}
foundPredCnt = 0.0
correctTags = 0.0
tokenCount = 0.0
for correct_slot, pred_slot in zip(correct_slots, pred_slots):
inCorrect = False
lastCorrectTag = 'O'
lastCorrectType = ''
lastPredTag = 'O'
lastPredType = ''
for c, p in zip(correct_slot, pred_slot):
c = str(c)
p = str(p)
correctTag, correctType = Evaluator.__splitTagType(c)
predTag, predType = Evaluator.__splitTagType(p)
if inCorrect == True:
if Evaluator.__endOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) == True and \
Evaluator.__endOfChunk(lastPredTag, predTag, lastPredType, predType) == True and \
(lastCorrectType == lastPredType):
inCorrect = False
correctChunkCnt += 1.0
if lastCorrectType in correctChunk:
correctChunk[lastCorrectType] += 1.0
else:
correctChunk[lastCorrectType] = 1.0
elif Evaluator.__endOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) != \
Evaluator.__endOfChunk(lastPredTag, predTag, lastPredType, predType) or \
(correctType != predType):
inCorrect = False
if Evaluator.__startOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) == True and \
Evaluator.__startOfChunk(lastPredTag, predTag, lastPredType, predType) == True and \
(correctType == predType):
inCorrect = True
if Evaluator.__startOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) == True:
foundCorrectCnt += 1
if correctType in foundCorrect:
foundCorrect[correctType] += 1.0
else:
foundCorrect[correctType] = 1.0
if Evaluator.__startOfChunk(lastPredTag, predTag, lastPredType, predType) == True:
foundPredCnt += 1.0
if predType in foundPred:
foundPred[predType] += 1.0
else:
foundPred[predType] = 1.0
if correctTag == predTag and correctType == predType:
correctTags += 1.0
tokenCount += 1.0
lastCorrectTag = correctTag
lastCorrectType = correctType
lastPredTag = predTag
lastPredType = predType
if inCorrect == True:
correctChunkCnt += 1.0
if lastCorrectType in correctChunk:
correctChunk[lastCorrectType] += 1.0
else:
correctChunk[lastCorrectType] = 1.0
if foundPredCnt > 0:
precision = 1.0 * correctChunkCnt / foundPredCnt
else:
precision = 0
if foundCorrectCnt > 0:
recall = 1.0 * correctChunkCnt / foundCorrectCnt
else:
recall = 0
if (precision + recall) > 0:
f1 = (2.0 * precision * recall) / (precision + recall)
else:
f1 = 0
return f1
@staticmethod
def max_freq_predict(sample):
"""Max frequency prediction.
"""
predict = []
for items in sample:
predict.append(Counter(items).most_common(1)[0][0])
return predict
@staticmethod
def __token_map(indexes, token_label_map):
return [[token_label_map[idx] if idx in token_label_map else -1 for idx in index] for index in indexes]
@staticmethod
def compute_all_metric(inps: InputData,
output: OutputData,
intent_label_map: dict = None,
metric_list: List=None)-> Dict:
"""Auto compute all metric mentioned in 'metric_list'
Args:
inps (InputData): input golden slot and intent labels
output (OutputData): output predicted slot and intent labels
intent_label_map (dict, Optional): dict like {"intent1": 0, "intent2": 1, ...},which aims to map intent string to index
metric_list (List): support metrics in ["slot_f1", "intent_acc", "intent_f1", "macro_intent_f1", "micro_intent_f1", "EMA"]
Returns:
Dict: all metric mentioned in 'metric_list', like {'EMA': 0.7, ...}
Example:
if compute slot metric:
inps.slot = [["slot1", "slot2", ...], ...]; output.slot_ids=[["slot1", "slot2", ...], ...];
if compute intent metric:
[Multi Intent] inps.intent = [["intent1", "intent2", ...], ...]; output.intent_ids = [["intent1", "intent2", ...], ...]
[Single Intent] inps.intent = ["intent1", ...]; [Single Intent] output.intent_ids = ["intent1", ...]
"""
if not metric_list:
metric_list = ["slot_f1", "intent_acc", "EMA"]
res_dict = {}
use_slot = output.slot_ids is not None and len(output.slot_ids) > 0
use_intent = output.intent_ids is not None and len(
output.intent_ids) > 0
if use_slot and "slot_f1" in metric_list:
res_dict["slot_f1"] = Evaluator.computeF1Score(
output.slot_ids, inps.slot)
if use_intent and "intent_acc" in metric_list:
res_dict["intent_acc"] = Evaluator.intent_accuracy(
output.intent_ids, inps.intent)
if isinstance(output.intent_ids[0], list):
if "intent_f1" in metric_list:
res_dict["intent_f1"] = Evaluator.intent_f1(Evaluator.__token_map(output.intent_ids, intent_label_map),
Evaluator.__token_map(
inps.intent, intent_label_map),
len(intent_label_map.keys()))
elif "macro_intent_f1" in metric_list:
res_dict["macro_intent_f1"] = Evaluator.intent_f1(Evaluator.__token_map(output.intent_ids, intent_label_map),
Evaluator.__token_map(inps.intent, intent_label_map),
len(intent_label_map.keys()), average="macro")
if "micro_intent_f1" in metric_list:
res_dict["micro_intent_f1"] = Evaluator.intent_f1(Evaluator.__token_map(output.intent_ids, intent_label_map),
Evaluator.__token_map(inps.intent, intent_label_map),
len(intent_label_map.keys()), average="micro")
if use_slot and use_intent and "EMA" in metric_list:
res_dict["EMA"] = Evaluator.exactly_match_accuracy(output.slot_ids, inps.slot, output.intent_ids,
inps.intent)
return res_dict