Spaces:
Running
Running
import re | |
import json | |
import evaluate | |
import datasets | |
_DESCRIPTION = """ | |
Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics: | |
1. Precision: The ratio of correctly predicted cells to the total number of cells in the predicted table | |
2. Recall: The ratio of correctly predicted cells to the total number of cells in the reference table | |
3. F1 Score: The harmonic mean of precision and recall | |
These metrics help evaluate the accuracy of table data extraction or generation. | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Args: | |
predictions (`str`): Predicted table in Markdown format. | |
references (`str`): Reference table in Markdown format. | |
Returns: | |
dict: A dictionary containing the following metrics: | |
- precision (`float`): Precision score, range [0,1] | |
- recall (`float`): Recall score, range [0,1] | |
- f1 (`float`): F1 score, range [0,1] | |
- true_positives (`int`): Number of correctly predicted cells | |
- false_positives (`int`): Number of incorrectly predicted cells | |
- false_negatives (`int`): Number of cells that were not predicted | |
Examples: | |
>>> accuracy_metric = evaluate.load("accuracy") | |
>>> results = accuracy_metric.compute( | |
... predictions="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |", | |
... references="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |" | |
... ) | |
>>> print(results) | |
{'precision': 0.7, 'recall': 0.7, 'f1': 0.7, 'true_positives': 7, 'false_positives': 3, 'false_negatives': 3} | |
""" | |
_CITATION = """ | |
@article{scikit-learn, | |
title={Scikit-learn: Machine Learning in {P}ython}, | |
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V. | |
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P. | |
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and | |
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.}, | |
journal={Journal of Machine Learning Research}, | |
volume={12}, | |
pages={2825--2830}, | |
year={2011} | |
} | |
""" | |
class Accuracy(evaluate.Metric): | |
def _info(self): | |
return evaluate.MetricInfo( | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
"predictions": datasets.Value("string"), | |
"references": datasets.Value("string"), | |
} | |
), | |
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"], | |
) | |
def _extract_markdown_table(self,text): | |
text = text.replace('\n', '') | |
text = text.replace(" ","") | |
pattern = r'\|(?:[^|]+\|)+[^|]+\|' | |
matches = re.findall(pattern, text) | |
if matches: | |
return ''.join(matches) | |
return None | |
def _table_to_dict(self,table_str): | |
result_dict = {} | |
table_str = table_str.lstrip("|").rstrip("|") | |
parts = table_str.split('||') | |
parts = [part for part in parts if "--" not in part] | |
legends = parts[0].split("|") | |
rows = len(parts) | |
if rows == 2: | |
nums = parts[1].split("|") | |
for i in range(len(nums)): | |
result_dict[legends[i]]=float(nums[i]) | |
elif rows >=3: | |
for i in range(1,rows): | |
pre_row = parts[i] | |
pre_row = pre_row.split("|") | |
label = pre_row[0] | |
result_dict[label] = {} | |
for j in range(1,len(pre_row)): | |
result_dict[label][legends[j-1]] = float(pre_row[j]) | |
else: | |
return None | |
return result_dict | |
def _markdown_to_dict(self,markdown_str): | |
table_str = self._extract_markdown_table(markdown_str) | |
if table_str: | |
return self._table_to_dict(table_str) | |
else: | |
return None | |
def _calculate_table_metrics(self,pred_table, true_table): | |
true_positives = 0 | |
false_positives = 0 | |
false_negatives = 0 | |
# 遍历预测表格的所有键值对 | |
for key, pred_value in pred_table.items(): | |
if key in true_table: | |
true_value = true_table[key] | |
if isinstance(pred_value, dict) and isinstance(true_value, dict): | |
nested_metrics = self._calculate_table_metrics(pred_value, true_value) | |
true_positives += nested_metrics['true_positives'] | |
false_positives += nested_metrics['false_positives'] | |
false_negatives += nested_metrics['false_negatives'] | |
# 如果值相等 | |
elif pred_value == true_value: | |
true_positives += 1 | |
else: | |
false_positives += 1 | |
false_negatives += 1 | |
else: | |
false_positives += 1 | |
# 计算未匹配的真实值 | |
for key in true_table: | |
if key not in pred_table: | |
false_negatives += 1 | |
# 计算指标 | |
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 | |
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0 | |
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 | |
return { | |
'precision': precision, | |
'recall': recall, | |
'f1': f1, | |
'true_positives': true_positives, | |
'false_positives': false_positives, | |
'false_negatives': false_negatives | |
} | |
def _compute(self, predictions, references): | |
predictions = "".join(predictions) | |
references = "".join(references) | |
return self._calculate_table_metrics(self._markdown_to_dict(predictions), self._markdown_to_dict(references)) | |
def main(): | |
accuracy_metric = Accuracy() | |
# 计算指标 | |
results = accuracy_metric.compute( | |
predictions=[""" | |
| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 | | |
"""], # 预测的表格 | |
references=[""" | |
| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 | | |
"""], # 参考的表格 | |
) | |
print(results) # 输出结果 | |
if __name__ == '__main__': | |
main() | |