# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """TODO: Add a description here.""" import copy import re from typing import List, Dict, Union, Callable import numpy as np import datasets import evaluate from rouge_chinese import Rouge from scipy.optimize import linear_sum_assignment # TODO: Add BibTeX citation _CITATION = """\ @InProceedings{huggingface:module, title = {quad match score}, authors={huggingface, Inc.}, year={2020} } """ # TODO: Add description of the module here _DESCRIPTION = """\ evaluate sentiment quadruples. 评估生成模型的情感四元组 """ # TODO: Add description of the arguments of the module here _KWARGS_DESCRIPTION = """ Calculates how good are predictions given some references, using certain scores Args: predictions: list of predictions to score. Each predictions should be a string with tokens separated by spaces. references: list of reference for each prediction. Each reference should be a string with tokens separated by spaces. Returns: score: sentiment quadruple match score Examples: Examples should be written in doctest format, and should illustrate how to use the function. >>> import evaluate >>> module = evaluate.load("yuyijiong/quad_match_score") >>> predictions=["food | good | food#taste | pos"] >>> references=["food | good | food#taste | pos & service | bad | service#general | neg"] >>> result=module.compute(predictions=predictions, references=references) >>> print(result) result={'ave match score of weight (1, 1, 1, 1)': 0.375, 'f1 score of exact match': 0.0, 'f1 score of optimal match of weight (1, 1, 1, 1)': 0.5} """ # 计算rougel的f1值 def get_rougel_f1(text_pred_list: List[str], text_true_list: List[str]) -> float: assert len(text_pred_list) == len(text_true_list), "文本数量不一致" # 如果text_pred_list[0]为空字符串或空格,则返回0 if not text_pred_list[0].strip(): return 0 rouge = Rouge() # 判断text_true[0]是否有中文,有中文则要用空格分割 if re.search(u"[\u4e00-\u9fa5]+", text_pred_list[0]): text_pred_list = [' '.join(list(text_pred)) for text_pred in text_pred_list] text_true_list = [' '.join(list(text_true)) for text_true in text_true_list] rouge_l_f1 = rouge.get_scores(text_pred_list, text_true_list, avg=True)['rouge-l']['f'] return rouge_l_f1 # 记录四元组的函数 class CommentUnitsSim: def __init__(self, data: List[Dict[str, str]], data_source: any = None, abnormal=False, language=None): self.data_source = data_source self.abnormal = abnormal data = copy.deepcopy(data) # 如果字典有target,则改名为target_text for quad_dict in data: if 'target' in quad_dict: quad_dict['target_text'] = quad_dict['target'] del quad_dict['target'] if 'opinion' in quad_dict: quad_dict['opinion_text'] = quad_dict['opinion'] del quad_dict['opinion'] self.data = data self.polarity_en2zh = {'positive': '积极', 'negative': '消极', 'neutral': '中性', 'pos': '积极', 'neg': '消极', 'neu': '中性', '积极': '积极', '消极': '消极', '中性': '中性'} self.polarity_zh2en = {'积极': 'pos', '消极': 'neg', '中性': 'neu', 'pos': 'pos', 'neg': 'neg', 'neu': 'neu', 'positive': 'pos', 'negative': 'neg', 'neutral': 'neu'} self.language = language if language is not None else 'zh' if self.check_zh() else 'en' self.none_sign = 'null' @property def num(self): return len(self.data) # 检查四元组中是否有中文 def check_zh(self): for quad_dict in self.data: if re.search('[\u4e00-\u9fa5]', quad_dict['target_text']) or re.search('[\u4e00-\u9fa5]', quad_dict['opinion_text']): return True return False # 检测极性是否正确 def check_polarity(self): # 若有某个四元组的极性不是positive、negative、neutral,则返回False for quad_dict in self.data: if quad_dict['polarity'] not in ['positive', 'negative', 'neutral', 'pos', 'neg', 'neu', '积极', '消极', '中性']: self.abnormal = True return False # 将极性由英文转为中文 def convert_polarity_en2zh(self): for quad_dict in self.data: quad_dict['polarity'] = self.polarity_en2zh[quad_dict['polarity']] return self # 将极性由中文转为英文 def convert_polarity_zh2en(self): for quad_dict in self.data: quad_dict['polarity'] = self.polarity_zh2en[quad_dict['polarity']] return self # 检查是否有重复的四元组,若有则删除重复的 def del_duplicate(self): new_data = [] for quad_dict in self.data: if quad_dict not in new_data: new_data.append(quad_dict) self.data = new_data return self # 检查是否有target和opinion都为null的四元组,若有则返回True def check_target_opinion_null(self): for quad_dict in self.data: if quad_dict['target_text'] == 'null' and quad_dict['opinion_text'] == 'null': return True return False # 检查是否有target或opinion为null的四元组,若有则返回True def check_any_null(self): for quad_dict in self.data: if quad_dict['target_text'] == 'null' or quad_dict['opinion_text'] == 'null': return True return False @classmethod def from_str(cls, quadruple_str: str, tuple_len: Union[int, list, str] = 4, format_code=0, sep_token1=' & ', sep_token2=' | '): data = [] abnormal = False # 确保分隔符后面一定是空格 for i in range(len(quadruple_str) - 1): if (quadruple_str[i] == sep_token1.strip() or quadruple_str[i] == sep_token2.strip()) and quadruple_str[ i + 1] != ' ': quadruple_str = quadruple_str[:i + 1] + ' ' + quadruple_str[i + 1:] # 选择几元组,即创建列表索引,从四元组中抽出n元 if isinstance(tuple_len, int): tuple_index = list(range(tuple_len)) elif isinstance(tuple_len, list): tuple_index = tuple_len elif isinstance(tuple_len, str): # 例如将‘012’转换为[0,1,2] tuple_index = [int(i) for i in tuple_len] else: raise Exception('tuple_len参数错误') for quadruple in quadruple_str.split(sep_token1): if format_code == 0: # quadruple可能是target|opinion|aspect|polarity,也可能是target|opinion|aspect,也可能是target|opinion,若没有则为“None” quadruple_split = [unit.strip() for unit in quadruple.split(sep_token2)] if len(quadruple_split) > len(tuple_index): print('quadruple格式错误,过多元素', quadruple_str) abnormal = True quadruple_split = quadruple_split[0:len(tuple_index)] # 过长则截断 elif len(quadruple_split) < len(tuple_index): print('quadruple格式错误,过少元素', quadruple_str) abnormal = True quadruple_split = ["None"] * ( len(tuple_index) - len(quadruple_split)) + quadruple_split # 过短则补'None' quadruple_keys = [["target_text", "opinion_text", "aspect", "polarity"][i] for i in tuple_index] quadruple_dict = dict(zip(quadruple_keys, quadruple_split)) q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'} q.update(quadruple_dict) # 检查极性是否合法 if q['polarity'] not in ['pos', 'neg', 'neu', 'None', '积极', '消极', '中性']: print('quadruple格式错误,极性格式不对', quadruple_str) else: raise Exception('answer_format参数错误') data.append(q) return CommentUnitsSim(data, quadruple_str, abnormal) @classmethod def from_list(cls, quadruple_list: List[List[str]], **kwargs): data = [] for quadruple in quadruple_list: # #format_code='013'代表list只有四元组的第0、1、3个元素,需要扩充为4元组,空缺位置补上None # if format_code=='013': # quadruple.insert(2,None) data.append( {"target_text": quadruple[0], "opinion_text": quadruple[1], "aspect": quadruple[2], "polarity": quadruple[3]}) return CommentUnitsSim(data, quadruple_list, **kwargs) @classmethod def from_list_dict(cls, quadruple_list: List[dict], **kwargs): for quad_dict in quadruple_list: if 'target' in quad_dict: quad_dict['target_text'] = quad_dict['target'] del quad_dict['target'] if 'opinion' in quad_dict: quad_dict['opinion_text'] = quad_dict['opinion'] del quad_dict['opinion'] data = [] for quadruple in quadruple_list: # 如果quadruple缺少某个key,则补上None q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'} q.update(quadruple) data.append(q) return CommentUnitsSim(data, quadruple_list, **kwargs) # 转化为list,即只保留字典的value def to_list(self): data = [] for quad_dict in self.data: data.append( [quad_dict['target_text'], quad_dict['opinion_text'], quad_dict['aspect'], quad_dict['polarity']]) return data # 将data转换为n元组字符串 def get_quadruple_str(self, format_code=0, tuple_len: Union[int, list, str] = 4, sep_token1=' & ', sep_token2=' | '): new_text_list = [] # 选择几元组,即创建列表索引,从四元组中抽出n元 if isinstance(tuple_len, int): tuple_index = list(range(tuple_len)) elif isinstance(tuple_len, list): tuple_index = tuple_len elif isinstance(tuple_len, str): # 例如将‘012’转换为[0,1,2] tuple_index = [int(i) for i in tuple_len] else: raise Exception('tuple_len参数错误') try: # 若语言为中文,则使用中文极性 if self.language == 'zh': self.convert_polarity_en2zh() else: self.convert_polarity_zh2en() except: print('语言参数错误', self.data) print(self.language) raise Exception('语言参数错误') # 若tuple_index==[3],则返回综合情感极性 if tuple_index == [3]: return self.merge_polarity() for quad_dict in self.data: # 提取target_text,如果空列表则为'',如果列表长度大于1则为','.join(list) target_text = quad_dict['target_text'] # 提取opinion_text,如果空列表则为'',如果列表长度大于1则为','.join(list) opinion_text = quad_dict['opinion_text'] # 提取aspect aspect = quad_dict['aspect'] # 提取polarity polarity = quad_dict['polarity'] # 拼接,‘|’分割 if format_code == 0: # 根据tuple_len拼接 new_text = sep_token2.join([[target_text, opinion_text, aspect, polarity][i] for i in tuple_index]) else: raise Exception('answer_format参数错误') new_text_list.append(new_text) # 如果tuple_index为[2,3],则需要去除new_text_list中重复的元素,不要改变顺序。因为可能有重复的方面 if tuple_index == [2, 3]: res = [] for t in new_text_list: if t not in res: res.append(t) new_text_list = res # 如果tuple_index为[3],则只保留new_text_list的第一个元素。因为只有一个情感极性 elif tuple_index == [3]: new_text_list = new_text_list[:1] if format_code == 0: # 根据tuple_len拼接 return sep_token1.join(new_text_list) # 与另一个CommentUnits对象对比,检测有几个相同的四元组 def compare_same(self, other) -> int: count = 0 for quad_dict in self.data: if quad_dict in other.data: count += 1 return count # 检查自身数据的四元组中target是否有重复 def check_target_repeat(self): target_list = [] for quad_dict in self.data: target_list.append(quad_dict['target_text']) return len(target_list) != len(set(target_list)) # 检查自身数据的四元组中opinion是否有重复 def check_opinion_repeat(self): opinion_list = [] for quad_dict in self.data: opinion_list.append(quad_dict['opinion_text']) return len(opinion_list) != len(set(opinion_list)) # 检查自身数据的四元组中aspect是否有重复 def check_aspect_repeat(self): aspect_list = [] for quad_dict in self.data: aspect_list.append(quad_dict['aspect']) return len(aspect_list) != len(set(aspect_list)) # 输出所有aspect的列表 def get_aspect_list(self): aspect_list = [] for quad_dict in self.data: aspect_list.append(quad_dict['aspect']) return aspect_list # 输出所有target的列表 def get_target_list(self): target_list = [] for quad_dict in self.data: target_list.append(quad_dict['target_text']) return target_list # 输出所有opinion的列表 def get_opinion_list(self): opinion_list = [] for quad_dict in self.data: opinion_list.append(quad_dict['opinion_text']) return opinion_list # 输出所有polarity的列表 def get_polarity_list(self): polarity_list = [] for quad_dict in self.data: polarity_list.append(quad_dict['polarity']) return polarity_list # 对所有polarity进行综合 def merge_polarity(self): polarity_list = self.get_polarity_list() # 判断是英文还是中文 if self.language == 'en': if 'pos' in polarity_list and 'neg' in polarity_list: return 'neu' elif 'pos' in polarity_list: return 'pos' elif 'neg' in polarity_list: return 'neg' else: return 'neu' else: if '积极' in polarity_list and '消极' in polarity_list: return '中性' elif '积极' in polarity_list: return '积极' elif '消极' in polarity_list: return '消极' else: return '中性' # 检测是否有不合法opinion def check_opinion_in_comment(self, comment_text): for quad_dict in self.data: if quad_dict['opinion_text'] != '*' and (not quad_dict['opinion_text'] in comment_text): return False return True # 检测是否有不合法target def check_target_in_comment(self, comment_text): for quad_dict in self.data: if quad_dict['target_text'] != '*' and (not quad_dict['target_text'] in comment_text): return False return True # 计算两个四元组的相似度 @staticmethod def get_similarity(units1, units2: 'CommentUnitsSim'): pass # 对自身数据进行操作 def apply(self, func: Callable, field: str): for quad_dict in self.data: quad_dict[field] = func(quad_dict[field]) return self # 四元组匹配函数 class CommentUnitsMatch: def __init__(self, target_weight=0.5, opinion_weight=0.5, aspect_weight=0.5, polarity_weight=0.5, one_match=True): # 归一化权重 weight_sum = target_weight + opinion_weight + aspect_weight + polarity_weight self.target_weight = target_weight / weight_sum self.opinion_weight = opinion_weight / weight_sum self.aspect_weight = aspect_weight / weight_sum self.polarity_weight = polarity_weight / weight_sum # 是否一对一匹配 self.one_match = one_match # 特定feature置零 def set_zero(self, feature: str = 'polarity'): if feature == 'polarity': self.polarity_weight = 0 elif feature == 'aspect': self.aspect_weight = 0 elif 'opinion' in feature: self.opinion_weight = 0 elif 'target' in feature: self.target_weight = 0 else: raise Exception('feature参数错误') def re_normalize(self): weight_sum = self.target_weight + self.opinion_weight + self.aspect_weight + self.polarity_weight self.target_weight = self.target_weight / weight_sum self.opinion_weight = self.opinion_weight / weight_sum self.aspect_weight = self.aspect_weight / weight_sum self.polarity_weight = self.polarity_weight / weight_sum # 计算cost矩阵,完全匹配为0,不匹配为1 def get_cost_matrix(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim', feature: str = 'polarity'): pass # 检查此feature是否存在,不存在则返回全0矩阵 if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None \ or units1.data[0].get(feature) == 'None' or units2.data[0].get(feature) == 'None': cost_matrix = np.zeros((len(units1.data), len(units2.data))) # 对应feature的weight也为0 self.set_zero(feature) # 并再次归一化 self.re_normalize() return cost_matrix # 检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。不相同则cost为1,相同则cost为0 cost_matrix = [] for quad_dict1 in units1.data: cost_list = [] for quad_dict2 in units2.data: if quad_dict1[feature] == quad_dict2[feature]: cost_list.append(0) else: cost_list.append(1) cost_matrix.append(cost_list) # cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data)) cost_matrix = np.array(cost_matrix) return cost_matrix # 计算cost矩阵,使用rougel指标 def get_cost_matrix_rouge(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim', feature: str = 'target_text'): # 检查此feature是否存在,不存在则返回全0矩阵 if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None \ or units1.data[0].get(feature) == 'None' or units2.data[0].get(feature) == 'None': cost_matrix = np.zeros((len(units1.data), len(units2.data))) # 对应feature的weight也为0 self.set_zero(feature) # 并再次归一化 self.re_normalize() return cost_matrix # 检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。相同则cost为0,不相同则cost为1-rougel cost_matrix = [] for quad_dict1 in units1.data: cost_list = [] for quad_dict2 in units2.data: if quad_dict1[feature] == quad_dict2[feature]: cost_list.append(0) else: cost_list.append(1 - get_rougel_f1([quad_dict1[feature]], [quad_dict2[feature]])) cost_matrix.append(cost_list) # cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data)) cost_matrix = np.array(cost_matrix) return cost_matrix # 匹配四元组并计算cost def match_units(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim') -> tuple: # 计算极性的cost矩阵,矩阵元素在0-1之间 cost_matrix_polarity = self.get_cost_matrix(units1, units2, feature='polarity') # 计算aspect的cost矩阵 cost_matrix_aspect = self.get_cost_matrix(units1, units2, feature='aspect') # 计算target的cost矩阵 cost_matrix_target = self.get_cost_matrix_rouge(units1, units2, feature='target_text') # 计算opinion的cost矩阵 cost_matrix_opinion = self.get_cost_matrix_rouge(units1, units2, feature='opinion_text') # 计算总的cost矩阵,矩阵元素在0-1之间。矩阵的行数为units1即pred的数量,列数为units2即true的数量 cost_matrix = self.target_weight * cost_matrix_target + self.opinion_weight * cost_matrix_opinion + \ self.aspect_weight * cost_matrix_aspect + self.polarity_weight * cost_matrix_polarity score_matrix = 1 - cost_matrix cost = 0 # 使用匈牙利算法进行匹配 if self.one_match: # 只允许一对一的匹配,这种情况下row_ind和col_ind的长度一定相等且等于units1和units2的数量中的较小值 row_ind, col_ind = linear_sum_assignment(cost_matrix) else: # 允许一对多的匹配。这种情况下每个四元组都一定匹配上,这种情况下row_ind和col_ind的长度一定相等且等于units1和units2的数量中的较大值 if units1.num > units2.num: row_ind = np.arange(units1.num) col_ind = np.argmin(cost_matrix, axis=1) else: row_ind = np.argmin(cost_matrix, axis=0) col_ind = np.arange(units2.num) # 计算这种匹配的cost for i in range(len(row_ind)): cost += cost_matrix[row_ind[i]][col_ind[i]] # 计算这种匹配下的TP\FP\FN TP = 0 for i in range(len(row_ind)): TP += score_matrix[row_ind[i]][col_ind[i]] # len(row_ind)为pred的数量,TP为匹配上的数量 FP = units1.num - TP FN = units2.num - TP # 如果一对一匹配,会有匹配不上的四元组,这些四元组cost为1 max_units_num = max(units1.num, units2.num) if self.one_match: cost += (max_units_num - len(row_ind)) # 对cost进行归一化,使其在0-1之间 cost_per_quadruple = cost / max_units_num if cost_per_quadruple > 1 or cost_per_quadruple < 0: print('cost错误', cost_per_quadruple, 'pred:', units1.data, 'true:', units2.data) print(self.target_weight, self.opinion_weight, self.aspect_weight, self.polarity_weight) # 返回的cost在0-1之间 return cost_per_quadruple, TP, FP, FN @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class QuadMatch(evaluate.Metric): """TODO: Short description of my evaluation module.""" def _info(self): # TODO: Specifies the evaluate.EvaluationModuleInfo object return evaluate.MetricInfo( # This is the description that will appear on the modules page. module_type="metric", description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, # This defines the format of each prediction and reference features=[ datasets.Features( { "predictions": datasets.Value("string", id="sequence"), "references": datasets.Sequence(datasets.Value("string", id="sequence")), } ), datasets.Features( { "predictions": datasets.Value("string", id="sequence"), "references": datasets.Value("string", id="sequence"), } ), ], # Homepage of the module for documentation homepage="http://module.homepage", # Additional links to the codebase or references codebase_urls=["http://github.com/path/to/codebase/of/new_module"], reference_urls=["http://path.to.reference.url/new_module"] ) def _download_and_prepare(self, dl_manager): """Optional: download external resources useful to compute the scores""" # TODO: Download external resources if needed pass def _compute(self, predictions: List[str], references: Union[List[str], List[List[str]]], quad_weights: tuple = (1, 1, 1, 1), **kwargs) -> dict: ''' :param predictions: list of predictions of sentiment quads :param references: list of references of sentiment quads :param quad_weights: weight of target,opinion,aspect,polarity for cost compute :param kwargs: :param tuple_len: indicate the format of the quad, see the following mapping :param sep_token1: the token to seperate quads :param sep_token2: the token to seperate units of one quad :return:average matching score #mapping id2prompt={'0123':"quadruples (target | opinion | aspect | polarity)", '':"quadruples (target | opinion | aspect | polarity)", '01':'pairs (target | opinion)', '012':'triples (target | opinion | aspect)', '013':'triples (target | opinion | polarity)', '023':'triples (target | aspect | polarity)', '23':'pairs (aspect | polarity)', '03':'pairs (target | polarity)', '13':'pairs (opinion | polarity)', '3':'single (polarity)'} #中文版映射 id2prompt_zh={'0123': "四元组(对象 | 观点 | 方面 | 极性)", '':"四元组(对象 | 观点 | 方面 | 极性)", '01':'二元组(对象 | 观点)', '012':'三元组(对象 | 观点 | 方面)', '013':'三元组(对象 | 观点 | 极性)', '023':'三元组(对象 | 方面 | 极性)', '23':'二元组(方面 | 极性)', '03':'二元组(对象 | 极性)', '13':'二元组(观点 | 极性)', '3':'单元素(极性)'} ''' f1_of_optimal_match, score_of_optimal_match = self.quad_f1_of_optimal_match(predictions, references, quad_weights, **kwargs) f1 = self.quad_f1_of_exact_match(predictions=predictions, references=references, **kwargs) # 取1-cost为得分 return {'score of optimal match of weight ' + str(quad_weights): score_of_optimal_match, 'f1 of optimal match of weight ' + str(quad_weights): f1_of_optimal_match, 'f1 of exact match': f1} @staticmethod def quad_f1_of_exact_match(predictions: List[str], references: Union[List[str], List[List[str]]], return_dict=False, **kwargs) -> Union[Dict[str, float], float]: assert len(predictions) == len(references), "文本数量不一致" correct, pred_num, true_num = 0, 0, 0 for pred, refer in zip(predictions, references): pred = CommentUnitsSim.from_str(pred, **kwargs) # refer转换为list if isinstance(refer, str): refer =[refer] # refer转换为CommentUnitsSim refer = [CommentUnitsSim.from_str(t, **kwargs) for t in refer] # 如果refer是list,说明有多个正确答案,取最高分的那个 #计算每个refer的TP的个数 correct_list = [pred.compare_same(t) for t in refer] #计算每个refer的f1 f1_list=[2 * correct_list[i] / (pred.num + refer[i].num) for i in range(len(refer))] # 获取f1得分最高的索引 best_index = f1_list.index(max(f1_list)) pred_num += pred.num true_num += refer[best_index].num correct += correct_list[best_index] # 以下结果保留4位小数 precision = round(correct / pred_num, 4) + 1e-8 recall = round(correct / true_num, 4) + 1e-8 f1 = round(2 * precision * recall / (precision + recall), 4) if return_dict: return {"precision": precision, "recall": recall, "f1": f1} else: return f1 # 计算最优匹配f1 @staticmethod def quad_f1_of_optimal_match( predictions: List[str], references: Union[List[str], List[List[str]]], quad_weights: tuple = (1, 1, 1, 1), one_match=True, **kwargs): assert len(predictions) == len(references) if isinstance(predictions, str): predictions = [predictions] references = [references] cost = 0 TP, FP, FN = 0, 0, 0 matcher = CommentUnitsMatch(*quad_weights, one_match=one_match) for pred, refer in zip(predictions, references): pred = CommentUnitsSim.from_str(pred, **kwargs) # 将refer转换为list形式 if isinstance(refer, str): refer = [refer] # 将refer中的每个元素转换为CommentUnitsSim refer = [CommentUnitsSim.from_str(t, **kwargs) for t in refer] # 如果true是多个正确答案,取最高分 cost_list = [matcher.match_units(pred, t) for t in refer] # 获取cost最小的值的索引,按元组中第一个元素大小排序 # 计算每一对样本的cost,TP,FP,FN cost_, TP_, FP_, FN_ = cost_list[np.argmin([c[0] for c in cost_list])] cost += cost_ TP += TP_ FP += FP_ FN += FN_ # 平均cost cost = cost / len(predictions) # 由TP\FP\FN计算最优匹配F1 precision_match = TP / (TP + FP) recall_match = TP / (TP + FN) f1_match = 2 * precision_match * recall_match / (precision_match + recall_match) return f1_match, 1 - cost