tldr_eval / tldr_eval.py
shuyanzh's picture
add descriptions
b20012f
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TODO: Add a description here."""
import evaluate
import datasets
from collections import defaultdict, Counter
import re
import numpy as np
from sacrebleu.metrics import BLEU
# TODO: Add BibTeX citation
_CITATION = """\
@article{zhou2022doccoder,
title={DocCoder: Generating Code by Retrieving and Reading Docs},
author={Zhou, Shuyan and Alon, Uri and Xu, Frank F and Jiang, Zhengbao and Neubig, Graham},
journal={arXiv preprint arXiv:2207.05987},
year={2022}
}
"""
_DESCRIPTION = """\
The evaluation metrics for natural language to bash generation.
The preprocessing is customized for [`tldr`](https://github.com/tldr-pages/tldr) dataset where we first conduct annoymization on the variables.
"""
_KWARGS_DESCRIPTION = """
predictions: list of str. The predictions
references: list of str. The references
Return
- **template_matching**: the exact match accuracy
- **command_accuracy**: accuracy of predicting the correct bash command name (e.g., `ls`)
- **bleu_char**: char bleu score
- **token recall/precision/f1**: the recall/precision/f1 of the predicted tokens
"""
VAR_STR = "[[VAR]]"
def clean_command(s):
s = s.replace("sudo", "").strip()
s = s.replace("`", "").replace('"', "").replace("'", "")
# '>', '|', '+'
s = s.replace("|", " ").replace(">", " ").replace("<", " ")
s = " ".join(s.split())
return s
def anonymize_command(s):
s = s.replace("={", " {")
var_to_pc_holder = defaultdict(lambda: len(var_to_pc_holder))
for var in re.findall("{{(.*?)}}", s):
_ = var_to_pc_holder[var]
for var, id in var_to_pc_holder.items():
var_str = "{{%s}}" % var
s = s.replace(var_str, f"{VAR_STR}_{id}")
# s = re.sub("{{.*?}}", VAR_STR, s)
return s
def clean_anonymize_command(s):
return anonymize_command(clean_command(s))
def get_bag_of_words(cmd):
cmd = clean_anonymize_command(cmd)
tokens = cmd.strip().split()
return tokens
def calc_template_matching(gold, pred):
ag = clean_anonymize_command(gold)
ap = clean_anonymize_command(pred)
m = {'template_matching': int(ag == ap)}
return m
def token_prf(tok_gold, tok_pred, match_blank=False):
if match_blank and len(tok_gold) == 0: # do not generate anything
if len(tok_pred) == 0:
m = {'r': 1, 'p': 1, 'f1': 1}
else:
m = {'r': 0, 'p': 0, 'f1': 0}
else:
tok_gold_dict = Counter(tok_gold)
tok_pred_dict = Counter(tok_pred)
tokens = set([*tok_gold_dict] + [*tok_pred_dict])
hit = 0
for token in tokens:
hit += min(tok_gold_dict.get(token, 0), tok_pred_dict.get(token, 0))
p = hit / (sum(tok_pred_dict.values()) + 1e-10)
r = hit / (sum(tok_gold_dict.values()) + 1e-10)
f1 = 2 * p * r / (p + r + 1e-10)
m = {'r': r, 'p': p, 'f1': f1}
return m
def measure_bag_of_word(gold, pred):
tok_gold = get_bag_of_words(gold)
tok_pred = get_bag_of_words(pred)
m = token_prf(tok_gold, tok_pred) # whole sentence
gold_cmd = tok_gold[0] if len(tok_gold) else "NONE_GOLD"
pred_cmd = tok_pred[0] if len(tok_pred) else "NONE_PRED"
m = {**m, 'command_accuracy': int(gold_cmd == pred_cmd)}
return m
def tldr_metrics(references, predictions):
assert len(references) == len(predictions)
metric_list = defaultdict(list)
for ref, pred in zip(references, predictions):
for k, v in calc_template_matching(ref, pred).items():
metric_list[k].append(v)
for k, v in measure_bag_of_word(ref, pred).items():
metric_list[k].append(v)
for k, v in metric_list.items():
metric_list[k] = np.mean(v)
def clean_for_bleu(s):
s = s.replace("sudo", "").strip()
s = s.replace("`", "").replace('"', "").replace("'", "")
# '>', '|', '+'
s = s.replace("|", " ").replace(">", " ").replace("<", " ")
s = " ".join(s.split())
s = s.replace("={", " {")
var_to_pc_holder = defaultdict(lambda: len(var_to_pc_holder))
for var in re.findall("{{(.*?)}}", s):
_ = var_to_pc_holder[var]
for var, id in var_to_pc_holder.items():
var_str = "{{%s}}" % var
s = s.replace(var_str, f"${id}")
# s = re.sub("{{.*?}}", VAR_STR, s)
# print(s)
return s
def to_characters(s):
# s = s.replace(" ", "")
# s = ' '.join(list(s))
return s
# character level
bleu = BLEU(tokenize='char')
predictions = [to_characters(clean_for_bleu(x)) for x in predictions]
references = [to_characters(clean_for_bleu(x)) for x in references]
bleu_score = bleu.corpus_score(predictions, [references]).score
metric_list['bleu_char'] = bleu_score
return metric_list
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class TLDREval(evaluate.Metric):
"""Evaluate Bash scripts."""
def _info(self):
return evaluate.MetricInfo(
# This is the description that will appear on the modules page.
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=datasets.Features({
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Value("string", id="sequence"),
}),
# Homepage of the module for documentation
homepage="https://github.com/shuyanzhou/docprompting",
# Additional links to the codebase or references
codebase_urls=["https://github.com/shuyanzhou/docprompting"],
reference_urls=["https://github.com/shuyanzhou/docprompting"]
)
def _compute(self, predictions, references):
"""Returns the scores"""
metrics = tldr_metrics(references, predictions)
# rename for better display
metrics['token_recall'] = metrics.pop('r')
metrics['token_precision'] = metrics.pop('p')
metrics['token_f1'] = metrics.pop('f1')
return dict(metrics)