yuyijiong commited on
Commit
13f64cc
1 Parent(s): 6a9adf9

Upload 2 files

Browse files
Files changed (2) hide show
  1. README.md +6 -6
  2. quad_match_score.py +261 -237
README.md CHANGED
@@ -47,9 +47,9 @@ references=["food | good | food#taste | pos & service | bad | service#general |
47
  result=module.compute(predictions=predictions, references=references)
48
  print(result)
49
 
50
- result={'ave match score of weight (1, 1, 1, 1)': 0.375,
51
- 'f1 score of exact match': 0.0,
52
- 'f1 score of optimal match of weight (1, 1, 1, 1)': 0.5}
53
  ```
54
 
55
  ### Inputs
@@ -78,9 +78,9 @@ result={'ave match score of weight (1, 1, 1, 1)': 0.375,
78
 
79
  *最优匹配 f1值、最优匹配样本平均得分、完全匹配 f1值(传统评估) 组成的dict,f1值均在[0,1]之间*
80
 
81
- *例如: {'ave match score of weight (1, 1, 1, 1)': 0.375,
82
- 'f1 score of exact match': 0.0,
83
- 'f1 score of optimal match of weight (1, 1, 1, 1)': 0.5}*
84
 
85
 
86
  ## Limitations and Bias
 
47
  result=module.compute(predictions=predictions, references=references)
48
  print(result)
49
 
50
+ result={'f1 of exact match': 0.6667,
51
+ 'f1 of optimal match of weight (1, 1, 1, 1)': 0.6666666666666666,
52
+ 'score of optimal match of weight (1, 1, 1, 1)': 0.5}
53
  ```
54
 
55
  ### Inputs
 
78
 
79
  *最优匹配 f1值、最优匹配样本平均得分、完全匹配 f1值(传统评估) 组成的dict,f1值均在[0,1]之间*
80
 
81
+ *例如:{'f1 of exact match': 0.6667,
82
+ 'f1 of optimal match of weight (1, 1, 1, 1)': 0.6666666666666666,
83
+ 'score of optimal match of weight (1, 1, 1, 1)': 0.5}*
84
 
85
 
86
  ## Limitations and Bias
quad_match_score.py CHANGED
@@ -15,10 +15,9 @@
15
 
16
  import copy
17
  import re
18
- from typing import List, Dict, Union,Callable
19
  import numpy as np
20
 
21
-
22
  import datasets
23
  import evaluate
24
  from rouge_chinese import Rouge
@@ -27,7 +26,7 @@ from scipy.optimize import linear_sum_assignment
27
  # TODO: Add BibTeX citation
28
  _CITATION = """\
29
  @InProceedings{huggingface:module,
30
- title = {A great new module},
31
  authors={huggingface, Inc.},
32
  year={2020}
33
  }
@@ -39,7 +38,6 @@ evaluate sentiment quadruples.
39
  评估生成模型的情感四元组
40
  """
41
 
42
-
43
  # TODO: Add description of the arguments of the module here
44
  _KWARGS_DESCRIPTION = """
45
  Calculates how good are predictions given some references, using certain scores
@@ -55,53 +53,22 @@ Examples:
55
  Examples should be written in doctest format, and should illustrate how
56
  to use the function.
57
 
58
- >>> my_new_module = evaluate.load("my_new_module")
59
- >>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
60
- >>> print(results)
61
- {'accuracy': 1.0}
 
 
 
 
 
62
  """
63
 
64
 
65
- def compute_quadruple_f1(y_pred: List[str], y_true: Union[List[str], List[List[str]]],
66
- return_rp=False, **kwargs) -> Dict[str, float]:
67
- assert len(y_pred) == len(y_true)
68
- correct, pred_num, true_num = 0, 0, 0
69
-
70
- for pred, true in zip(y_pred, y_true):
71
-
72
- pred = CommentUnitsSim.from_str(pred, **kwargs)
73
- # 如果true是list,说明有多个正确答案
74
- if isinstance(true, str):
75
- true = CommentUnitsSim.from_str(true, **kwargs)
76
- else:
77
- true = [CommentUnitsSim.from_str(t,**kwargs) for t in true]
78
-
79
- # 如果true是list,说明有多个正确答案,取最高分
80
- if isinstance(true, list):
81
- correct_list = [pred.compare_same(t) for t in true]
82
- correct += max(correct_list) # 获取得分最高的值
83
- correct_index = correct_list.index(max(correct_list)) # 获取得分最高的索引
84
- pred_num += pred.num
85
- true_num += true[correct_index].num
86
- else:
87
- correct += pred.compare_same(true)
88
- pred_num += pred.num
89
- true_num += true.num
90
-
91
- # 以下结果保留4位小数
92
- precision = round(correct / pred_num, 4) + 1e-8
93
- recall = round(correct / true_num, 4) + 1e-8
94
- f1 = round(2 * precision * recall / (precision + recall), 4)
95
-
96
- if return_rp:
97
- return {"precision": precision, "recall": recall, "f1": f1}
98
- else:
99
- return f1
100
-
101
  # 计算rougel的f1值
102
  def get_rougel_f1(text_pred_list: List[str], text_true_list: List[str]) -> float:
103
  assert len(text_pred_list) == len(text_true_list), "文本数量不一致"
104
- #如果text_pred_list[0]为空字符串或空格,则返回0
105
  if not text_pred_list[0].strip():
106
  return 0
107
 
@@ -115,12 +82,13 @@ def get_rougel_f1(text_pred_list: List[str], text_true_list: List[str]) -> float
115
 
116
  return rouge_l_f1
117
 
 
118
  # 记录四元组的函数
119
  class CommentUnitsSim:
120
- def __init__(self, data: List[Dict[str, str]],data_source:any=None,abnormal=False,language=None):
121
- self.data_source=data_source
122
- self.abnormal=abnormal
123
- data=copy.deepcopy(data)
124
  # 如果字典有target,则改名为target_text
125
  for quad_dict in data:
126
  if 'target' in quad_dict:
@@ -131,73 +99,79 @@ class CommentUnitsSim:
131
  del quad_dict['opinion']
132
 
133
  self.data = data
134
- self.polarity_en2zh = {'positive': '积极', 'negative': '消极', 'neutral': '中性','pos':'积极','neg':'消极','neu':'中性','积极':'积极','消极':'消极','中性':'中性'}
135
- self.polarity_zh2en={'积极':'pos','消极':'neg','中性':'neu','pos':'pos','neg':'neg','neu':'neu','positive':'pos','negative':'neg','neutral':'neu'}
 
 
136
 
137
- self.language=language if language is not None else 'zh' if self.check_zh() else 'en'
138
- self.none_sign='null'
139
 
140
  @property
141
  def num(self):
142
  return len(self.data)
143
 
144
- #检查四元组中是否有中文
145
  def check_zh(self):
146
  for quad_dict in self.data:
147
- if re.search('[\u4e00-\u9fa5]',quad_dict['target_text']) or re.search('[\u4e00-\u9fa5]',quad_dict['opinion_text']):
 
148
  return True
149
  return False
150
 
151
  # 检测极性是否正确
152
  def check_polarity(self):
153
- #若有某个四元组的极性不是positive、negative、neutral,则返回False
154
  for quad_dict in self.data:
155
- if quad_dict['polarity'] not in ['positive', 'negative', 'neutral','pos','neg','neu','积极','消极','中性']:
156
- self.abnormal=True
 
157
  return False
158
 
159
- #将极性由英文转为中文
160
  def convert_polarity_en2zh(self):
161
  for quad_dict in self.data:
162
- quad_dict['polarity']=self.polarity_en2zh[quad_dict['polarity']]
163
  return self
164
 
165
- #将极性由中文转为英文
166
  def convert_polarity_zh2en(self):
167
  for quad_dict in self.data:
168
- quad_dict['polarity']=self.polarity_zh2en[quad_dict['polarity']]
169
  return self
170
 
171
- #检查是否有重复的四元组,若有则删除重复的
172
  def del_duplicate(self):
173
- new_data=[]
174
  for quad_dict in self.data:
175
  if quad_dict not in new_data:
176
  new_data.append(quad_dict)
177
- self.data=new_data
178
  return self
179
 
180
- #检查是否有target和opinion都为null的四元组,若有则返回True
181
  def check_target_opinion_null(self):
182
  for quad_dict in self.data:
183
- if quad_dict['target_text']=='null' and quad_dict['opinion_text']=='null':
184
  return True
185
  return False
186
 
187
- #检查是否有target或opinion为null的四元组,若有则返回True
188
  def check_any_null(self):
189
  for quad_dict in self.data:
190
- if quad_dict['target_text']=='null' or quad_dict['opinion_text']=='null':
191
  return True
192
  return False
193
 
194
  @classmethod
195
- def from_str(cls, quadruple_str: str, tuple_len:Union[int,list,str]=4, format_code=0, sep_token1=' & ', sep_token2=' | '):
 
196
  data = []
197
- abnormal=False
198
- #确保分隔符后面一定是空格
199
- for i in range(len(quadruple_str)-1):
200
- if (quadruple_str[i] == sep_token1.strip() or quadruple_str[i] == sep_token2.strip()) and quadruple_str[i + 1] != ' ':
 
201
  quadruple_str = quadruple_str[:i + 1] + ' ' + quadruple_str[i + 1:]
202
 
203
  # 选择几元组,即创建列表索引,从四元组中抽出n元
@@ -211,27 +185,27 @@ class CommentUnitsSim:
211
  else:
212
  raise Exception('tuple_len参数错误')
213
 
214
-
215
  for quadruple in quadruple_str.split(sep_token1):
216
  if format_code == 0:
217
  # quadruple可能是target|opinion|aspect|polarity,也可能是target|opinion|aspect,也可能是target|opinion,若没有则为“None”
218
- quadruple_split=[unit.strip() for unit in quadruple.split(sep_token2)]
219
- if len(quadruple_split)>len(tuple_index):
220
  print('quadruple格式错误,过多元素', quadruple_str)
221
- abnormal=True
222
- quadruple_split=quadruple_split[0:len(tuple_index)] #过长则截断
223
- elif len(quadruple_split)<len(tuple_index):
224
  print('quadruple格式错误,过少元素', quadruple_str)
225
- abnormal=True
226
- quadruple_split=["None"]*(len(tuple_index)-len(quadruple_split))+quadruple_split #过短则补'None'
 
227
 
228
- quadruple_keys=[["target_text","opinion_text","aspect","polarity"][i] for i in tuple_index]
229
- quadruple_dict=dict(zip(quadruple_keys,quadruple_split))
230
 
231
  q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'}
232
  q.update(quadruple_dict)
233
- #检查极性是否合法
234
- if q['polarity'] not in ['pos','neg','neu','None','积极','消极','中性']:
235
  print('quadruple格式错误,极性格式不对', quadruple_str)
236
 
237
  else:
@@ -239,10 +213,10 @@ class CommentUnitsSim:
239
 
240
  data.append(q)
241
 
242
- return CommentUnitsSim(data,quadruple_str,abnormal)
243
 
244
  @classmethod
245
- def from_list(cls, quadruple_list: List[List[str]],**kwargs):
246
  data = []
247
  for quadruple in quadruple_list:
248
  # #format_code='013'代表list只有四元组的第0、1、3个元素,需要扩充为4元组,空缺位置补上None
@@ -253,10 +227,10 @@ class CommentUnitsSim:
253
  {"target_text": quadruple[0], "opinion_text": quadruple[1], "aspect": quadruple[2],
254
  "polarity": quadruple[3]})
255
 
256
- return CommentUnitsSim(data,quadruple_list,**kwargs)
257
 
258
  @classmethod
259
- def from_list_dict(cls, quadruple_list: List[dict],**kwargs):
260
  for quad_dict in quadruple_list:
261
  if 'target' in quad_dict:
262
  quad_dict['target_text'] = quad_dict['target']
@@ -267,22 +241,24 @@ class CommentUnitsSim:
267
 
268
  data = []
269
  for quadruple in quadruple_list:
270
- #如果quadruple缺少某个key,则补上None
271
- q={"target_text":'None',"opinion_text":'None',"aspect":'None',"polarity":'None'}
272
  q.update(quadruple)
273
  data.append(q)
274
 
275
- return CommentUnitsSim(data,quadruple_list,**kwargs)
276
 
277
- #转化为list,即只保留字典的value
278
  def to_list(self):
279
  data = []
280
  for quad_dict in self.data:
281
- data.append([quad_dict['target_text'],quad_dict['opinion_text'],quad_dict['aspect'],quad_dict['polarity']])
 
282
  return data
283
 
284
  # 将data转换为n元组字符串
285
- def get_quadruple_str(self, format_code=0, tuple_len:Union[int,list,str]=4,sep_token1=' & ',sep_token2=' | '):
 
286
  new_text_list = []
287
  # 选择几元组,即创建列表索引,从四元组中抽出n元
288
  if isinstance(tuple_len, int):
@@ -296,18 +272,18 @@ class CommentUnitsSim:
296
  raise Exception('tuple_len参数错误')
297
 
298
  try:
299
- #若语言为中文,则使用中文极性
300
- if self.language=='zh':
301
  self.convert_polarity_en2zh()
302
  else:
303
  self.convert_polarity_zh2en()
304
  except:
305
- print('语言参数��误',self.data)
306
  print(self.language)
307
  raise Exception('语言参数错误')
308
 
309
- #若tuple_index==[3],则返回综合情感极性
310
- if tuple_index==[3]:
311
  return self.merge_polarity()
312
 
313
  for quad_dict in self.data:
@@ -320,7 +296,6 @@ class CommentUnitsSim:
320
  # 提取polarity
321
  polarity = quad_dict['polarity']
322
 
323
-
324
  # 拼接,‘|’分割
325
  if format_code == 0:
326
  # 根据tuple_len拼接
@@ -330,24 +305,24 @@ class CommentUnitsSim:
330
 
331
  new_text_list.append(new_text)
332
 
333
- #如果tuple_index为[2,3],则需要去除new_text_list中重复的元素,不要改变顺序。因为可能有重复的方面
334
- if tuple_index==[2,3]:
335
  res = []
336
  for t in new_text_list:
337
  if t not in res:
338
  res.append(t)
339
- new_text_list=res
340
 
341
- #如果tuple_index为[3],则只保留new_text_list的第一个元素。因为只有一个情感极性
342
- elif tuple_index==[3]:
343
- new_text_list=new_text_list[:1]
344
 
345
  if format_code == 0:
346
  # 根据tuple_len拼接
347
  return sep_token1.join(new_text_list)
348
 
349
  # 与另一个CommentUnits对象对比,检测有几个相同的四元组
350
- def compare_same(self, other)->int:
351
  count = 0
352
  for quad_dict in self.data:
353
  if quad_dict in other.data:
@@ -403,10 +378,10 @@ class CommentUnitsSim:
403
  polarity_list.append(quad_dict['polarity'])
404
  return polarity_list
405
 
406
- #对所有polarity进行综合
407
  def merge_polarity(self):
408
  polarity_list = self.get_polarity_list()
409
- #判断是英文还是中文
410
  if self.language == 'en':
411
  if 'pos' in polarity_list and 'neg' in polarity_list:
412
  return 'neu'
@@ -426,44 +401,47 @@ class CommentUnitsSim:
426
  else:
427
  return '中性'
428
 
429
- #检测是否有不合法opinion
430
  def check_opinion_in_comment(self, comment_text):
431
  for quad_dict in self.data:
432
- if quad_dict['opinion_text'] !='*' and (not quad_dict['opinion_text'] in comment_text):
433
  return False
434
  return True
435
 
436
- #检测是否有不合法target
437
- def check_target_in_comment(self,comment_text):
438
  for quad_dict in self.data:
439
- if quad_dict['target_text'] !='*' and (not quad_dict['target_text'] in comment_text):
440
  return False
441
  return True
442
 
443
- #计算两个四元组的相似度
444
  @staticmethod
445
  def get_similarity(units1, units2: 'CommentUnitsSim'):
446
  pass
447
 
448
- #对自身数据进行操作
449
- def apply(self,func:Callable,field:str):
450
  for quad_dict in self.data:
451
  quad_dict[field] = func(quad_dict[field])
452
  return self
453
 
454
 
455
- #四元组匹配函数
456
  class CommentUnitsMatch:
457
- def __init__(self,target_weight=0.5,opinion_weight=0.5,aspect_weight=0.5,polarity_weight=0.5):
458
- #归一化权重
459
- weight_sum = target_weight+opinion_weight+aspect_weight+polarity_weight
460
- self.target_weight = target_weight/weight_sum
461
- self.opinion_weight = opinion_weight/weight_sum
462
- self.aspect_weight = aspect_weight/weight_sum
463
- self.polarity_weight = polarity_weight/weight_sum
464
-
465
- #特定feature置零
466
- def set_zero(self,feature:str='polarity'):
 
 
 
467
  if feature == 'polarity':
468
  self.polarity_weight = 0
469
  elif feature == 'aspect':
@@ -476,21 +454,20 @@ class CommentUnitsMatch:
476
  raise Exception('feature参数错误')
477
 
478
  def re_normalize(self):
479
- weight_sum = self.target_weight+self.opinion_weight+self.aspect_weight+self.polarity_weight
480
- self.target_weight = self.target_weight/weight_sum
481
- self.opinion_weight = self.opinion_weight/weight_sum
482
- self.aspect_weight = self.aspect_weight/weight_sum
483
- self.polarity_weight = self.polarity_weight/weight_sum
484
-
485
-
486
- #计算cost矩阵
487
- def get_cost_matrix(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',feature:str='polarity'):
488
  pass
489
- #检查此feature是否存在,不存在则返回全0矩阵
490
- if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None\
491
- or units1.data[0].get(feature)=='None' or units2.data[0].get(feature)=='None':
492
- cost_matrix = np.zeros((len(units1.data),len(units2.data)))
493
- #对应feature的weight也为0
494
  self.set_zero(feature)
495
 
496
  # 并再次归一化
@@ -498,7 +475,7 @@ class CommentUnitsMatch:
498
 
499
  return cost_matrix
500
 
501
- #检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。不相同则cost为1,相同则cost为0
502
  cost_matrix = []
503
  for quad_dict1 in units1.data:
504
  cost_list = []
@@ -509,23 +486,23 @@ class CommentUnitsMatch:
509
  cost_list.append(1)
510
  cost_matrix.append(cost_list)
511
 
512
- #cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
513
  cost_matrix = np.array(cost_matrix)
514
  return cost_matrix
515
 
516
- #计算cost矩阵,使用rouge指标
517
- def get_cost_matrix_rouge(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',feature:str='target_text'):
518
- #检查此feature是否存在,不存在则返回全0矩阵
519
- if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None\
520
- or units1.data[0].get(feature)=='None' or units2.data[0].get(feature)=='None':
521
- cost_matrix = np.zeros((len(units1.data),len(units2.data)))
522
- #对应feature的weight也为0
523
  self.set_zero(feature)
524
  # 并再次归一化
525
  self.re_normalize()
526
  return cost_matrix
527
 
528
- #检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。相同则cost为0,不相同则cost为1-rougel
529
  cost_matrix = []
530
  for quad_dict1 in units1.data:
531
  cost_list = []
@@ -533,63 +510,71 @@ class CommentUnitsMatch:
533
  if quad_dict1[feature] == quad_dict2[feature]:
534
  cost_list.append(0)
535
  else:
536
- cost_list.append(1-get_rougel_f1([quad_dict1[feature]],[quad_dict2[feature]]))
537
  cost_matrix.append(cost_list)
538
 
539
- #cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
540
  cost_matrix = np.array(cost_matrix)
541
  return cost_matrix
542
 
543
- def match_units(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',one_match=True)->tuple:
544
- #计算极性的cost矩阵,矩阵元素在0-1之间
545
- cost_matrix_polarity = self.get_cost_matrix(units1, units2,feature='polarity')
546
- #计算aspect的cost矩阵
547
- cost_matrix_aspect = self.get_cost_matrix(units1, units2,feature='aspect')
548
- #计算target的cost矩阵
549
- cost_matrix_target = self.get_cost_matrix_rouge(units1, units2,feature='target_text')
550
- #计算opinion的cost矩阵
551
- cost_matrix_opinion = self.get_cost_matrix_rouge(units1, units2,feature='opinion_text')
552
-
553
- #计算总的cost矩阵,矩阵元素在0-1之间。矩阵的行数为units1即pred的数量,列数为units2即true的数量
554
- cost_matrix = self.target_weight*cost_matrix_target + self.opinion_weight*cost_matrix_opinion + \
555
- self.aspect_weight*cost_matrix_aspect + self.polarity_weight*cost_matrix_polarity
556
- score_matrix = 1-cost_matrix
557
- #使用匈牙利算法进行匹配
558
- if one_match:
 
 
 
 
559
  row_ind, col_ind = linear_sum_assignment(cost_matrix)
 
560
  else:
561
- #允许一对多的匹配
562
- row_ind = np.argmin(cost_matrix, axis=0)
563
- col_ind = np.arange(len(units2.data))
 
564
 
565
- max_units_num=max(units1.num,units2.num)
 
 
566
 
567
- #计算这种匹配的cost
568
- cost = 0
569
  for i in range(len(row_ind)):
570
  cost += cost_matrix[row_ind[i]][col_ind[i]]
571
 
572
- #计算这种匹配下的TP\FP\FN
573
  TP = 0
574
  for i in range(len(row_ind)):
575
  TP += score_matrix[row_ind[i]][col_ind[i]]
576
 
577
- #len(row_ind)为pred的数量,TP为匹配上的数量
578
- FP = units1.num-TP
579
- FN = units2.num-TP
580
-
581
 
582
- #匹配不上的四元组,cost为1
583
- cost += (max_units_num-len(row_ind))
 
 
584
 
585
- cost_per_quadruple=cost/max_units_num
586
- if cost_per_quadruple>1 or cost_per_quadruple <0:
 
 
 
587
 
588
- print('cost错误',cost_per_quadruple,'pred:',units1.data,'true:',units2.data)
589
- print(self.target_weight,self.opinion_weight,self.aspect_weight,self.polarity_weight)
590
-
591
- #返回的cost在0-1之间
592
- return cost_per_quadruple,TP,FP,FN
593
 
594
 
595
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
@@ -632,9 +617,9 @@ class QuadMatch(evaluate.Metric):
632
  pass
633
 
634
  def _compute(self,
635
- predictions:List[str],
636
- references: Union[List[str],List[List[str]]],
637
- quad_weights:tuple=(1,1,1,1),
638
  **kwargs) -> dict:
639
  '''
640
 
@@ -673,55 +658,94 @@ class QuadMatch(evaluate.Metric):
673
  '13':'二元组(观点 | 极性)',
674
  '3':'单元素(极性)'}
675
  '''
 
 
 
676
 
677
- assert len(predictions) == len(references)
678
- if isinstance(predictions,str):
679
- predictions=[predictions]
680
- references=[references]
681
 
682
- cost=0
683
- TP,FP,FN=0,0,0
684
- matcher = CommentUnitsMatch(*quad_weights)
685
-
686
- for pred, true in zip(predictions, references):
687
 
688
- pred = CommentUnitsSim.from_str(pred,**kwargs)
 
689
  # 如果true是list,说明有多个正确答案
690
  if isinstance(true, str):
691
  true = CommentUnitsSim.from_str(true, **kwargs)
692
- elif isinstance(true, list):
693
- true=[CommentUnitsSim.from_str(t, **kwargs) for t in true]
694
  else:
695
- print("true的类型不对",true)
696
- continue
697
 
698
- #如果true是list,说明有多个正确答案,取最高分
699
  if isinstance(true, list):
700
- cost_list=[matcher.match_units(pred,t,one_match=True) for t in true]
701
- # 获取得分最高的值的索引,按元组中第一个元素大小排序
702
- cost_,TP_,FP_,FN_ = cost_list[np.argmax([c[0] for c in cost_list])]
703
- cost += cost_
704
- TP+=TP_
705
- FP+=FP_
706
- FN+=FN_
707
-
708
  else:
709
- cost_,TP_,FP_,FN_ = matcher.match_units(pred,true,one_match=True)
710
- cost += cost_
711
- TP+=TP_
712
- FP+=FP_
713
- FN+=FN_
714
-
715
- #平均cost
716
- cost=cost/len(predictions)
717
- #由TP\FP\FN计算最优匹配F1
718
- precision_match=TP/(TP+FP)
719
- recall_match=TP/(TP+FN)
720
- f1_match=2*precision_match*recall_match/(precision_match+recall_match)
721
-
722
- f1=compute_quadruple_f1(y_pred=predictions,y_true=references, **kwargs)
723
-
724
- #取1-cost为得分
725
- return {'ave match score of weight '+str(quad_weights):1-cost,
726
- 'f1 score of optimal match of weight '+str(quad_weights): f1_match,
727
- 'f1 score of exact match':f1}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  import copy
17
  import re
18
+ from typing import List, Dict, Union, Callable
19
  import numpy as np
20
 
 
21
  import datasets
22
  import evaluate
23
  from rouge_chinese import Rouge
 
26
  # TODO: Add BibTeX citation
27
  _CITATION = """\
28
  @InProceedings{huggingface:module,
29
+ title = {quad match score},
30
  authors={huggingface, Inc.},
31
  year={2020}
32
  }
 
38
  评估生成模型的情感四元组
39
  """
40
 
 
41
  # TODO: Add description of the arguments of the module here
42
  _KWARGS_DESCRIPTION = """
43
  Calculates how good are predictions given some references, using certain scores
 
53
  Examples should be written in doctest format, and should illustrate how
54
  to use the function.
55
 
56
+ >>> import evaluate
57
+ >>> module = evaluate.load("yuyijiong/quad_match_score")
58
+ >>> predictions=["food | good | food#taste | pos"]
59
+ >>> references=["food | good | food#taste | pos & service | bad | service#general | neg"]
60
+ >>> result=module.compute(predictions=predictions, references=references)
61
+ >>> print(result)
62
+ result={'ave match score of weight (1, 1, 1, 1)': 0.375,
63
+ 'f1 score of exact match': 0.0,
64
+ 'f1 score of optimal match of weight (1, 1, 1, 1)': 0.5}
65
  """
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  # 计算rougel的f1值
69
  def get_rougel_f1(text_pred_list: List[str], text_true_list: List[str]) -> float:
70
  assert len(text_pred_list) == len(text_true_list), "文本数量不一致"
71
+ # 如果text_pred_list[0]为空字符串或空格,则返回0
72
  if not text_pred_list[0].strip():
73
  return 0
74
 
 
82
 
83
  return rouge_l_f1
84
 
85
+
86
  # 记录四元组的函数
87
  class CommentUnitsSim:
88
+ def __init__(self, data: List[Dict[str, str]], data_source: any = None, abnormal=False, language=None):
89
+ self.data_source = data_source
90
+ self.abnormal = abnormal
91
+ data = copy.deepcopy(data)
92
  # 如果字典有target,则改名为target_text
93
  for quad_dict in data:
94
  if 'target' in quad_dict:
 
99
  del quad_dict['opinion']
100
 
101
  self.data = data
102
+ self.polarity_en2zh = {'positive': '积极', 'negative': '消极', 'neutral': '中性', 'pos': '积极', 'neg': '消极',
103
+ 'neu': '中性', '积极': '积极', '消极': '消极', '中性': '中性'}
104
+ self.polarity_zh2en = {'积极': 'pos', '消极': 'neg', '中性': 'neu', 'pos': 'pos', 'neg': 'neg', 'neu': 'neu',
105
+ 'positive': 'pos', 'negative': 'neg', 'neutral': 'neu'}
106
 
107
+ self.language = language if language is not None else 'zh' if self.check_zh() else 'en'
108
+ self.none_sign = 'null'
109
 
110
  @property
111
  def num(self):
112
  return len(self.data)
113
 
114
+ # 检查四元组中是否有中文
115
  def check_zh(self):
116
  for quad_dict in self.data:
117
+ if re.search('[\u4e00-\u9fa5]', quad_dict['target_text']) or re.search('[\u4e00-\u9fa5]',
118
+ quad_dict['opinion_text']):
119
  return True
120
  return False
121
 
122
  # 检测极性是否正确
123
  def check_polarity(self):
124
+ # 若有某个四元组的极性不是positive、negative、neutral,则返回False
125
  for quad_dict in self.data:
126
+ if quad_dict['polarity'] not in ['positive', 'negative', 'neutral', 'pos', 'neg', 'neu', '积极', '消极',
127
+ '中性']:
128
+ self.abnormal = True
129
  return False
130
 
131
+ # 将极性由英文转为中文
132
  def convert_polarity_en2zh(self):
133
  for quad_dict in self.data:
134
+ quad_dict['polarity'] = self.polarity_en2zh[quad_dict['polarity']]
135
  return self
136
 
137
+ # 将极性由中文转为英文
138
  def convert_polarity_zh2en(self):
139
  for quad_dict in self.data:
140
+ quad_dict['polarity'] = self.polarity_zh2en[quad_dict['polarity']]
141
  return self
142
 
143
+ # 检查是否有重复的四元组,若有则删除重复的
144
  def del_duplicate(self):
145
+ new_data = []
146
  for quad_dict in self.data:
147
  if quad_dict not in new_data:
148
  new_data.append(quad_dict)
149
+ self.data = new_data
150
  return self
151
 
152
+ # 检查是否有target和opinion都为null的四元组,若有则返回True
153
  def check_target_opinion_null(self):
154
  for quad_dict in self.data:
155
+ if quad_dict['target_text'] == 'null' and quad_dict['opinion_text'] == 'null':
156
  return True
157
  return False
158
 
159
+ # 检查是否有target或opinion为null的四元组,若有则返回True
160
  def check_any_null(self):
161
  for quad_dict in self.data:
162
+ if quad_dict['target_text'] == 'null' or quad_dict['opinion_text'] == 'null':
163
  return True
164
  return False
165
 
166
  @classmethod
167
+ def from_str(cls, quadruple_str: str, tuple_len: Union[int, list, str] = 4, format_code=0, sep_token1=' & ',
168
+ sep_token2=' | '):
169
  data = []
170
+ abnormal = False
171
+ # 确保分隔符后面一定是空格
172
+ for i in range(len(quadruple_str) - 1):
173
+ if (quadruple_str[i] == sep_token1.strip() or quadruple_str[i] == sep_token2.strip()) and quadruple_str[
174
+ i + 1] != ' ':
175
  quadruple_str = quadruple_str[:i + 1] + ' ' + quadruple_str[i + 1:]
176
 
177
  # 选择几元组,即创建列表索引,从四元组中抽出n元
 
185
  else:
186
  raise Exception('tuple_len参数错误')
187
 
 
188
  for quadruple in quadruple_str.split(sep_token1):
189
  if format_code == 0:
190
  # quadruple可能是target|opinion|aspect|polarity,也可能是target|opinion|aspect,也可能是target|opinion,若没有则为“None”
191
+ quadruple_split = [unit.strip() for unit in quadruple.split(sep_token2)]
192
+ if len(quadruple_split) > len(tuple_index):
193
  print('quadruple格式错误,过多元素', quadruple_str)
194
+ abnormal = True
195
+ quadruple_split = quadruple_split[0:len(tuple_index)] # 过长则截断
196
+ elif len(quadruple_split) < len(tuple_index):
197
  print('quadruple格式错误,过少元素', quadruple_str)
198
+ abnormal = True
199
+ quadruple_split = ["None"] * (
200
+ len(tuple_index) - len(quadruple_split)) + quadruple_split # 过短则补'None'
201
 
202
+ quadruple_keys = [["target_text", "opinion_text", "aspect", "polarity"][i] for i in tuple_index]
203
+ quadruple_dict = dict(zip(quadruple_keys, quadruple_split))
204
 
205
  q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'}
206
  q.update(quadruple_dict)
207
+ # 检查极性是否合法
208
+ if q['polarity'] not in ['pos', 'neg', 'neu', 'None', '积极', '消极', '中性']:
209
  print('quadruple格式错误,极性格式不对', quadruple_str)
210
 
211
  else:
 
213
 
214
  data.append(q)
215
 
216
+ return CommentUnitsSim(data, quadruple_str, abnormal)
217
 
218
  @classmethod
219
+ def from_list(cls, quadruple_list: List[List[str]], **kwargs):
220
  data = []
221
  for quadruple in quadruple_list:
222
  # #format_code='013'代表list只有四元组的第0、1、3个元素,需要扩充为4元组,空缺位置补上None
 
227
  {"target_text": quadruple[0], "opinion_text": quadruple[1], "aspect": quadruple[2],
228
  "polarity": quadruple[3]})
229
 
230
+ return CommentUnitsSim(data, quadruple_list, **kwargs)
231
 
232
  @classmethod
233
+ def from_list_dict(cls, quadruple_list: List[dict], **kwargs):
234
  for quad_dict in quadruple_list:
235
  if 'target' in quad_dict:
236
  quad_dict['target_text'] = quad_dict['target']
 
241
 
242
  data = []
243
  for quadruple in quadruple_list:
244
+ # 如果quadruple缺少某个key,则补上None
245
+ q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'}
246
  q.update(quadruple)
247
  data.append(q)
248
 
249
+ return CommentUnitsSim(data, quadruple_list, **kwargs)
250
 
251
+ # 转化为list,即只保留字典的value
252
  def to_list(self):
253
  data = []
254
  for quad_dict in self.data:
255
+ data.append(
256
+ [quad_dict['target_text'], quad_dict['opinion_text'], quad_dict['aspect'], quad_dict['polarity']])
257
  return data
258
 
259
  # 将data转换为n元组字符串
260
+ def get_quadruple_str(self, format_code=0, tuple_len: Union[int, list, str] = 4, sep_token1=' & ',
261
+ sep_token2=' | '):
262
  new_text_list = []
263
  # 选择几元组,即创建列表索引,从四元组中抽出n元
264
  if isinstance(tuple_len, int):
 
272
  raise Exception('tuple_len参数错误')
273
 
274
  try:
275
+ # 若语言为中文,则使用中文极性
276
+ if self.language == 'zh':
277
  self.convert_polarity_en2zh()
278
  else:
279
  self.convert_polarity_zh2en()
280
  except:
281
+ print('语言参数错误', self.data)
282
  print(self.language)
283
  raise Exception('语言参数错误')
284
 
285
+ # 若tuple_index==[3],则返回综合情感极性
286
+ if tuple_index == [3]:
287
  return self.merge_polarity()
288
 
289
  for quad_dict in self.data:
 
296
  # 提取polarity
297
  polarity = quad_dict['polarity']
298
 
 
299
  # 拼接,‘|’分割
300
  if format_code == 0:
301
  # 根据tuple_len拼接
 
305
 
306
  new_text_list.append(new_text)
307
 
308
+ # 如果tuple_index为[2,3],则需要去除new_text_list中重复的元素,不要改变顺序。因为可能有重复的方面
309
+ if tuple_index == [2, 3]:
310
  res = []
311
  for t in new_text_list:
312
  if t not in res:
313
  res.append(t)
314
+ new_text_list = res
315
 
316
+ # 如果tuple_index为[3],则只保留new_text_list的第一个元素。因为只有一个情感极性
317
+ elif tuple_index == [3]:
318
+ new_text_list = new_text_list[:1]
319
 
320
  if format_code == 0:
321
  # 根据tuple_len拼接
322
  return sep_token1.join(new_text_list)
323
 
324
  # 与另一个CommentUnits对象对比,检测有几个相同的四元组
325
+ def compare_same(self, other) -> int:
326
  count = 0
327
  for quad_dict in self.data:
328
  if quad_dict in other.data:
 
378
  polarity_list.append(quad_dict['polarity'])
379
  return polarity_list
380
 
381
+ # 对所有polarity进行综合
382
  def merge_polarity(self):
383
  polarity_list = self.get_polarity_list()
384
+ # 判断是英文还是中文
385
  if self.language == 'en':
386
  if 'pos' in polarity_list and 'neg' in polarity_list:
387
  return 'neu'
 
401
  else:
402
  return '中性'
403
 
404
+ # 检测是否有不合法opinion
405
  def check_opinion_in_comment(self, comment_text):
406
  for quad_dict in self.data:
407
+ if quad_dict['opinion_text'] != '*' and (not quad_dict['opinion_text'] in comment_text):
408
  return False
409
  return True
410
 
411
+ # 检测是否有不合法target
412
+ def check_target_in_comment(self, comment_text):
413
  for quad_dict in self.data:
414
+ if quad_dict['target_text'] != '*' and (not quad_dict['target_text'] in comment_text):
415
  return False
416
  return True
417
 
418
+ # 计算两个四元组的相似度
419
  @staticmethod
420
  def get_similarity(units1, units2: 'CommentUnitsSim'):
421
  pass
422
 
423
+ # 对自身数据进行操作
424
+ def apply(self, func: Callable, field: str):
425
  for quad_dict in self.data:
426
  quad_dict[field] = func(quad_dict[field])
427
  return self
428
 
429
 
430
+ # 四元组匹配函数
431
  class CommentUnitsMatch:
432
+ def __init__(self, target_weight=0.5, opinion_weight=0.5, aspect_weight=0.5, polarity_weight=0.5, one_match=True):
433
+ # 归一化权重
434
+ weight_sum = target_weight + opinion_weight + aspect_weight + polarity_weight
435
+ self.target_weight = target_weight / weight_sum
436
+ self.opinion_weight = opinion_weight / weight_sum
437
+ self.aspect_weight = aspect_weight / weight_sum
438
+ self.polarity_weight = polarity_weight / weight_sum
439
+
440
+ # 是否一对一匹配
441
+ self.one_match = one_match
442
+
443
+ # 特定feature置零
444
+ def set_zero(self, feature: str = 'polarity'):
445
  if feature == 'polarity':
446
  self.polarity_weight = 0
447
  elif feature == 'aspect':
 
454
  raise Exception('feature参数错误')
455
 
456
  def re_normalize(self):
457
+ weight_sum = self.target_weight + self.opinion_weight + self.aspect_weight + self.polarity_weight
458
+ self.target_weight = self.target_weight / weight_sum
459
+ self.opinion_weight = self.opinion_weight / weight_sum
460
+ self.aspect_weight = self.aspect_weight / weight_sum
461
+ self.polarity_weight = self.polarity_weight / weight_sum
462
+
463
+ # 计算cost矩阵,完全匹配为0,不匹配为1
464
+ def get_cost_matrix(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim', feature: str = 'polarity'):
 
465
  pass
466
+ # 检查此feature是否存在,不存在则返回全0矩阵
467
+ if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None \
468
+ or units1.data[0].get(feature) == 'None' or units2.data[0].get(feature) == 'None':
469
+ cost_matrix = np.zeros((len(units1.data), len(units2.data)))
470
+ # 对应feature的weight也为0
471
  self.set_zero(feature)
472
 
473
  # 并再次归一化
 
475
 
476
  return cost_matrix
477
 
478
+ # 检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。不相同则cost为1,相同则cost为0
479
  cost_matrix = []
480
  for quad_dict1 in units1.data:
481
  cost_list = []
 
486
  cost_list.append(1)
487
  cost_matrix.append(cost_list)
488
 
489
+ # cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
490
  cost_matrix = np.array(cost_matrix)
491
  return cost_matrix
492
 
493
+ # 计算cost矩阵,使用rougel指标
494
+ def get_cost_matrix_rouge(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim', feature: str = 'target_text'):
495
+ # 检查此feature是否存在,不存在则返回全0矩阵
496
+ if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None \
497
+ or units1.data[0].get(feature) == 'None' or units2.data[0].get(feature) == 'None':
498
+ cost_matrix = np.zeros((len(units1.data), len(units2.data)))
499
+ # 对应feature的weight也为0
500
  self.set_zero(feature)
501
  # 并再次归一化
502
  self.re_normalize()
503
  return cost_matrix
504
 
505
+ # 检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。相同则cost为0,不相同则cost为1-rougel
506
  cost_matrix = []
507
  for quad_dict1 in units1.data:
508
  cost_list = []
 
510
  if quad_dict1[feature] == quad_dict2[feature]:
511
  cost_list.append(0)
512
  else:
513
+ cost_list.append(1 - get_rougel_f1([quad_dict1[feature]], [quad_dict2[feature]]))
514
  cost_matrix.append(cost_list)
515
 
516
+ # cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
517
  cost_matrix = np.array(cost_matrix)
518
  return cost_matrix
519
 
520
+ # 匹配四元组并计算cost
521
+ def match_units(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim') -> tuple:
522
+ # 计算极性的cost矩阵,矩阵元素在0-1之间
523
+ cost_matrix_polarity = self.get_cost_matrix(units1, units2, feature='polarity')
524
+ # 计算aspect的cost矩阵
525
+ cost_matrix_aspect = self.get_cost_matrix(units1, units2, feature='aspect')
526
+ # 计算target的cost矩阵
527
+ cost_matrix_target = self.get_cost_matrix_rouge(units1, units2, feature='target_text')
528
+ # 计算opinion的cost矩阵
529
+ cost_matrix_opinion = self.get_cost_matrix_rouge(units1, units2, feature='opinion_text')
530
+
531
+ # 计算总的cost矩阵,矩阵元素在0-1之间。矩阵的行数为units1即pred的数量,列数为units2即true的数量
532
+ cost_matrix = self.target_weight * cost_matrix_target + self.opinion_weight * cost_matrix_opinion + \
533
+ self.aspect_weight * cost_matrix_aspect + self.polarity_weight * cost_matrix_polarity
534
+ score_matrix = 1 - cost_matrix
535
+
536
+ cost = 0
537
+ # 使用匈牙利算法进行匹配
538
+ if self.one_match:
539
+ # 只允许一对一的匹配,这种情况下row_ind和col_ind的长度一定相等且等于units1和units2的数量中的较小值
540
  row_ind, col_ind = linear_sum_assignment(cost_matrix)
541
+
542
  else:
543
+ # 允许一对多的匹配。这种情况下每个四元组都一定匹配上,这种情况下row_ind和col_ind的长度一定相等且等于units1和units2的数量中的较大值
544
+ if units1.num > units2.num:
545
+ row_ind = np.arange(units1.num)
546
+ col_ind = np.argmin(cost_matrix, axis=1)
547
 
548
+ else:
549
+ row_ind = np.argmin(cost_matrix, axis=0)
550
+ col_ind = np.arange(units2.num)
551
 
552
+ # 计算这种匹配的cost
 
553
  for i in range(len(row_ind)):
554
  cost += cost_matrix[row_ind[i]][col_ind[i]]
555
 
556
+ # 计算这种匹配下的TP\FP\FN
557
  TP = 0
558
  for i in range(len(row_ind)):
559
  TP += score_matrix[row_ind[i]][col_ind[i]]
560
 
561
+ # len(row_ind)为pred的数量,TP为匹配上的数量
562
+ FP = units1.num - TP
563
+ FN = units2.num - TP
 
564
 
565
+ # 如果一对一匹配,会有匹配不上的四元组,这些四元组cost为1
566
+ max_units_num = max(units1.num, units2.num)
567
+ if self.one_match:
568
+ cost += (max_units_num - len(row_ind))
569
 
570
+ # 对cost进行归一化,使其在0-1之间
571
+ cost_per_quadruple = cost / max_units_num
572
+ if cost_per_quadruple > 1 or cost_per_quadruple < 0:
573
+ print('cost错误', cost_per_quadruple, 'pred:', units1.data, 'true:', units2.data)
574
+ print(self.target_weight, self.opinion_weight, self.aspect_weight, self.polarity_weight)
575
 
576
+ # 返回的cost在0-1之间
577
+ return cost_per_quadruple, TP, FP, FN
 
 
 
578
 
579
 
580
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
 
617
  pass
618
 
619
  def _compute(self,
620
+ predictions: List[str],
621
+ references: Union[List[str], List[List[str]]],
622
+ quad_weights: tuple = (1, 1, 1, 1),
623
  **kwargs) -> dict:
624
  '''
625
 
 
658
  '13':'二元组(观点 | 极性)',
659
  '3':'单元素(极性)'}
660
  '''
661
+ f1_of_optimal_match, score_of_optimal_match = self.quad_f1_of_optimal_match(predictions, references,
662
+ quad_weights, **kwargs)
663
+ f1 = self.quad_f1_of_exact_match(y_pred=predictions, y_true=references, **kwargs)
664
 
665
+ # 取1-cost为得分
666
+ return {'score of optimal match of weight ' + str(quad_weights): score_of_optimal_match,
667
+ 'f1 of optimal match of weight ' + str(quad_weights): f1_of_optimal_match,
668
+ 'f1 of exact match': f1}
669
 
670
+ @staticmethod
671
+ def quad_f1_of_exact_match(y_pred: List[str], y_true: Union[List[str], List[List[str]]],
672
+ return_dict=False, **kwargs) -> Union[Dict[str, float], float]:
673
+ assert len(y_pred) == len(y_true), "文本数量不一致"
674
+ correct, pred_num, true_num = 0, 0, 0
675
 
676
+ for pred, true in zip(y_pred, y_true):
677
+ pred = CommentUnitsSim.from_str(pred, **kwargs)
678
  # 如果true是list,说明有多个正确答案
679
  if isinstance(true, str):
680
  true = CommentUnitsSim.from_str(true, **kwargs)
 
 
681
  else:
682
+ true = [CommentUnitsSim.from_str(t, **kwargs) for t in true]
 
683
 
684
+ # 如果true是list,说明有多个正确答案,取最高分
685
  if isinstance(true, list):
686
+ correct_list = [pred.compare_same(t) for t in true]
687
+ correct += max(correct_list) # 获取得分最高的值
688
+ correct_index = correct_list.index(max(correct_list)) # 获取得分最高的索引
689
+ pred_num += pred.num
690
+ true_num += true[correct_index].num
 
 
 
691
  else:
692
+ correct += pred.compare_same(true)
693
+ pred_num += pred.num
694
+ true_num += true.num
695
+
696
+ # 以下结果保留4位小数
697
+ precision = round(correct / pred_num, 4) + 1e-8
698
+ recall = round(correct / true_num, 4) + 1e-8
699
+ f1 = round(2 * precision * recall / (precision + recall), 4)
700
+
701
+ if return_dict:
702
+ return {"precision": precision, "recall": recall, "f1": f1}
703
+ else:
704
+ return f1
705
+
706
+ # 计算最优匹配f1
707
+ @staticmethod
708
+ def quad_f1_of_optimal_match(
709
+ predictions: List[str],
710
+ references: Union[List[str], List[List[str]]],
711
+ quad_weights: tuple = (1, 1, 1, 1),
712
+ one_match=True,
713
+ **kwargs):
714
+
715
+ assert len(predictions) == len(references)
716
+ if isinstance(predictions, str):
717
+ predictions = [predictions]
718
+ references = [references]
719
+
720
+ cost = 0
721
+ TP, FP, FN = 0, 0, 0
722
+ matcher = CommentUnitsMatch(*quad_weights, one_match=one_match)
723
+
724
+ for pred, refer in zip(predictions, references):
725
+
726
+ pred = CommentUnitsSim.from_str(pred, **kwargs)
727
+ # 将refer转换为list形式
728
+ if isinstance(refer, str):
729
+ refer = [refer]
730
+
731
+ # 将refer中的每个元素转换为CommentUnitsSim
732
+ refer = [CommentUnitsSim.from_str(t, **kwargs) for t in refer]
733
+
734
+ # 如果true是多个正确答案,取最高分
735
+ cost_list = [matcher.match_units(pred, t) for t in refer]
736
+ # 获取得分最高的值的索引,按元组中第一个元素大小排序
737
+ # 计算每一对样本的cost,TP,FP,FN
738
+ cost_, TP_, FP_, FN_ = cost_list[np.argmax([c[0] for c in cost_list])]
739
+ cost += cost_
740
+ TP += TP_
741
+ FP += FP_
742
+ FN += FN_
743
+
744
+ # 平均cost
745
+ cost = cost / len(predictions)
746
+ # 由TP\FP\FN计算最优匹配F1
747
+ precision_match = TP / (TP + FP)
748
+ recall_match = TP / (TP + FN)
749
+ f1_match = 2 * precision_match * recall_match / (precision_match + recall_match)
750
+
751
+ return f1_match, 1 - cost