yuyijiong commited on
Commit
120df80
1 Parent(s): 13f64cc

修复多个refer时f1不正常的bug

Browse files
Files changed (1) hide show
  1. quad_match_score.py +24 -23
quad_match_score.py CHANGED
@@ -660,7 +660,7 @@ class QuadMatch(evaluate.Metric):
660
  '''
661
  f1_of_optimal_match, score_of_optimal_match = self.quad_f1_of_optimal_match(predictions, references,
662
  quad_weights, **kwargs)
663
- f1 = self.quad_f1_of_exact_match(y_pred=predictions, y_true=references, **kwargs)
664
 
665
  # 取1-cost为得分
666
  return {'score of optimal match of weight ' + str(quad_weights): score_of_optimal_match,
@@ -668,30 +668,31 @@ class QuadMatch(evaluate.Metric):
668
  'f1 of exact match': f1}
669
 
670
  @staticmethod
671
- def quad_f1_of_exact_match(y_pred: List[str], y_true: Union[List[str], List[List[str]]],
672
  return_dict=False, **kwargs) -> Union[Dict[str, float], float]:
673
- assert len(y_pred) == len(y_true), "文本数量不一致"
674
  correct, pred_num, true_num = 0, 0, 0
675
 
676
- for pred, true in zip(y_pred, y_true):
677
  pred = CommentUnitsSim.from_str(pred, **kwargs)
678
- # 如果true是list,说明有多个正确答案
679
- if isinstance(true, str):
680
- true = CommentUnitsSim.from_str(true, **kwargs)
681
- else:
682
- true = [CommentUnitsSim.from_str(t, **kwargs) for t in true]
683
-
684
- # 如果true是list,说明有多个正确答案,取最高分
685
- if isinstance(true, list):
686
- correct_list = [pred.compare_same(t) for t in true]
687
- correct += max(correct_list) # 获取得分最高的值
688
- correct_index = correct_list.index(max(correct_list)) # 获取得分最高的索引
689
- pred_num += pred.num
690
- true_num += true[correct_index].num
691
- else:
692
- correct += pred.compare_same(true)
693
- pred_num += pred.num
694
- true_num += true.num
 
695
 
696
  # 以下结果保留4位小数
697
  precision = round(correct / pred_num, 4) + 1e-8
@@ -733,9 +734,9 @@ class QuadMatch(evaluate.Metric):
733
 
734
  # 如果true是多个正确答案,取最高分
735
  cost_list = [matcher.match_units(pred, t) for t in refer]
736
- # 获取得分最高的值的索引,按元组中第一个元素大小排序
737
  # 计算每一对样本的cost,TP,FP,FN
738
- cost_, TP_, FP_, FN_ = cost_list[np.argmax([c[0] for c in cost_list])]
739
  cost += cost_
740
  TP += TP_
741
  FP += FP_
 
660
  '''
661
  f1_of_optimal_match, score_of_optimal_match = self.quad_f1_of_optimal_match(predictions, references,
662
  quad_weights, **kwargs)
663
+ f1 = self.quad_f1_of_exact_match(predictions=predictions, references=references, **kwargs)
664
 
665
  # 取1-cost为得分
666
  return {'score of optimal match of weight ' + str(quad_weights): score_of_optimal_match,
 
668
  'f1 of exact match': f1}
669
 
670
  @staticmethod
671
+ def quad_f1_of_exact_match(predictions: List[str], references: Union[List[str], List[List[str]]],
672
  return_dict=False, **kwargs) -> Union[Dict[str, float], float]:
673
+ assert len(predictions) == len(references), "文本数量不一致"
674
  correct, pred_num, true_num = 0, 0, 0
675
 
676
+ for pred, refer in zip(predictions, references):
677
  pred = CommentUnitsSim.from_str(pred, **kwargs)
678
+ # refer转换为list
679
+ if isinstance(refer, str):
680
+ refer =[refer]
681
+
682
+ # refer转换为CommentUnitsSim
683
+ refer = [CommentUnitsSim.from_str(t, **kwargs) for t in refer]
684
+
685
+ # 如果refer是list,说明有多个正确答案,取最高分的那个
686
+ #计算每个refer的TP的个数
687
+ correct_list = [pred.compare_same(t) for t in refer]
688
+ #计算每个refer的f1
689
+ f1_list=[2 * correct_list[i] / (pred.num + refer[i].num) for i in range(len(refer))]
690
+ # 获取f1得分最高的索引
691
+ best_index = f1_list.index(max(f1_list))
692
+ pred_num += pred.num
693
+ true_num += refer[best_index].num
694
+ correct += correct_list[best_index]
695
+
696
 
697
  # 以下结果保留4位小数
698
  precision = round(correct / pred_num, 4) + 1e-8
 
734
 
735
  # 如果true是多个正确答案,取最高分
736
  cost_list = [matcher.match_units(pred, t) for t in refer]
737
+ # 获取cost最小的值的索引,按元组中第一个元素大小排序
738
  # 计算每一对样本的cost,TP,FP,FN
739
+ cost_, TP_, FP_, FN_ = cost_list[np.argmin([c[0] for c in cost_list])]
740
  cost += cost_
741
  TP += TP_
742
  FP += FP_