yuyijiong commited on
Commit
83cb3c8
1 Parent(s): e28c9cd

Upload 6 files

Browse files
Files changed (5) hide show
  1. README.md +66 -5
  2. app.py +10 -0
  3. quad_match_score.py +727 -0
  4. requirements.txt +3 -0
  5. tests.py +9 -0
README.md CHANGED
@@ -1,12 +1,73 @@
1
  ---
2
  title: Quad Match Score
3
- emoji: 🦀
4
- colorFrom: pink
5
- colorTo: purple
 
 
 
6
  sdk: gradio
7
- sdk_version: 3.24.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Quad Match Score
3
+ datasets:
4
+ - SemEval2016 Task5
5
+ tags:
6
+ - evaluate
7
+ - metric
8
+ description: "TODO: add a description here"
9
  sdk: gradio
10
+ sdk_version: 3.19.1
11
  app_file: app.py
12
  pinned: false
13
  ---
14
 
15
+ # Metric Card for My Metric
16
+
17
+ ***Module Card Instructions:*** *评估生成模型的情感四元组抽取结果.*
18
+
19
+ ## Metric Description
20
+ *评估生成模型的情感四元组抽取结果.*
21
+
22
+ ## How to Use
23
+ ```python
24
+ import evaluate
25
+
26
+ module = evaluate.load("yuyijiong/my_metric")
27
+
28
+ predictions=["food | good | food#taste | pos"]
29
+ references=["food | good | food#taste | pos & service | bad | service#general | neg"]
30
+
31
+ module.compute(predictions=predictions, references=references)
32
+
33
+ ```
34
+
35
+ ### Inputs
36
+ *List all input arguments in the format below*
37
+ - **predictions** *(List[str]): 模型生成的四元组,列表中每个字符串代表一个样本的生成结果.*
38
+ - **references** *(Union[List[str],List[List[str]]):
39
+ 人工标注的四元组,列表中每个字符串代表一个样本的标签.如果列表元素为list,代表多个reference,评估时取最高分*
40
+ - **weights** *(Tuple[float, float, float, float]):默认为(1,1,1,1),分别代表四个方面的评估指标的权重*
41
+ - **tuple_len** *(str): indicate the format of the quad, see the following mapping
42
+ 指示四元组的格式,默认为'0123'。对应关系如下所示*
43
+ ```
44
+ {'0123': "四元组(对象 | 观点 | 方面 | 极性)",
45
+ '01':'二元组(对象 | 观点)',
46
+ '012':'三元组(对象 | 观点 | 方面)',
47
+ '013':'三元组(对象 | 观点 | 极性)',
48
+ '023':'三元组(对象 | 方面 | 极性)',
49
+ '23':'二元组(方面 | 极性)',
50
+ '03':'二元组(对象 | 极性)',
51
+ '13':'二元组(观点 | 极性)',
52
+ '3':'单元素(极性)'}
53
+ ```
54
+ - **sep_token1**: the token to seperate quads 分割不同四元组的token
55
+ - **sep_token2**: the token to seperate units of one quad 四元组中不同元素之间的分隔token
56
+
57
+ ### Output Values
58
+
59
+ *最优匹配 f1值、最优匹配样本平均得分、完全匹配 f1值 组成的dict,f1值均在[0,1]之间*
60
+
61
+ *例如: {'ave match score of weight (1, 1, 1, 1)': 0.375,
62
+ 'f1 score of exact match': 0.0,
63
+ 'f1 score of optimal match of weight (1, 1, 1, 1)': 0.5}*
64
+
65
+
66
+ ## Limitations and Bias
67
+ *对比传统评估指标,得分偏高*
68
+
69
+ ## Citation
70
+ *论文即将发表*
71
+
72
+ ## Further References
73
+ *Add any useful further references.*
app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ from evaluate.utils import launch_gradio_widget
3
+
4
+ module = evaluate.load("yuyijiong/quad_match_score")
5
+ launch_gradio_widget(module)
6
+
7
+ # predictions=["a | b | c | pos"]
8
+ # references=["a | b | c | pos & e | f | g | neg"]
9
+ #
10
+ # module.compute(predictions=predictions, references=references)
quad_match_score.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """TODO: Add a description here."""
15
+
16
+ import copy
17
+ import re
18
+ from typing import List, Dict, Union,Callable
19
+ import numpy as np
20
+
21
+
22
+ import datasets
23
+ import evaluate
24
+ from rouge_chinese import Rouge
25
+ from scipy.optimize import linear_sum_assignment
26
+
27
+ # TODO: Add BibTeX citation
28
+ _CITATION = """\
29
+ @InProceedings{huggingface:module,
30
+ title = {A great new module},
31
+ authors={huggingface, Inc.},
32
+ year={2020}
33
+ }
34
+ """
35
+
36
+ # TODO: Add description of the module here
37
+ _DESCRIPTION = """\
38
+ evaluate sentiment quadruples.
39
+ 评估生成模型的情感四元组
40
+ """
41
+
42
+
43
+ # TODO: Add description of the arguments of the module here
44
+ _KWARGS_DESCRIPTION = """
45
+ Calculates how good are predictions given some references, using certain scores
46
+ Args:
47
+ predictions: list of predictions to score. Each predictions
48
+ should be a string with tokens separated by spaces.
49
+ references: list of reference for each prediction. Each
50
+ reference should be a string with tokens separated by spaces.
51
+ Returns:
52
+ score: sentiment quadruple match score
53
+
54
+ Examples:
55
+ Examples should be written in doctest format, and should illustrate how
56
+ to use the function.
57
+
58
+ >>> my_new_module = evaluate.load("my_new_module")
59
+ >>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
60
+ >>> print(results)
61
+ {'accuracy': 1.0}
62
+ """
63
+
64
+
65
+ def compute_quadruple_f1(y_pred: List[str], y_true: Union[List[str], List[List[str]]],
66
+ return_rp=False, **kwargs) -> Dict[str, float]:
67
+ assert len(y_pred) == len(y_true)
68
+ correct, pred_num, true_num = 0, 0, 0
69
+
70
+ for pred, true in zip(y_pred, y_true):
71
+
72
+ pred = CommentUnitsSim.from_str(pred, **kwargs)
73
+ # 如果true是list,说明有多个正确答案
74
+ if isinstance(true, str):
75
+ true = CommentUnitsSim.from_str(true, **kwargs)
76
+ else:
77
+ true = [CommentUnitsSim.from_str(t,**kwargs) for t in true]
78
+
79
+ # 如果true是list,说明有多个正确答案,取最高分
80
+ if isinstance(true, list):
81
+ correct_list = [pred.compare_same(t) for t in true]
82
+ correct += max(correct_list) # 获取得分最高的值
83
+ correct_index = correct_list.index(max(correct_list)) # 获取得分最高的索引
84
+ pred_num += pred.num
85
+ true_num += true[correct_index].num
86
+ else:
87
+ correct += pred.compare_same(true)
88
+ pred_num += pred.num
89
+ true_num += true.num
90
+
91
+ # 以下结果保留4位小数
92
+ precision = round(correct / pred_num, 4) + 1e-8
93
+ recall = round(correct / true_num, 4) + 1e-8
94
+ f1 = round(2 * precision * recall / (precision + recall), 4)
95
+
96
+ if return_rp:
97
+ return {"precision": precision, "recall": recall, "f1": f1}
98
+ else:
99
+ return f1
100
+
101
+ # 计算rougel的f1值
102
+ def get_rougel_f1(text_pred_list: List[str], text_true_list: List[str]) -> float:
103
+ assert len(text_pred_list) == len(text_true_list), "文本数量不一致"
104
+ #如果text_pred_list[0]为空字符串或空格,则返回0
105
+ if not text_pred_list[0].strip():
106
+ return 0
107
+
108
+ rouge = Rouge()
109
+ # 判断text_true[0]是否有中文,有中文则要用空格分割
110
+ if re.search(u"[\u4e00-\u9fa5]+", text_pred_list[0]):
111
+ text_pred_list = [' '.join(list(text_pred)) for text_pred in text_pred_list]
112
+ text_true_list = [' '.join(list(text_true)) for text_true in text_true_list]
113
+
114
+ rouge_l_f1 = rouge.get_scores(text_pred_list, text_true_list, avg=True)['rouge-l']['f']
115
+
116
+ return rouge_l_f1
117
+
118
+ # 记录四元组的函数
119
+ class CommentUnitsSim:
120
+ def __init__(self, data: List[Dict[str, str]],data_source:any=None,abnormal=False,language=None):
121
+ self.data_source=data_source
122
+ self.abnormal=abnormal
123
+ data=copy.deepcopy(data)
124
+ # 如果字典有target,则改名为target_text
125
+ for quad_dict in data:
126
+ if 'target' in quad_dict:
127
+ quad_dict['target_text'] = quad_dict['target']
128
+ del quad_dict['target']
129
+ if 'opinion' in quad_dict:
130
+ quad_dict['opinion_text'] = quad_dict['opinion']
131
+ del quad_dict['opinion']
132
+
133
+ self.data = data
134
+ self.polarity_en2zh = {'positive': '积极', 'negative': '消极', 'neutral': '中性','pos':'积极','neg':'消极','neu':'中性','积极':'积极','消极':'���极','中性':'中性'}
135
+ self.polarity_zh2en={'积极':'pos','消极':'neg','中性':'neu','pos':'pos','neg':'neg','neu':'neu','positive':'pos','negative':'neg','neutral':'neu'}
136
+
137
+ self.language=language if language is not None else 'zh' if self.check_zh() else 'en'
138
+ self.none_sign='null'
139
+
140
+ @property
141
+ def num(self):
142
+ return len(self.data)
143
+
144
+ #检查四元组中是否有中文
145
+ def check_zh(self):
146
+ for quad_dict in self.data:
147
+ if re.search('[\u4e00-\u9fa5]',quad_dict['target_text']) or re.search('[\u4e00-\u9fa5]',quad_dict['opinion_text']):
148
+ return True
149
+ return False
150
+
151
+ # 检测极性是否正确
152
+ def check_polarity(self):
153
+ #若有某个四元组的极性不是positive、negative、neutral,则返回False
154
+ for quad_dict in self.data:
155
+ if quad_dict['polarity'] not in ['positive', 'negative', 'neutral','pos','neg','neu','积极','消极','中性']:
156
+ self.abnormal=True
157
+ return False
158
+
159
+ #将极性由英文转为中文
160
+ def convert_polarity_en2zh(self):
161
+ for quad_dict in self.data:
162
+ quad_dict['polarity']=self.polarity_en2zh[quad_dict['polarity']]
163
+ return self
164
+
165
+ #将极性由中文转为英文
166
+ def convert_polarity_zh2en(self):
167
+ for quad_dict in self.data:
168
+ quad_dict['polarity']=self.polarity_zh2en[quad_dict['polarity']]
169
+ return self
170
+
171
+ #检查是否有重复的四元组,若有则删除重复的
172
+ def del_duplicate(self):
173
+ new_data=[]
174
+ for quad_dict in self.data:
175
+ if quad_dict not in new_data:
176
+ new_data.append(quad_dict)
177
+ self.data=new_data
178
+ return self
179
+
180
+ #检查是否有target和opinion都为null的四元组,若有则返回True
181
+ def check_target_opinion_null(self):
182
+ for quad_dict in self.data:
183
+ if quad_dict['target_text']=='null' and quad_dict['opinion_text']=='null':
184
+ return True
185
+ return False
186
+
187
+ #检查是否有target或opinion为null的四元组,若有则返回True
188
+ def check_any_null(self):
189
+ for quad_dict in self.data:
190
+ if quad_dict['target_text']=='null' or quad_dict['opinion_text']=='null':
191
+ return True
192
+ return False
193
+
194
+ @classmethod
195
+ def from_str(cls, quadruple_str: str, tuple_len:Union[int,list,str]=4, format_code=0, sep_token1=' & ', sep_token2=' | '):
196
+ data = []
197
+ abnormal=False
198
+ #确保分隔符后面一定是空格
199
+ for i in range(len(quadruple_str)-1):
200
+ if (quadruple_str[i] == sep_token1.strip() or quadruple_str[i] == sep_token2.strip()) and quadruple_str[i + 1] != ' ':
201
+ quadruple_str = quadruple_str[:i + 1] + ' ' + quadruple_str[i + 1:]
202
+
203
+ # 选择几元组,即创建列表索引,从四元组中抽出n元
204
+ if isinstance(tuple_len, int):
205
+ tuple_index = list(range(tuple_len))
206
+ elif isinstance(tuple_len, list):
207
+ tuple_index = tuple_len
208
+ elif isinstance(tuple_len, str):
209
+ # 例如将‘012’转换为[0,1,2]
210
+ tuple_index = [int(i) for i in tuple_len]
211
+ else:
212
+ raise Exception('tuple_len参数错误')
213
+
214
+
215
+ for quadruple in quadruple_str.split(sep_token1):
216
+ if format_code == 0:
217
+ # quadruple可能是target|opinion|aspect|polarity,也可能是target|opinion|aspect,也可能是target|opinion,若没有则为“None”
218
+ quadruple_split=[unit.strip() for unit in quadruple.split(sep_token2)]
219
+ if len(quadruple_split)>len(tuple_index):
220
+ print('quadruple格式错误,过多元素', quadruple_str)
221
+ abnormal=True
222
+ quadruple_split=quadruple_split[0:len(tuple_index)] #过长则截断
223
+ elif len(quadruple_split)<len(tuple_index):
224
+ print('quadruple格式错误,过少元素', quadruple_str)
225
+ abnormal=True
226
+ quadruple_split=["None"]*(len(tuple_index)-len(quadruple_split))+quadruple_split #过短则补'None'
227
+
228
+ quadruple_keys=[["target_text","opinion_text","aspect","polarity"][i] for i in tuple_index]
229
+ quadruple_dict=dict(zip(quadruple_keys,quadruple_split))
230
+
231
+ q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'}
232
+ q.update(quadruple_dict)
233
+ #检查极性是否合法
234
+ if q['polarity'] not in ['pos','neg','neu','None','积极','消极','中性']:
235
+ print('quadruple格式错误,极性格式不对', quadruple_str)
236
+
237
+ else:
238
+ raise Exception('answer_format参数错误')
239
+
240
+ data.append(q)
241
+
242
+ return CommentUnitsSim(data,quadruple_str,abnormal)
243
+
244
+ @classmethod
245
+ def from_list(cls, quadruple_list: List[List[str]],**kwargs):
246
+ data = []
247
+ for quadruple in quadruple_list:
248
+ # #format_code='013'代表list只有四元组的第0、1、3个元素,需要扩充为4元组,空缺位置补上None
249
+ # if format_code=='013':
250
+ # quadruple.insert(2,None)
251
+
252
+ data.append(
253
+ {"target_text": quadruple[0], "opinion_text": quadruple[1], "aspect": quadruple[2],
254
+ "polarity": quadruple[3]})
255
+
256
+ return CommentUnitsSim(data,quadruple_list,**kwargs)
257
+
258
+ @classmethod
259
+ def from_list_dict(cls, quadruple_list: List[dict],**kwargs):
260
+ for quad_dict in quadruple_list:
261
+ if 'target' in quad_dict:
262
+ quad_dict['target_text'] = quad_dict['target']
263
+ del quad_dict['target']
264
+ if 'opinion' in quad_dict:
265
+ quad_dict['opinion_text'] = quad_dict['opinion']
266
+ del quad_dict['opinion']
267
+
268
+ data = []
269
+ for quadruple in quadruple_list:
270
+ #如果quadruple缺少某个key,则补上None
271
+ q={"target_text":'None',"opinion_text":'None',"aspect":'None',"polarity":'None'}
272
+ q.update(quadruple)
273
+ data.append(q)
274
+
275
+ return CommentUnitsSim(data,quadruple_list,**kwargs)
276
+
277
+ #转化为list,即只保留字典的value
278
+ def to_list(self):
279
+ data = []
280
+ for quad_dict in self.data:
281
+ data.append([quad_dict['target_text'],quad_dict['opinion_text'],quad_dict['aspect'],quad_dict['polarity']])
282
+ return data
283
+
284
+ # 将data转换为n元组字符串
285
+ def get_quadruple_str(self, format_code=0, tuple_len:Union[int,list,str]=4,sep_token1=' & ',sep_token2=' | '):
286
+ new_text_list = []
287
+ # 选择几元组,即创建列表索引,从四元组中抽出n元
288
+ if isinstance(tuple_len, int):
289
+ tuple_index = list(range(tuple_len))
290
+ elif isinstance(tuple_len, list):
291
+ tuple_index = tuple_len
292
+ elif isinstance(tuple_len, str):
293
+ # 例如将‘012’转换为[0,1,2]
294
+ tuple_index = [int(i) for i in tuple_len]
295
+ else:
296
+ raise Exception('tuple_len参数错误')
297
+
298
+ try:
299
+ #若语言为中文,则使用中文极性
300
+ if self.language=='zh':
301
+ self.convert_polarity_en2zh()
302
+ else:
303
+ self.convert_polarity_zh2en()
304
+ except:
305
+ print('语言参数错误',self.data)
306
+ print(self.language)
307
+ raise Exception('语言参数错误')
308
+
309
+ #若tuple_index==[3],则返回综合情感极性
310
+ if tuple_index==[3]:
311
+ return self.merge_polarity()
312
+
313
+ for quad_dict in self.data:
314
+ # 提取target_text,如果空列表则为'',如果列表长度大于1则为','.join(list)
315
+ target_text = quad_dict['target_text']
316
+ # 提取opinion_text,如果空列表则为'',如果列表长度大于1则为','.join(list)
317
+ opinion_text = quad_dict['opinion_text']
318
+ # 提取aspect
319
+ aspect = quad_dict['aspect']
320
+ # 提取polarity
321
+ polarity = quad_dict['polarity']
322
+
323
+
324
+ # 拼接,‘|’分割
325
+ if format_code == 0:
326
+ # 根据tuple_len拼接
327
+ new_text = sep_token2.join([[target_text, opinion_text, aspect, polarity][i] for i in tuple_index])
328
+ else:
329
+ raise Exception('answer_format参数错误')
330
+
331
+ new_text_list.append(new_text)
332
+
333
+ #如果tuple_index为[2,3],则需要去除new_text_list中重复的元素,不要改变顺序。因为可能有重复的方面
334
+ if tuple_index==[2,3]:
335
+ res = []
336
+ for t in new_text_list:
337
+ if t not in res:
338
+ res.append(t)
339
+ new_text_list=res
340
+
341
+ #如果tuple_index为[3],则只保留new_text_list的第一个元素。因为只有一个情感极性
342
+ elif tuple_index==[3]:
343
+ new_text_list=new_text_list[:1]
344
+
345
+ if format_code == 0:
346
+ # 根据tuple_len拼接
347
+ return sep_token1.join(new_text_list)
348
+
349
+ # 与另一个CommentUnits对象对比,检测有几个相同的四元组
350
+ def compare_same(self, other)->int:
351
+ count = 0
352
+ for quad_dict in self.data:
353
+ if quad_dict in other.data:
354
+ count += 1
355
+ return count
356
+
357
+ # 检查自身数据的四元组中target是否有重复
358
+ def check_target_repeat(self):
359
+ target_list = []
360
+ for quad_dict in self.data:
361
+ target_list.append(quad_dict['target_text'])
362
+ return len(target_list) != len(set(target_list))
363
+
364
+ # 检查自身数据的四元组中opinion是否有重复
365
+ def check_opinion_repeat(self):
366
+ opinion_list = []
367
+ for quad_dict in self.data:
368
+ opinion_list.append(quad_dict['opinion_text'])
369
+ return len(opinion_list) != len(set(opinion_list))
370
+
371
+ # 检查自身数据的四元组中aspect是否有重复
372
+ def check_aspect_repeat(self):
373
+ aspect_list = []
374
+ for quad_dict in self.data:
375
+ aspect_list.append(quad_dict['aspect'])
376
+ return len(aspect_list) != len(set(aspect_list))
377
+
378
+ # 输出所有aspect的列表
379
+ def get_aspect_list(self):
380
+ aspect_list = []
381
+ for quad_dict in self.data:
382
+ aspect_list.append(quad_dict['aspect'])
383
+ return aspect_list
384
+
385
+ # 输出所有target的列表
386
+ def get_target_list(self):
387
+ target_list = []
388
+ for quad_dict in self.data:
389
+ target_list.append(quad_dict['target_text'])
390
+ return target_list
391
+
392
+ # 输出所有opinion的列表
393
+ def get_opinion_list(self):
394
+ opinion_list = []
395
+ for quad_dict in self.data:
396
+ opinion_list.append(quad_dict['opinion_text'])
397
+ return opinion_list
398
+
399
+ # 输出所有polarity的列表
400
+ def get_polarity_list(self):
401
+ polarity_list = []
402
+ for quad_dict in self.data:
403
+ polarity_list.append(quad_dict['polarity'])
404
+ return polarity_list
405
+
406
+ #对所有polarity进行综合
407
+ def merge_polarity(self):
408
+ polarity_list = self.get_polarity_list()
409
+ #判断是英文还是中文
410
+ if self.language == 'en':
411
+ if 'pos' in polarity_list and 'neg' in polarity_list:
412
+ return 'neu'
413
+ elif 'pos' in polarity_list:
414
+ return 'pos'
415
+ elif 'neg' in polarity_list:
416
+ return 'neg'
417
+ else:
418
+ return 'neu'
419
+ else:
420
+ if '积极' in polarity_list and '消极' in polarity_list:
421
+ return '中性'
422
+ elif '积极' in polarity_list:
423
+ return '积极'
424
+ elif '消极' in polarity_list:
425
+ return '消极'
426
+ else:
427
+ return '中性'
428
+
429
+ #检测是否有不合法opinion
430
+ def check_opinion_in_comment(self, comment_text):
431
+ for quad_dict in self.data:
432
+ if quad_dict['opinion_text'] !='*' and (not quad_dict['opinion_text'] in comment_text):
433
+ return False
434
+ return True
435
+
436
+ #检测是否有不合法target
437
+ def check_target_in_comment(self,comment_text):
438
+ for quad_dict in self.data:
439
+ if quad_dict['target_text'] !='*' and (not quad_dict['target_text'] in comment_text):
440
+ return False
441
+ return True
442
+
443
+ #计算两个四元组的相似度
444
+ @staticmethod
445
+ def get_similarity(units1, units2: 'CommentUnitsSim'):
446
+ pass
447
+
448
+ #对自身数据进行操作
449
+ def apply(self,func:Callable,field:str):
450
+ for quad_dict in self.data:
451
+ quad_dict[field] = func(quad_dict[field])
452
+ return self
453
+
454
+
455
+ #四元组匹配函数
456
+ class CommentUnitsMatch:
457
+ def __init__(self,target_weight=0.5,opinion_weight=0.5,aspect_weight=0.5,polarity_weight=0.5):
458
+ #归一化权重
459
+ weight_sum = target_weight+opinion_weight+aspect_weight+polarity_weight
460
+ self.target_weight = target_weight/weight_sum
461
+ self.opinion_weight = opinion_weight/weight_sum
462
+ self.aspect_weight = aspect_weight/weight_sum
463
+ self.polarity_weight = polarity_weight/weight_sum
464
+
465
+ #特定feature置零
466
+ def set_zero(self,feature:str='polarity'):
467
+ if feature == 'polarity':
468
+ self.polarity_weight = 0
469
+ elif feature == 'aspect':
470
+ self.aspect_weight = 0
471
+ elif 'opinion' in feature:
472
+ self.opinion_weight = 0
473
+ elif 'target' in feature:
474
+ self.target_weight = 0
475
+ else:
476
+ raise Exception('feature参数错误')
477
+
478
+ def re_normalize(self):
479
+ weight_sum = self.target_weight+self.opinion_weight+self.aspect_weight+self.polarity_weight
480
+ self.target_weight = self.target_weight/weight_sum
481
+ self.opinion_weight = self.opinion_weight/weight_sum
482
+ self.aspect_weight = self.aspect_weight/weight_sum
483
+ self.polarity_weight = self.polarity_weight/weight_sum
484
+
485
+
486
+ #计算cost矩阵
487
+ def get_cost_matrix(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',feature:str='polarity'):
488
+ pass
489
+ #检查此feature是否存在,不存在则返回全0矩阵
490
+ if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None\
491
+ or units1.data[0].get(feature)=='None' or units2.data[0].get(feature)=='None':
492
+ cost_matrix = np.zeros((len(units1.data),len(units2.data)))
493
+ #对应feature的weight也为0
494
+ self.set_zero(feature)
495
+
496
+ # 并再次归一化
497
+ self.re_normalize()
498
+
499
+ return cost_matrix
500
+
501
+ #检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。不相同则cost为1,相同则cost为0
502
+ cost_matrix = []
503
+ for quad_dict1 in units1.data:
504
+ cost_list = []
505
+ for quad_dict2 in units2.data:
506
+ if quad_dict1[feature] == quad_dict2[feature]:
507
+ cost_list.append(0)
508
+ else:
509
+ cost_list.append(1)
510
+ cost_matrix.append(cost_list)
511
+
512
+ #cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
513
+ cost_matrix = np.array(cost_matrix)
514
+ return cost_matrix
515
+
516
+ #计算cost矩阵,使用rouge指标
517
+ def get_cost_matrix_rouge(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',feature:str='target_text'):
518
+ #检查此feature是否存在,不存在则返回全0矩阵
519
+ if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None\
520
+ or units1.data[0].get(feature)=='None' or units2.data[0].get(feature)=='None':
521
+ cost_matrix = np.zeros((len(units1.data),len(units2.data)))
522
+ #对应feature的weight也为0
523
+ self.set_zero(feature)
524
+ # 并再次归一化
525
+ self.re_normalize()
526
+ return cost_matrix
527
+
528
+ #检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。相同则cost为0,不相同则cost为1-rougel
529
+ cost_matrix = []
530
+ for quad_dict1 in units1.data:
531
+ cost_list = []
532
+ for quad_dict2 in units2.data:
533
+ if quad_dict1[feature] == quad_dict2[feature]:
534
+ cost_list.append(0)
535
+ else:
536
+ cost_list.append(1-get_rougel_f1([quad_dict1[feature]],[quad_dict2[feature]]))
537
+ cost_matrix.append(cost_list)
538
+
539
+ #cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
540
+ cost_matrix = np.array(cost_matrix)
541
+ return cost_matrix
542
+
543
+ def match_units(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',one_match=True)->tuple:
544
+ #计算极性的cost矩阵,矩阵元素在0-1之间
545
+ cost_matrix_polarity = self.get_cost_matrix(units1, units2,feature='polarity')
546
+ #计算aspect的cost矩阵
547
+ cost_matrix_aspect = self.get_cost_matrix(units1, units2,feature='aspect')
548
+ #计算target的cost矩阵
549
+ cost_matrix_target = self.get_cost_matrix_rouge(units1, units2,feature='target_text')
550
+ #计算opinion的cost矩阵
551
+ cost_matrix_opinion = self.get_cost_matrix_rouge(units1, units2,feature='opinion_text')
552
+
553
+ #计算总的cost矩阵,矩阵元素在0-1之间。矩阵的行数为units1即pred的数量,列数为units2即true的数量
554
+ cost_matrix = self.target_weight*cost_matrix_target + self.opinion_weight*cost_matrix_opinion + \
555
+ self.aspect_weight*cost_matrix_aspect + self.polarity_weight*cost_matrix_polarity
556
+ score_matrix = 1-cost_matrix
557
+ #使用匈牙利算法进行匹配
558
+ if one_match:
559
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
560
+ else:
561
+ #允许一对多的匹配
562
+ row_ind = np.argmin(cost_matrix, axis=0)
563
+ col_ind = np.arange(len(units2.data))
564
+
565
+ max_units_num=max(units1.num,units2.num)
566
+
567
+ #计算这种匹配的cost
568
+ cost = 0
569
+ for i in range(len(row_ind)):
570
+ cost += cost_matrix[row_ind[i]][col_ind[i]]
571
+
572
+ #计算这种匹配下的TP\FP\FN
573
+ TP = 0
574
+ for i in range(len(row_ind)):
575
+ TP += score_matrix[row_ind[i]][col_ind[i]]
576
+
577
+ #len(row_ind)为pred的数量,TP为匹配上的数量
578
+ FP = units1.num-TP
579
+ FN = units2.num-TP
580
+
581
+
582
+ #匹配不上的四元组,cost为1
583
+ cost += (max_units_num-len(row_ind))
584
+
585
+ cost_per_quadruple=cost/max_units_num
586
+ if cost_per_quadruple>1 or cost_per_quadruple <0:
587
+
588
+ print('cost错误',cost_per_quadruple,'pred:',units1.data,'true:',units2.data)
589
+ print(self.target_weight,self.opinion_weight,self.aspect_weight,self.polarity_weight)
590
+
591
+ #返回的cost在0-1之间
592
+ return cost_per_quadruple,TP,FP,FN
593
+
594
+
595
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
596
+ class QuadMatch(evaluate.Metric):
597
+ """TODO: Short description of my evaluation module."""
598
+
599
+ def _info(self):
600
+ # TODO: Specifies the evaluate.EvaluationModuleInfo object
601
+ return evaluate.MetricInfo(
602
+ # This is the description that will appear on the modules page.
603
+ module_type="metric",
604
+ description=_DESCRIPTION,
605
+ citation=_CITATION,
606
+ inputs_description=_KWARGS_DESCRIPTION,
607
+ # This defines the format of each prediction and reference
608
+ features=[
609
+ datasets.Features(
610
+ {
611
+ "predictions": datasets.Value("string", id="sequence"),
612
+ "references": datasets.Sequence(datasets.Value("string", id="sequence")),
613
+ }
614
+ ),
615
+ datasets.Features(
616
+ {
617
+ "predictions": datasets.Value("string", id="sequence"),
618
+ "references": datasets.Value("string", id="sequence"),
619
+ }
620
+ ),
621
+ ],
622
+ # Homepage of the module for documentation
623
+ homepage="http://module.homepage",
624
+ # Additional links to the codebase or references
625
+ codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
626
+ reference_urls=["http://path.to.reference.url/new_module"]
627
+ )
628
+
629
+ def _download_and_prepare(self, dl_manager):
630
+ """Optional: download external resources useful to compute the scores"""
631
+ # TODO: Download external resources if needed
632
+ pass
633
+
634
+ def _compute(self,
635
+ predictions:List[str],
636
+ references: Union[List[str],List[List[str]]],
637
+ quad_weights:tuple=(1,1,1,1),
638
+ **kwargs) -> dict:
639
+ '''
640
+
641
+ :param predictions: list of predictions of sentiment quads
642
+ :param references: list of references of sentiment quads
643
+ :param quad_weights: weight of target,opinion,aspect,polarity for cost compute
644
+
645
+ :param kwargs:
646
+ :param tuple_len: indicate the format of the quad, see the following mapping
647
+ :param sep_token1: the token to seperate quads
648
+ :param sep_token2: the token to seperate units of one quad
649
+
650
+ :return:average matching score
651
+
652
+ #mapping
653
+ id2prompt={'0123':"quadruples (target | opinion | aspect | polarity)",
654
+ '':"quadruples (target | opinion | aspect | polarity)",
655
+ '01':'pairs (target | opinion)',
656
+ '012':'triples (target | opinion | aspect)',
657
+ '013':'triples (target | opinion | polarity)',
658
+ '023':'triples (target | aspect | polarity)',
659
+ '23':'pairs (aspect | polarity)',
660
+ '03':'pairs (target | polarity)',
661
+ '13':'pairs (opinion | polarity)',
662
+ '3':'single (polarity)'}
663
+
664
+ #中文版映射
665
+ id2prompt_zh={'0123': "四元组(对象 | 观点 | 方面 | 极性)",
666
+ '':"四元组(对象 | 观点 | 方面 | 极性)",
667
+ '01':'二元组(对象 | 观点)',
668
+ '012':'三元组(对象 | 观点 | 方面)',
669
+ '013':'三元组(对象 | 观点 | 极性)',
670
+ '023':'三元组(对象 | 方面 | 极性)',
671
+ '23':'二元组(方面 | 极性)',
672
+ '03':'二元组(对象 | 极性)',
673
+ '13':'二元组(观点 | 极性)',
674
+ '3':'单元素(极性)'}
675
+ '''
676
+
677
+ assert len(predictions) == len(references)
678
+ if isinstance(predictions,str):
679
+ predictions=[predictions]
680
+ references=[references]
681
+
682
+ cost=0
683
+ TP,FP,FN=0,0,0
684
+ matcher = CommentUnitsMatch(*quad_weights)
685
+
686
+ for pred, true in zip(predictions, references):
687
+
688
+ pred = CommentUnitsSim.from_str(pred,**kwargs)
689
+ # 如果true是list,说明有多个正确答案
690
+ if isinstance(true, str):
691
+ true = CommentUnitsSim.from_str(true, **kwargs)
692
+ elif isinstance(true, list):
693
+ true=[CommentUnitsSim.from_str(t, **kwargs) for t in true]
694
+ else:
695
+ print("true的类型不对",true)
696
+ continue
697
+
698
+ #如果true是list,说明有多个正确答案,取最高分
699
+ if isinstance(true, list):
700
+ cost_list=[matcher.match_units(pred,t,one_match=True) for t in true]
701
+ # 获取得分最高的值的索引,按元组中第一个元素大小排序
702
+ cost_,TP_,FP_,FN_ = cost_list[np.argmax([c[0] for c in cost_list])]
703
+ cost += cost_
704
+ TP+=TP_
705
+ FP+=FP_
706
+ FN+=FN_
707
+
708
+ else:
709
+ cost_,TP_,FP_,FN_ = matcher.match_units(pred,true,one_match=True)
710
+ cost += cost_
711
+ TP+=TP_
712
+ FP+=FP_
713
+ FN+=FN_
714
+
715
+ #平均cost
716
+ cost=cost/len(predictions)
717
+ #由TP\FP\FN计算最优匹配F1
718
+ precision_match=TP/(TP+FP)
719
+ recall_match=TP/(TP+FN)
720
+ f1_match=2*precision_match*recall_match/(precision_match+recall_match)
721
+
722
+ f1=compute_quadruple_f1(y_pred=predictions,y_true=references, **kwargs)
723
+
724
+ #取1-cost为得分
725
+ return {'ave match score of weight '+str(quad_weights):1-cost,
726
+ 'f1 score of optimal match of weight '+str(quad_weights): f1_match,
727
+ 'f1 score of exact match':f1}
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ git+https://github.com/huggingface/evaluate@main
2
+ rouge_chinese
3
+ scipy
tests.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ test_cases = [
2
+ {
3
+ "predictions": "a | b | c | pos",
4
+ "references": "a | b | c | pos & e | f | g | neg",
5
+ "result": {'ave match score of weight (1, 1, 1, 1)': 0.375,
6
+ 'f1 score of exact match': 0.0,
7
+ 'f1 score of optimal match of weight (1, 1, 1, 1)': 0.5}
8
+ }
9
+ ]