Spaces:
Running
Running
File size: 6,965 Bytes
8d2a41b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import re
import json
import evaluate
import datasets
_DESCRIPTION = """
Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics:
1. Precision: The ratio of correctly predicted cells to the total number of cells in the predicted table
2. Recall: The ratio of correctly predicted cells to the total number of cells in the reference table
3. F1 Score: The harmonic mean of precision and recall
These metrics help evaluate the accuracy of table data extraction or generation.
"""
_KWARGS_DESCRIPTION = """
Args:
predictions (`str`): Predicted table in Markdown format.
references (`str`): Reference table in Markdown format.
Returns:
dict: A dictionary containing the following metrics:
- precision (`float`): Precision score, range [0,1]
- recall (`float`): Recall score, range [0,1]
- f1 (`float`): F1 score, range [0,1]
- true_positives (`int`): Number of correctly predicted cells
- false_positives (`int`): Number of incorrectly predicted cells
- false_negatives (`int`): Number of cells that were not predicted
Examples:
>>> accuracy_metric = evaluate.load("accuracy")
>>> results = accuracy_metric.compute(
... predictions="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |",
... references="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |"
... )
>>> print(results)
{'precision': 0.7, 'recall': 0.7, 'f1': 0.7, 'true_positives': 7, 'false_positives': 3, 'false_negatives': 3}
"""
_CITATION = """
@article{scikit-learn,
title={Scikit-learn: Machine Learning in {P}ython},
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
journal={Journal of Machine Learning Research},
volume={12},
pages={2825--2830},
year={2011}
}
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Accuracy(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Value("string"),
"references": datasets.Value("string"),
}
),
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
)
def _extract_markdown_table(self,text):
text = text.replace('\n', '')
text = text.replace(" ","")
pattern = r'\|(?:[^|]+\|)+[^|]+\|'
matches = re.findall(pattern, text)
if matches:
return ''.join(matches)
return None
def _table_to_dict(self,table_str):
result_dict = {}
table_str = table_str.lstrip("|").rstrip("|")
parts = table_str.split('||')
parts = [part for part in parts if "--" not in part]
legends = parts[0].split("|")
rows = len(parts)
if rows == 2:
nums = parts[1].split("|")
for i in range(len(nums)):
result_dict[legends[i]]=float(nums[i])
elif rows >=3:
for i in range(1,rows):
pre_row = parts[i]
pre_row = pre_row.split("|")
label = pre_row[0]
result_dict[label] = {}
for j in range(1,len(pre_row)):
result_dict[label][legends[j-1]] = float(pre_row[j])
else:
return None
return result_dict
def _markdown_to_dict(self,markdown_str):
table_str = self._extract_markdown_table(markdown_str)
if table_str:
return self._table_to_dict(table_str)
else:
return None
def _calculate_table_metrics(self,pred_table, true_table):
true_positives = 0
false_positives = 0
false_negatives = 0
# 遍历预测表格的所有键值对
for key, pred_value in pred_table.items():
if key in true_table:
true_value = true_table[key]
if isinstance(pred_value, dict) and isinstance(true_value, dict):
nested_metrics = self._calculate_table_metrics(pred_value, true_value)
true_positives += nested_metrics['true_positives']
false_positives += nested_metrics['false_positives']
false_negatives += nested_metrics['false_negatives']
# 如果值相等
elif pred_value == true_value:
true_positives += 1
else:
false_positives += 1
false_negatives += 1
else:
false_positives += 1
# 计算未匹配的真实值
for key in true_table:
if key not in pred_table:
false_negatives += 1
# 计算指标
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
return {
'precision': precision,
'recall': recall,
'f1': f1,
'true_positives': true_positives,
'false_positives': false_positives,
'false_negatives': false_negatives
}
def _compute(self, predictions, references):
predictions = "".join(predictions)
references = "".join(references)
return self._calculate_table_metrics(self._markdown_to_dict(predictions), self._markdown_to_dict(references))
def main():
accuracy_metric = Accuracy()
# 计算指标
results = accuracy_metric.compute(
predictions=["""
| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |
"""], # 预测的表格
references=["""
| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |
"""], # 参考的表格
)
print(results) # 输出结果
if __name__ == '__main__':
main()
|