Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- README.md +66 -5
- app.py +10 -0
- quad_match_score.py +727 -0
- requirements.txt +3 -0
- tests.py +9 -0
README.md
CHANGED
@@ -1,12 +1,73 @@
|
|
1 |
---
|
2 |
title: Quad Match Score
|
3 |
-
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: Quad Match Score
|
3 |
+
datasets:
|
4 |
+
- SemEval2016 Task5
|
5 |
+
tags:
|
6 |
+
- evaluate
|
7 |
+
- metric
|
8 |
+
description: "TODO: add a description here"
|
9 |
sdk: gradio
|
10 |
+
sdk_version: 3.19.1
|
11 |
app_file: app.py
|
12 |
pinned: false
|
13 |
---
|
14 |
|
15 |
+
# Metric Card for My Metric
|
16 |
+
|
17 |
+
***Module Card Instructions:*** *评估生成模型的情感四元组抽取结果.*
|
18 |
+
|
19 |
+
## Metric Description
|
20 |
+
*评估生成模型的情感四元组抽取结果.*
|
21 |
+
|
22 |
+
## How to Use
|
23 |
+
```python
|
24 |
+
import evaluate
|
25 |
+
|
26 |
+
module = evaluate.load("yuyijiong/my_metric")
|
27 |
+
|
28 |
+
predictions=["food | good | food#taste | pos"]
|
29 |
+
references=["food | good | food#taste | pos & service | bad | service#general | neg"]
|
30 |
+
|
31 |
+
module.compute(predictions=predictions, references=references)
|
32 |
+
|
33 |
+
```
|
34 |
+
|
35 |
+
### Inputs
|
36 |
+
*List all input arguments in the format below*
|
37 |
+
- **predictions** *(List[str]): 模型生成的四元组,列表中每个字符串代表一个样本的生成结果.*
|
38 |
+
- **references** *(Union[List[str],List[List[str]]):
|
39 |
+
人工标注的四元组,列表中每个字符串代表一个样本的标签.如果列表元素为list,代表多个reference,评估时取最高分*
|
40 |
+
- **weights** *(Tuple[float, float, float, float]):默认为(1,1,1,1),分别代表四个方面的评估指标的权重*
|
41 |
+
- **tuple_len** *(str): indicate the format of the quad, see the following mapping
|
42 |
+
指示四元组的格式,默认为'0123'。对应关系如下所示*
|
43 |
+
```
|
44 |
+
{'0123': "四元组(对象 | 观点 | 方面 | 极性)",
|
45 |
+
'01':'二元组(对象 | 观点)',
|
46 |
+
'012':'三元组(对象 | 观点 | 方面)',
|
47 |
+
'013':'三元组(对象 | 观点 | 极性)',
|
48 |
+
'023':'三元组(对象 | 方面 | 极性)',
|
49 |
+
'23':'二元组(方面 | 极性)',
|
50 |
+
'03':'二元组(对象 | 极性)',
|
51 |
+
'13':'二元组(观点 | 极性)',
|
52 |
+
'3':'单元素(极性)'}
|
53 |
+
```
|
54 |
+
- **sep_token1**: the token to seperate quads 分割不同四元组的token
|
55 |
+
- **sep_token2**: the token to seperate units of one quad 四元组中不同元素之间的分隔token
|
56 |
+
|
57 |
+
### Output Values
|
58 |
+
|
59 |
+
*最优匹配 f1值、最优匹配样本平均得分、完全匹配 f1值 组成的dict,f1值均在[0,1]之间*
|
60 |
+
|
61 |
+
*例如: {'ave match score of weight (1, 1, 1, 1)': 0.375,
|
62 |
+
'f1 score of exact match': 0.0,
|
63 |
+
'f1 score of optimal match of weight (1, 1, 1, 1)': 0.5}*
|
64 |
+
|
65 |
+
|
66 |
+
## Limitations and Bias
|
67 |
+
*对比传统评估指标,得分偏高*
|
68 |
+
|
69 |
+
## Citation
|
70 |
+
*论文即将发表*
|
71 |
+
|
72 |
+
## Further References
|
73 |
+
*Add any useful further references.*
|
app.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import evaluate
|
2 |
+
from evaluate.utils import launch_gradio_widget
|
3 |
+
|
4 |
+
module = evaluate.load("yuyijiong/quad_match_score")
|
5 |
+
launch_gradio_widget(module)
|
6 |
+
|
7 |
+
# predictions=["a | b | c | pos"]
|
8 |
+
# references=["a | b | c | pos & e | f | g | neg"]
|
9 |
+
#
|
10 |
+
# module.compute(predictions=predictions, references=references)
|
quad_match_score.py
ADDED
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""TODO: Add a description here."""
|
15 |
+
|
16 |
+
import copy
|
17 |
+
import re
|
18 |
+
from typing import List, Dict, Union,Callable
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
|
22 |
+
import datasets
|
23 |
+
import evaluate
|
24 |
+
from rouge_chinese import Rouge
|
25 |
+
from scipy.optimize import linear_sum_assignment
|
26 |
+
|
27 |
+
# TODO: Add BibTeX citation
|
28 |
+
_CITATION = """\
|
29 |
+
@InProceedings{huggingface:module,
|
30 |
+
title = {A great new module},
|
31 |
+
authors={huggingface, Inc.},
|
32 |
+
year={2020}
|
33 |
+
}
|
34 |
+
"""
|
35 |
+
|
36 |
+
# TODO: Add description of the module here
|
37 |
+
_DESCRIPTION = """\
|
38 |
+
evaluate sentiment quadruples.
|
39 |
+
评估生成模型的情感四元组
|
40 |
+
"""
|
41 |
+
|
42 |
+
|
43 |
+
# TODO: Add description of the arguments of the module here
|
44 |
+
_KWARGS_DESCRIPTION = """
|
45 |
+
Calculates how good are predictions given some references, using certain scores
|
46 |
+
Args:
|
47 |
+
predictions: list of predictions to score. Each predictions
|
48 |
+
should be a string with tokens separated by spaces.
|
49 |
+
references: list of reference for each prediction. Each
|
50 |
+
reference should be a string with tokens separated by spaces.
|
51 |
+
Returns:
|
52 |
+
score: sentiment quadruple match score
|
53 |
+
|
54 |
+
Examples:
|
55 |
+
Examples should be written in doctest format, and should illustrate how
|
56 |
+
to use the function.
|
57 |
+
|
58 |
+
>>> my_new_module = evaluate.load("my_new_module")
|
59 |
+
>>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
|
60 |
+
>>> print(results)
|
61 |
+
{'accuracy': 1.0}
|
62 |
+
"""
|
63 |
+
|
64 |
+
|
65 |
+
def compute_quadruple_f1(y_pred: List[str], y_true: Union[List[str], List[List[str]]],
|
66 |
+
return_rp=False, **kwargs) -> Dict[str, float]:
|
67 |
+
assert len(y_pred) == len(y_true)
|
68 |
+
correct, pred_num, true_num = 0, 0, 0
|
69 |
+
|
70 |
+
for pred, true in zip(y_pred, y_true):
|
71 |
+
|
72 |
+
pred = CommentUnitsSim.from_str(pred, **kwargs)
|
73 |
+
# 如果true是list,说明有多个正确答案
|
74 |
+
if isinstance(true, str):
|
75 |
+
true = CommentUnitsSim.from_str(true, **kwargs)
|
76 |
+
else:
|
77 |
+
true = [CommentUnitsSim.from_str(t,**kwargs) for t in true]
|
78 |
+
|
79 |
+
# 如果true是list,说明有多个正确答案,取最高分
|
80 |
+
if isinstance(true, list):
|
81 |
+
correct_list = [pred.compare_same(t) for t in true]
|
82 |
+
correct += max(correct_list) # 获取得分最高的值
|
83 |
+
correct_index = correct_list.index(max(correct_list)) # 获取得分最高的索引
|
84 |
+
pred_num += pred.num
|
85 |
+
true_num += true[correct_index].num
|
86 |
+
else:
|
87 |
+
correct += pred.compare_same(true)
|
88 |
+
pred_num += pred.num
|
89 |
+
true_num += true.num
|
90 |
+
|
91 |
+
# 以下结果保留4位小数
|
92 |
+
precision = round(correct / pred_num, 4) + 1e-8
|
93 |
+
recall = round(correct / true_num, 4) + 1e-8
|
94 |
+
f1 = round(2 * precision * recall / (precision + recall), 4)
|
95 |
+
|
96 |
+
if return_rp:
|
97 |
+
return {"precision": precision, "recall": recall, "f1": f1}
|
98 |
+
else:
|
99 |
+
return f1
|
100 |
+
|
101 |
+
# 计算rougel的f1值
|
102 |
+
def get_rougel_f1(text_pred_list: List[str], text_true_list: List[str]) -> float:
|
103 |
+
assert len(text_pred_list) == len(text_true_list), "文本数量不一致"
|
104 |
+
#如果text_pred_list[0]为空字符串或空格,则返回0
|
105 |
+
if not text_pred_list[0].strip():
|
106 |
+
return 0
|
107 |
+
|
108 |
+
rouge = Rouge()
|
109 |
+
# 判断text_true[0]是否有中文,有中文则要用空格分割
|
110 |
+
if re.search(u"[\u4e00-\u9fa5]+", text_pred_list[0]):
|
111 |
+
text_pred_list = [' '.join(list(text_pred)) for text_pred in text_pred_list]
|
112 |
+
text_true_list = [' '.join(list(text_true)) for text_true in text_true_list]
|
113 |
+
|
114 |
+
rouge_l_f1 = rouge.get_scores(text_pred_list, text_true_list, avg=True)['rouge-l']['f']
|
115 |
+
|
116 |
+
return rouge_l_f1
|
117 |
+
|
118 |
+
# 记录四元组的函数
|
119 |
+
class CommentUnitsSim:
|
120 |
+
def __init__(self, data: List[Dict[str, str]],data_source:any=None,abnormal=False,language=None):
|
121 |
+
self.data_source=data_source
|
122 |
+
self.abnormal=abnormal
|
123 |
+
data=copy.deepcopy(data)
|
124 |
+
# 如果字典有target,则改名为target_text
|
125 |
+
for quad_dict in data:
|
126 |
+
if 'target' in quad_dict:
|
127 |
+
quad_dict['target_text'] = quad_dict['target']
|
128 |
+
del quad_dict['target']
|
129 |
+
if 'opinion' in quad_dict:
|
130 |
+
quad_dict['opinion_text'] = quad_dict['opinion']
|
131 |
+
del quad_dict['opinion']
|
132 |
+
|
133 |
+
self.data = data
|
134 |
+
self.polarity_en2zh = {'positive': '积极', 'negative': '消极', 'neutral': '中性','pos':'积极','neg':'消极','neu':'中性','积极':'积极','消极':'���极','中性':'中性'}
|
135 |
+
self.polarity_zh2en={'积极':'pos','消极':'neg','中性':'neu','pos':'pos','neg':'neg','neu':'neu','positive':'pos','negative':'neg','neutral':'neu'}
|
136 |
+
|
137 |
+
self.language=language if language is not None else 'zh' if self.check_zh() else 'en'
|
138 |
+
self.none_sign='null'
|
139 |
+
|
140 |
+
@property
|
141 |
+
def num(self):
|
142 |
+
return len(self.data)
|
143 |
+
|
144 |
+
#检查四元组中是否有中文
|
145 |
+
def check_zh(self):
|
146 |
+
for quad_dict in self.data:
|
147 |
+
if re.search('[\u4e00-\u9fa5]',quad_dict['target_text']) or re.search('[\u4e00-\u9fa5]',quad_dict['opinion_text']):
|
148 |
+
return True
|
149 |
+
return False
|
150 |
+
|
151 |
+
# 检测极性是否正确
|
152 |
+
def check_polarity(self):
|
153 |
+
#若有某个四元组的极性不是positive、negative、neutral,则返回False
|
154 |
+
for quad_dict in self.data:
|
155 |
+
if quad_dict['polarity'] not in ['positive', 'negative', 'neutral','pos','neg','neu','积极','消极','中性']:
|
156 |
+
self.abnormal=True
|
157 |
+
return False
|
158 |
+
|
159 |
+
#将极性由英文转为中文
|
160 |
+
def convert_polarity_en2zh(self):
|
161 |
+
for quad_dict in self.data:
|
162 |
+
quad_dict['polarity']=self.polarity_en2zh[quad_dict['polarity']]
|
163 |
+
return self
|
164 |
+
|
165 |
+
#将极性由中文转为英文
|
166 |
+
def convert_polarity_zh2en(self):
|
167 |
+
for quad_dict in self.data:
|
168 |
+
quad_dict['polarity']=self.polarity_zh2en[quad_dict['polarity']]
|
169 |
+
return self
|
170 |
+
|
171 |
+
#检查是否有重复的四元组,若有则删除重复的
|
172 |
+
def del_duplicate(self):
|
173 |
+
new_data=[]
|
174 |
+
for quad_dict in self.data:
|
175 |
+
if quad_dict not in new_data:
|
176 |
+
new_data.append(quad_dict)
|
177 |
+
self.data=new_data
|
178 |
+
return self
|
179 |
+
|
180 |
+
#检查是否有target和opinion都为null的四元组,若有则返回True
|
181 |
+
def check_target_opinion_null(self):
|
182 |
+
for quad_dict in self.data:
|
183 |
+
if quad_dict['target_text']=='null' and quad_dict['opinion_text']=='null':
|
184 |
+
return True
|
185 |
+
return False
|
186 |
+
|
187 |
+
#检查是否有target或opinion为null的四元组,若有则返回True
|
188 |
+
def check_any_null(self):
|
189 |
+
for quad_dict in self.data:
|
190 |
+
if quad_dict['target_text']=='null' or quad_dict['opinion_text']=='null':
|
191 |
+
return True
|
192 |
+
return False
|
193 |
+
|
194 |
+
@classmethod
|
195 |
+
def from_str(cls, quadruple_str: str, tuple_len:Union[int,list,str]=4, format_code=0, sep_token1=' & ', sep_token2=' | '):
|
196 |
+
data = []
|
197 |
+
abnormal=False
|
198 |
+
#确保分隔符后面一定是空格
|
199 |
+
for i in range(len(quadruple_str)-1):
|
200 |
+
if (quadruple_str[i] == sep_token1.strip() or quadruple_str[i] == sep_token2.strip()) and quadruple_str[i + 1] != ' ':
|
201 |
+
quadruple_str = quadruple_str[:i + 1] + ' ' + quadruple_str[i + 1:]
|
202 |
+
|
203 |
+
# 选择几元组,即创建列表索引,从四元组中抽出n元
|
204 |
+
if isinstance(tuple_len, int):
|
205 |
+
tuple_index = list(range(tuple_len))
|
206 |
+
elif isinstance(tuple_len, list):
|
207 |
+
tuple_index = tuple_len
|
208 |
+
elif isinstance(tuple_len, str):
|
209 |
+
# 例如将‘012’转换为[0,1,2]
|
210 |
+
tuple_index = [int(i) for i in tuple_len]
|
211 |
+
else:
|
212 |
+
raise Exception('tuple_len参数错误')
|
213 |
+
|
214 |
+
|
215 |
+
for quadruple in quadruple_str.split(sep_token1):
|
216 |
+
if format_code == 0:
|
217 |
+
# quadruple可能是target|opinion|aspect|polarity,也可能是target|opinion|aspect,也可能是target|opinion,若没有则为“None”
|
218 |
+
quadruple_split=[unit.strip() for unit in quadruple.split(sep_token2)]
|
219 |
+
if len(quadruple_split)>len(tuple_index):
|
220 |
+
print('quadruple格式错误,过多元素', quadruple_str)
|
221 |
+
abnormal=True
|
222 |
+
quadruple_split=quadruple_split[0:len(tuple_index)] #过长则截断
|
223 |
+
elif len(quadruple_split)<len(tuple_index):
|
224 |
+
print('quadruple格式错误,过少元素', quadruple_str)
|
225 |
+
abnormal=True
|
226 |
+
quadruple_split=["None"]*(len(tuple_index)-len(quadruple_split))+quadruple_split #过短则补'None'
|
227 |
+
|
228 |
+
quadruple_keys=[["target_text","opinion_text","aspect","polarity"][i] for i in tuple_index]
|
229 |
+
quadruple_dict=dict(zip(quadruple_keys,quadruple_split))
|
230 |
+
|
231 |
+
q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'}
|
232 |
+
q.update(quadruple_dict)
|
233 |
+
#检查极性是否合法
|
234 |
+
if q['polarity'] not in ['pos','neg','neu','None','积极','消极','中性']:
|
235 |
+
print('quadruple格式错误,极性格式不对', quadruple_str)
|
236 |
+
|
237 |
+
else:
|
238 |
+
raise Exception('answer_format参数错误')
|
239 |
+
|
240 |
+
data.append(q)
|
241 |
+
|
242 |
+
return CommentUnitsSim(data,quadruple_str,abnormal)
|
243 |
+
|
244 |
+
@classmethod
|
245 |
+
def from_list(cls, quadruple_list: List[List[str]],**kwargs):
|
246 |
+
data = []
|
247 |
+
for quadruple in quadruple_list:
|
248 |
+
# #format_code='013'代表list只有四元组的第0、1、3个元素,需要扩充为4元组,空缺位置补上None
|
249 |
+
# if format_code=='013':
|
250 |
+
# quadruple.insert(2,None)
|
251 |
+
|
252 |
+
data.append(
|
253 |
+
{"target_text": quadruple[0], "opinion_text": quadruple[1], "aspect": quadruple[2],
|
254 |
+
"polarity": quadruple[3]})
|
255 |
+
|
256 |
+
return CommentUnitsSim(data,quadruple_list,**kwargs)
|
257 |
+
|
258 |
+
@classmethod
|
259 |
+
def from_list_dict(cls, quadruple_list: List[dict],**kwargs):
|
260 |
+
for quad_dict in quadruple_list:
|
261 |
+
if 'target' in quad_dict:
|
262 |
+
quad_dict['target_text'] = quad_dict['target']
|
263 |
+
del quad_dict['target']
|
264 |
+
if 'opinion' in quad_dict:
|
265 |
+
quad_dict['opinion_text'] = quad_dict['opinion']
|
266 |
+
del quad_dict['opinion']
|
267 |
+
|
268 |
+
data = []
|
269 |
+
for quadruple in quadruple_list:
|
270 |
+
#如果quadruple缺少某个key,则补上None
|
271 |
+
q={"target_text":'None',"opinion_text":'None',"aspect":'None',"polarity":'None'}
|
272 |
+
q.update(quadruple)
|
273 |
+
data.append(q)
|
274 |
+
|
275 |
+
return CommentUnitsSim(data,quadruple_list,**kwargs)
|
276 |
+
|
277 |
+
#转化为list,即只保留字典的value
|
278 |
+
def to_list(self):
|
279 |
+
data = []
|
280 |
+
for quad_dict in self.data:
|
281 |
+
data.append([quad_dict['target_text'],quad_dict['opinion_text'],quad_dict['aspect'],quad_dict['polarity']])
|
282 |
+
return data
|
283 |
+
|
284 |
+
# 将data转换为n元组字符串
|
285 |
+
def get_quadruple_str(self, format_code=0, tuple_len:Union[int,list,str]=4,sep_token1=' & ',sep_token2=' | '):
|
286 |
+
new_text_list = []
|
287 |
+
# 选择几元组,即创建列表索引,从四元组中抽出n元
|
288 |
+
if isinstance(tuple_len, int):
|
289 |
+
tuple_index = list(range(tuple_len))
|
290 |
+
elif isinstance(tuple_len, list):
|
291 |
+
tuple_index = tuple_len
|
292 |
+
elif isinstance(tuple_len, str):
|
293 |
+
# 例如将‘012’转换为[0,1,2]
|
294 |
+
tuple_index = [int(i) for i in tuple_len]
|
295 |
+
else:
|
296 |
+
raise Exception('tuple_len参数错误')
|
297 |
+
|
298 |
+
try:
|
299 |
+
#若语言为中文,则使用中文极性
|
300 |
+
if self.language=='zh':
|
301 |
+
self.convert_polarity_en2zh()
|
302 |
+
else:
|
303 |
+
self.convert_polarity_zh2en()
|
304 |
+
except:
|
305 |
+
print('语言参数错误',self.data)
|
306 |
+
print(self.language)
|
307 |
+
raise Exception('语言参数错误')
|
308 |
+
|
309 |
+
#若tuple_index==[3],则返回综合情感极性
|
310 |
+
if tuple_index==[3]:
|
311 |
+
return self.merge_polarity()
|
312 |
+
|
313 |
+
for quad_dict in self.data:
|
314 |
+
# 提取target_text,如果空列表则为'',如果列表长度大于1则为','.join(list)
|
315 |
+
target_text = quad_dict['target_text']
|
316 |
+
# 提取opinion_text,如果空列表则为'',如果列表长度大于1则为','.join(list)
|
317 |
+
opinion_text = quad_dict['opinion_text']
|
318 |
+
# 提取aspect
|
319 |
+
aspect = quad_dict['aspect']
|
320 |
+
# 提取polarity
|
321 |
+
polarity = quad_dict['polarity']
|
322 |
+
|
323 |
+
|
324 |
+
# 拼接,‘|’分割
|
325 |
+
if format_code == 0:
|
326 |
+
# 根据tuple_len拼接
|
327 |
+
new_text = sep_token2.join([[target_text, opinion_text, aspect, polarity][i] for i in tuple_index])
|
328 |
+
else:
|
329 |
+
raise Exception('answer_format参数错误')
|
330 |
+
|
331 |
+
new_text_list.append(new_text)
|
332 |
+
|
333 |
+
#如果tuple_index为[2,3],则需要去除new_text_list中重复的元素,不要改变顺序。因为可能有重复的方面
|
334 |
+
if tuple_index==[2,3]:
|
335 |
+
res = []
|
336 |
+
for t in new_text_list:
|
337 |
+
if t not in res:
|
338 |
+
res.append(t)
|
339 |
+
new_text_list=res
|
340 |
+
|
341 |
+
#如果tuple_index为[3],则只保留new_text_list的第一个元素。因为只有一个情感极性
|
342 |
+
elif tuple_index==[3]:
|
343 |
+
new_text_list=new_text_list[:1]
|
344 |
+
|
345 |
+
if format_code == 0:
|
346 |
+
# 根据tuple_len拼接
|
347 |
+
return sep_token1.join(new_text_list)
|
348 |
+
|
349 |
+
# 与另一个CommentUnits对象对比,检测有几个相同的四元组
|
350 |
+
def compare_same(self, other)->int:
|
351 |
+
count = 0
|
352 |
+
for quad_dict in self.data:
|
353 |
+
if quad_dict in other.data:
|
354 |
+
count += 1
|
355 |
+
return count
|
356 |
+
|
357 |
+
# 检查自身数据的四元组中target是否有重复
|
358 |
+
def check_target_repeat(self):
|
359 |
+
target_list = []
|
360 |
+
for quad_dict in self.data:
|
361 |
+
target_list.append(quad_dict['target_text'])
|
362 |
+
return len(target_list) != len(set(target_list))
|
363 |
+
|
364 |
+
# 检查自身数据的四元组中opinion是否有重复
|
365 |
+
def check_opinion_repeat(self):
|
366 |
+
opinion_list = []
|
367 |
+
for quad_dict in self.data:
|
368 |
+
opinion_list.append(quad_dict['opinion_text'])
|
369 |
+
return len(opinion_list) != len(set(opinion_list))
|
370 |
+
|
371 |
+
# 检查自身数据的四元组中aspect是否有重复
|
372 |
+
def check_aspect_repeat(self):
|
373 |
+
aspect_list = []
|
374 |
+
for quad_dict in self.data:
|
375 |
+
aspect_list.append(quad_dict['aspect'])
|
376 |
+
return len(aspect_list) != len(set(aspect_list))
|
377 |
+
|
378 |
+
# 输出所有aspect的列表
|
379 |
+
def get_aspect_list(self):
|
380 |
+
aspect_list = []
|
381 |
+
for quad_dict in self.data:
|
382 |
+
aspect_list.append(quad_dict['aspect'])
|
383 |
+
return aspect_list
|
384 |
+
|
385 |
+
# 输出所有target的列表
|
386 |
+
def get_target_list(self):
|
387 |
+
target_list = []
|
388 |
+
for quad_dict in self.data:
|
389 |
+
target_list.append(quad_dict['target_text'])
|
390 |
+
return target_list
|
391 |
+
|
392 |
+
# 输出所有opinion的列表
|
393 |
+
def get_opinion_list(self):
|
394 |
+
opinion_list = []
|
395 |
+
for quad_dict in self.data:
|
396 |
+
opinion_list.append(quad_dict['opinion_text'])
|
397 |
+
return opinion_list
|
398 |
+
|
399 |
+
# 输出所有polarity的列表
|
400 |
+
def get_polarity_list(self):
|
401 |
+
polarity_list = []
|
402 |
+
for quad_dict in self.data:
|
403 |
+
polarity_list.append(quad_dict['polarity'])
|
404 |
+
return polarity_list
|
405 |
+
|
406 |
+
#对所有polarity进行综合
|
407 |
+
def merge_polarity(self):
|
408 |
+
polarity_list = self.get_polarity_list()
|
409 |
+
#判断是英文还是中文
|
410 |
+
if self.language == 'en':
|
411 |
+
if 'pos' in polarity_list and 'neg' in polarity_list:
|
412 |
+
return 'neu'
|
413 |
+
elif 'pos' in polarity_list:
|
414 |
+
return 'pos'
|
415 |
+
elif 'neg' in polarity_list:
|
416 |
+
return 'neg'
|
417 |
+
else:
|
418 |
+
return 'neu'
|
419 |
+
else:
|
420 |
+
if '积极' in polarity_list and '消极' in polarity_list:
|
421 |
+
return '中性'
|
422 |
+
elif '积极' in polarity_list:
|
423 |
+
return '积极'
|
424 |
+
elif '消极' in polarity_list:
|
425 |
+
return '消极'
|
426 |
+
else:
|
427 |
+
return '中性'
|
428 |
+
|
429 |
+
#检测是否有不合法opinion
|
430 |
+
def check_opinion_in_comment(self, comment_text):
|
431 |
+
for quad_dict in self.data:
|
432 |
+
if quad_dict['opinion_text'] !='*' and (not quad_dict['opinion_text'] in comment_text):
|
433 |
+
return False
|
434 |
+
return True
|
435 |
+
|
436 |
+
#检测是否有不合法target
|
437 |
+
def check_target_in_comment(self,comment_text):
|
438 |
+
for quad_dict in self.data:
|
439 |
+
if quad_dict['target_text'] !='*' and (not quad_dict['target_text'] in comment_text):
|
440 |
+
return False
|
441 |
+
return True
|
442 |
+
|
443 |
+
#计算两个四元组的相似度
|
444 |
+
@staticmethod
|
445 |
+
def get_similarity(units1, units2: 'CommentUnitsSim'):
|
446 |
+
pass
|
447 |
+
|
448 |
+
#对自身数据进行操作
|
449 |
+
def apply(self,func:Callable,field:str):
|
450 |
+
for quad_dict in self.data:
|
451 |
+
quad_dict[field] = func(quad_dict[field])
|
452 |
+
return self
|
453 |
+
|
454 |
+
|
455 |
+
#四元组匹配函数
|
456 |
+
class CommentUnitsMatch:
|
457 |
+
def __init__(self,target_weight=0.5,opinion_weight=0.5,aspect_weight=0.5,polarity_weight=0.5):
|
458 |
+
#归一化权重
|
459 |
+
weight_sum = target_weight+opinion_weight+aspect_weight+polarity_weight
|
460 |
+
self.target_weight = target_weight/weight_sum
|
461 |
+
self.opinion_weight = opinion_weight/weight_sum
|
462 |
+
self.aspect_weight = aspect_weight/weight_sum
|
463 |
+
self.polarity_weight = polarity_weight/weight_sum
|
464 |
+
|
465 |
+
#特定feature置零
|
466 |
+
def set_zero(self,feature:str='polarity'):
|
467 |
+
if feature == 'polarity':
|
468 |
+
self.polarity_weight = 0
|
469 |
+
elif feature == 'aspect':
|
470 |
+
self.aspect_weight = 0
|
471 |
+
elif 'opinion' in feature:
|
472 |
+
self.opinion_weight = 0
|
473 |
+
elif 'target' in feature:
|
474 |
+
self.target_weight = 0
|
475 |
+
else:
|
476 |
+
raise Exception('feature参数错误')
|
477 |
+
|
478 |
+
def re_normalize(self):
|
479 |
+
weight_sum = self.target_weight+self.opinion_weight+self.aspect_weight+self.polarity_weight
|
480 |
+
self.target_weight = self.target_weight/weight_sum
|
481 |
+
self.opinion_weight = self.opinion_weight/weight_sum
|
482 |
+
self.aspect_weight = self.aspect_weight/weight_sum
|
483 |
+
self.polarity_weight = self.polarity_weight/weight_sum
|
484 |
+
|
485 |
+
|
486 |
+
#计算cost矩阵
|
487 |
+
def get_cost_matrix(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',feature:str='polarity'):
|
488 |
+
pass
|
489 |
+
#检查此feature是否存在,不存在则返回全0矩阵
|
490 |
+
if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None\
|
491 |
+
or units1.data[0].get(feature)=='None' or units2.data[0].get(feature)=='None':
|
492 |
+
cost_matrix = np.zeros((len(units1.data),len(units2.data)))
|
493 |
+
#对应feature的weight也为0
|
494 |
+
self.set_zero(feature)
|
495 |
+
|
496 |
+
# 并再次归一化
|
497 |
+
self.re_normalize()
|
498 |
+
|
499 |
+
return cost_matrix
|
500 |
+
|
501 |
+
#检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。不相同则cost为1,相同则cost为0
|
502 |
+
cost_matrix = []
|
503 |
+
for quad_dict1 in units1.data:
|
504 |
+
cost_list = []
|
505 |
+
for quad_dict2 in units2.data:
|
506 |
+
if quad_dict1[feature] == quad_dict2[feature]:
|
507 |
+
cost_list.append(0)
|
508 |
+
else:
|
509 |
+
cost_list.append(1)
|
510 |
+
cost_matrix.append(cost_list)
|
511 |
+
|
512 |
+
#cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
|
513 |
+
cost_matrix = np.array(cost_matrix)
|
514 |
+
return cost_matrix
|
515 |
+
|
516 |
+
#计算cost矩阵,使用rouge指标
|
517 |
+
def get_cost_matrix_rouge(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',feature:str='target_text'):
|
518 |
+
#检查此feature是否存在,不存在则返回全0矩阵
|
519 |
+
if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None\
|
520 |
+
or units1.data[0].get(feature)=='None' or units2.data[0].get(feature)=='None':
|
521 |
+
cost_matrix = np.zeros((len(units1.data),len(units2.data)))
|
522 |
+
#对应feature的weight也为0
|
523 |
+
self.set_zero(feature)
|
524 |
+
# 并再次归一化
|
525 |
+
self.re_normalize()
|
526 |
+
return cost_matrix
|
527 |
+
|
528 |
+
#检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。相同则cost为0,不相同则cost为1-rougel
|
529 |
+
cost_matrix = []
|
530 |
+
for quad_dict1 in units1.data:
|
531 |
+
cost_list = []
|
532 |
+
for quad_dict2 in units2.data:
|
533 |
+
if quad_dict1[feature] == quad_dict2[feature]:
|
534 |
+
cost_list.append(0)
|
535 |
+
else:
|
536 |
+
cost_list.append(1-get_rougel_f1([quad_dict1[feature]],[quad_dict2[feature]]))
|
537 |
+
cost_matrix.append(cost_list)
|
538 |
+
|
539 |
+
#cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
|
540 |
+
cost_matrix = np.array(cost_matrix)
|
541 |
+
return cost_matrix
|
542 |
+
|
543 |
+
def match_units(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',one_match=True)->tuple:
|
544 |
+
#计算极性的cost矩阵,矩阵元素在0-1之间
|
545 |
+
cost_matrix_polarity = self.get_cost_matrix(units1, units2,feature='polarity')
|
546 |
+
#计算aspect的cost矩阵
|
547 |
+
cost_matrix_aspect = self.get_cost_matrix(units1, units2,feature='aspect')
|
548 |
+
#计算target的cost矩阵
|
549 |
+
cost_matrix_target = self.get_cost_matrix_rouge(units1, units2,feature='target_text')
|
550 |
+
#计算opinion的cost矩阵
|
551 |
+
cost_matrix_opinion = self.get_cost_matrix_rouge(units1, units2,feature='opinion_text')
|
552 |
+
|
553 |
+
#计算总的cost矩阵,矩阵元素在0-1之间。矩阵的行数为units1即pred的数量,列数为units2即true的数量
|
554 |
+
cost_matrix = self.target_weight*cost_matrix_target + self.opinion_weight*cost_matrix_opinion + \
|
555 |
+
self.aspect_weight*cost_matrix_aspect + self.polarity_weight*cost_matrix_polarity
|
556 |
+
score_matrix = 1-cost_matrix
|
557 |
+
#使用匈牙利算法进行匹配
|
558 |
+
if one_match:
|
559 |
+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
560 |
+
else:
|
561 |
+
#允许一对多的匹配
|
562 |
+
row_ind = np.argmin(cost_matrix, axis=0)
|
563 |
+
col_ind = np.arange(len(units2.data))
|
564 |
+
|
565 |
+
max_units_num=max(units1.num,units2.num)
|
566 |
+
|
567 |
+
#计算这种匹配的cost
|
568 |
+
cost = 0
|
569 |
+
for i in range(len(row_ind)):
|
570 |
+
cost += cost_matrix[row_ind[i]][col_ind[i]]
|
571 |
+
|
572 |
+
#计算这种匹配下的TP\FP\FN
|
573 |
+
TP = 0
|
574 |
+
for i in range(len(row_ind)):
|
575 |
+
TP += score_matrix[row_ind[i]][col_ind[i]]
|
576 |
+
|
577 |
+
#len(row_ind)为pred的数量,TP为匹配上的数量
|
578 |
+
FP = units1.num-TP
|
579 |
+
FN = units2.num-TP
|
580 |
+
|
581 |
+
|
582 |
+
#匹配不上的四元组,cost为1
|
583 |
+
cost += (max_units_num-len(row_ind))
|
584 |
+
|
585 |
+
cost_per_quadruple=cost/max_units_num
|
586 |
+
if cost_per_quadruple>1 or cost_per_quadruple <0:
|
587 |
+
|
588 |
+
print('cost错误',cost_per_quadruple,'pred:',units1.data,'true:',units2.data)
|
589 |
+
print(self.target_weight,self.opinion_weight,self.aspect_weight,self.polarity_weight)
|
590 |
+
|
591 |
+
#返回的cost在0-1之间
|
592 |
+
return cost_per_quadruple,TP,FP,FN
|
593 |
+
|
594 |
+
|
595 |
+
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
596 |
+
class QuadMatch(evaluate.Metric):
|
597 |
+
"""TODO: Short description of my evaluation module."""
|
598 |
+
|
599 |
+
def _info(self):
|
600 |
+
# TODO: Specifies the evaluate.EvaluationModuleInfo object
|
601 |
+
return evaluate.MetricInfo(
|
602 |
+
# This is the description that will appear on the modules page.
|
603 |
+
module_type="metric",
|
604 |
+
description=_DESCRIPTION,
|
605 |
+
citation=_CITATION,
|
606 |
+
inputs_description=_KWARGS_DESCRIPTION,
|
607 |
+
# This defines the format of each prediction and reference
|
608 |
+
features=[
|
609 |
+
datasets.Features(
|
610 |
+
{
|
611 |
+
"predictions": datasets.Value("string", id="sequence"),
|
612 |
+
"references": datasets.Sequence(datasets.Value("string", id="sequence")),
|
613 |
+
}
|
614 |
+
),
|
615 |
+
datasets.Features(
|
616 |
+
{
|
617 |
+
"predictions": datasets.Value("string", id="sequence"),
|
618 |
+
"references": datasets.Value("string", id="sequence"),
|
619 |
+
}
|
620 |
+
),
|
621 |
+
],
|
622 |
+
# Homepage of the module for documentation
|
623 |
+
homepage="http://module.homepage",
|
624 |
+
# Additional links to the codebase or references
|
625 |
+
codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
|
626 |
+
reference_urls=["http://path.to.reference.url/new_module"]
|
627 |
+
)
|
628 |
+
|
629 |
+
def _download_and_prepare(self, dl_manager):
|
630 |
+
"""Optional: download external resources useful to compute the scores"""
|
631 |
+
# TODO: Download external resources if needed
|
632 |
+
pass
|
633 |
+
|
634 |
+
def _compute(self,
|
635 |
+
predictions:List[str],
|
636 |
+
references: Union[List[str],List[List[str]]],
|
637 |
+
quad_weights:tuple=(1,1,1,1),
|
638 |
+
**kwargs) -> dict:
|
639 |
+
'''
|
640 |
+
|
641 |
+
:param predictions: list of predictions of sentiment quads
|
642 |
+
:param references: list of references of sentiment quads
|
643 |
+
:param quad_weights: weight of target,opinion,aspect,polarity for cost compute
|
644 |
+
|
645 |
+
:param kwargs:
|
646 |
+
:param tuple_len: indicate the format of the quad, see the following mapping
|
647 |
+
:param sep_token1: the token to seperate quads
|
648 |
+
:param sep_token2: the token to seperate units of one quad
|
649 |
+
|
650 |
+
:return:average matching score
|
651 |
+
|
652 |
+
#mapping
|
653 |
+
id2prompt={'0123':"quadruples (target | opinion | aspect | polarity)",
|
654 |
+
'':"quadruples (target | opinion | aspect | polarity)",
|
655 |
+
'01':'pairs (target | opinion)',
|
656 |
+
'012':'triples (target | opinion | aspect)',
|
657 |
+
'013':'triples (target | opinion | polarity)',
|
658 |
+
'023':'triples (target | aspect | polarity)',
|
659 |
+
'23':'pairs (aspect | polarity)',
|
660 |
+
'03':'pairs (target | polarity)',
|
661 |
+
'13':'pairs (opinion | polarity)',
|
662 |
+
'3':'single (polarity)'}
|
663 |
+
|
664 |
+
#中文版映射
|
665 |
+
id2prompt_zh={'0123': "四元组(对象 | 观点 | 方面 | 极性)",
|
666 |
+
'':"四元组(对象 | 观点 | 方面 | 极性)",
|
667 |
+
'01':'二元组(对象 | 观点)',
|
668 |
+
'012':'三元组(对象 | 观点 | 方面)',
|
669 |
+
'013':'三元组(对象 | 观点 | 极性)',
|
670 |
+
'023':'三元组(对象 | 方面 | 极性)',
|
671 |
+
'23':'二元组(方面 | 极性)',
|
672 |
+
'03':'二元组(对象 | 极性)',
|
673 |
+
'13':'二元组(观点 | 极性)',
|
674 |
+
'3':'单元素(极性)'}
|
675 |
+
'''
|
676 |
+
|
677 |
+
assert len(predictions) == len(references)
|
678 |
+
if isinstance(predictions,str):
|
679 |
+
predictions=[predictions]
|
680 |
+
references=[references]
|
681 |
+
|
682 |
+
cost=0
|
683 |
+
TP,FP,FN=0,0,0
|
684 |
+
matcher = CommentUnitsMatch(*quad_weights)
|
685 |
+
|
686 |
+
for pred, true in zip(predictions, references):
|
687 |
+
|
688 |
+
pred = CommentUnitsSim.from_str(pred,**kwargs)
|
689 |
+
# 如果true是list,说明有多个正确答案
|
690 |
+
if isinstance(true, str):
|
691 |
+
true = CommentUnitsSim.from_str(true, **kwargs)
|
692 |
+
elif isinstance(true, list):
|
693 |
+
true=[CommentUnitsSim.from_str(t, **kwargs) for t in true]
|
694 |
+
else:
|
695 |
+
print("true的类型不对",true)
|
696 |
+
continue
|
697 |
+
|
698 |
+
#如果true是list,说明有多个正确答案,取最高分
|
699 |
+
if isinstance(true, list):
|
700 |
+
cost_list=[matcher.match_units(pred,t,one_match=True) for t in true]
|
701 |
+
# 获取得分最高的值的索引,按元组中第一个元素大小排序
|
702 |
+
cost_,TP_,FP_,FN_ = cost_list[np.argmax([c[0] for c in cost_list])]
|
703 |
+
cost += cost_
|
704 |
+
TP+=TP_
|
705 |
+
FP+=FP_
|
706 |
+
FN+=FN_
|
707 |
+
|
708 |
+
else:
|
709 |
+
cost_,TP_,FP_,FN_ = matcher.match_units(pred,true,one_match=True)
|
710 |
+
cost += cost_
|
711 |
+
TP+=TP_
|
712 |
+
FP+=FP_
|
713 |
+
FN+=FN_
|
714 |
+
|
715 |
+
#平均cost
|
716 |
+
cost=cost/len(predictions)
|
717 |
+
#由TP\FP\FN计算最优匹配F1
|
718 |
+
precision_match=TP/(TP+FP)
|
719 |
+
recall_match=TP/(TP+FN)
|
720 |
+
f1_match=2*precision_match*recall_match/(precision_match+recall_match)
|
721 |
+
|
722 |
+
f1=compute_quadruple_f1(y_pred=predictions,y_true=references, **kwargs)
|
723 |
+
|
724 |
+
#取1-cost为得分
|
725 |
+
return {'ave match score of weight '+str(quad_weights):1-cost,
|
726 |
+
'f1 score of optimal match of weight '+str(quad_weights): f1_match,
|
727 |
+
'f1 score of exact match':f1}
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/huggingface/evaluate@main
|
2 |
+
rouge_chinese
|
3 |
+
scipy
|
tests.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
test_cases = [
|
2 |
+
{
|
3 |
+
"predictions": "a | b | c | pos",
|
4 |
+
"references": "a | b | c | pos & e | f | g | neg",
|
5 |
+
"result": {'ave match score of weight (1, 1, 1, 1)': 0.375,
|
6 |
+
'f1 score of exact match': 0.0,
|
7 |
+
'f1 score of optimal match of weight (1, 1, 1, 1)': 0.5}
|
8 |
+
}
|
9 |
+
]
|