|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""TODO: Add a description here.""" |
|
|
|
import evaluate |
|
import datasets |
|
from collections import defaultdict, Counter |
|
import re |
|
import numpy as np |
|
from sacrebleu.metrics import BLEU |
|
|
|
|
|
_CITATION = """\ |
|
@article{zhou2022doccoder, |
|
title={DocCoder: Generating Code by Retrieving and Reading Docs}, |
|
author={Zhou, Shuyan and Alon, Uri and Xu, Frank F and Jiang, Zhengbao and Neubig, Graham}, |
|
journal={arXiv preprint arXiv:2207.05987}, |
|
year={2022} |
|
} |
|
""" |
|
|
|
_DESCRIPTION = """\ |
|
This metric is used to evaluate the quality of a generated bash script. |
|
""" |
|
|
|
|
|
_KWARGS_DESCRIPTION = """ |
|
predictions: list of str. The predictions |
|
references: list of str. The references |
|
|
|
Return |
|
|
|
""" |
|
|
|
VAR_STR = "[[VAR]]" |
|
|
|
|
|
def clean_command(s): |
|
s = s.replace("sudo", "").strip() |
|
s = s.replace("`", "").replace('"', "").replace("'", "") |
|
|
|
s = s.replace("|", " ").replace(">", " ").replace("<", " ") |
|
s = " ".join(s.split()) |
|
return s |
|
|
|
def anonymize_command(s): |
|
s = s.replace("={", " {") |
|
var_to_pc_holder = defaultdict(lambda: len(var_to_pc_holder)) |
|
for var in re.findall("{{(.*?)}}", s): |
|
_ = var_to_pc_holder[var] |
|
for var, id in var_to_pc_holder.items(): |
|
var_str = "{{%s}}" % var |
|
s = s.replace(var_str, f"{VAR_STR}_{id}") |
|
|
|
return s |
|
|
|
def clean_anonymize_command(s): |
|
return anonymize_command(clean_command(s)) |
|
|
|
|
|
def get_bag_of_words(cmd): |
|
cmd = clean_anonymize_command(cmd) |
|
tokens = cmd.strip().split() |
|
return tokens |
|
|
|
def calc_template_matching(gold, pred): |
|
ag = clean_anonymize_command(gold) |
|
ap = clean_anonymize_command(pred) |
|
m = {'template_matching': int(ag == ap)} |
|
return m |
|
|
|
def token_prf(tok_gold, tok_pred, match_blank=False): |
|
if match_blank and len(tok_gold) == 0: |
|
if len(tok_pred) == 0: |
|
m = {'r': 1, 'p': 1, 'f1': 1} |
|
else: |
|
m = {'r': 0, 'p': 0, 'f1': 0} |
|
else: |
|
tok_gold_dict = Counter(tok_gold) |
|
tok_pred_dict = Counter(tok_pred) |
|
tokens = set([*tok_gold_dict] + [*tok_pred_dict]) |
|
hit = 0 |
|
for token in tokens: |
|
hit += min(tok_gold_dict.get(token, 0), tok_pred_dict.get(token, 0)) |
|
p = hit / (sum(tok_pred_dict.values()) + 1e-10) |
|
r = hit / (sum(tok_gold_dict.values()) + 1e-10) |
|
f1 = 2 * p * r / (p + r + 1e-10) |
|
m = {'r': r, 'p': p, 'f1': f1} |
|
return m |
|
|
|
def measure_bag_of_word(gold, pred): |
|
tok_gold = get_bag_of_words(gold) |
|
tok_pred = get_bag_of_words(pred) |
|
m = token_prf(tok_gold, tok_pred) |
|
gold_cmd = tok_gold[0] if len(tok_gold) else "NONE_GOLD" |
|
pred_cmd = tok_pred[0] if len(tok_pred) else "NONE_PRED" |
|
m = {**m, 'command_accuracy': int(gold_cmd == pred_cmd)} |
|
|
|
return m |
|
|
|
def tldr_metrics(references, predictions): |
|
assert len(references) == len(predictions) |
|
metric_list = defaultdict(list) |
|
for ref, pred in zip(references, predictions): |
|
for k, v in calc_template_matching(ref, pred).items(): |
|
metric_list[k].append(v) |
|
for k, v in measure_bag_of_word(ref, pred).items(): |
|
metric_list[k].append(v) |
|
|
|
for k, v in metric_list.items(): |
|
metric_list[k] = np.mean(v) |
|
|
|
def clean_for_bleu(s): |
|
s = s.replace("sudo", "").strip() |
|
s = s.replace("`", "").replace('"', "").replace("'", "") |
|
|
|
s = s.replace("|", " ").replace(">", " ").replace("<", " ") |
|
s = " ".join(s.split()) |
|
s = s.replace("={", " {") |
|
var_to_pc_holder = defaultdict(lambda: len(var_to_pc_holder)) |
|
for var in re.findall("{{(.*?)}}", s): |
|
_ = var_to_pc_holder[var] |
|
for var, id in var_to_pc_holder.items(): |
|
var_str = "{{%s}}" % var |
|
s = s.replace(var_str, f"${id}") |
|
|
|
|
|
return s |
|
|
|
def to_characters(s): |
|
|
|
|
|
return s |
|
|
|
bleu = BLEU(tokenize='char') |
|
predictions = [to_characters(clean_for_bleu(x)) for x in predictions] |
|
references = [to_characters(clean_for_bleu(x)) for x in references] |
|
bleu_score = bleu.corpus_score(predictions, [references]).score |
|
metric_list['bleu_char'] = bleu_score |
|
return metric_list |
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class TLDREval(evaluate.Metric): |
|
"""Evaluate Bash scripts.""" |
|
|
|
def _info(self): |
|
return evaluate.MetricInfo( |
|
|
|
module_type="metric", |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
|
|
features=datasets.Features({ |
|
"predictions": datasets.Value("string", id="sequence"), |
|
"references": datasets.Value("string", id="sequence"), |
|
}), |
|
|
|
homepage="https://github.com/shuyanzhou/docprompting", |
|
|
|
codebase_urls=["https://github.com/shuyanzhou/docprompting"], |
|
reference_urls=["https://github.com/shuyanzhou/docprompting"] |
|
) |
|
|
|
def _compute(self, predictions, references): |
|
"""Returns the scores""" |
|
metrics = tldr_metrics(references, predictions) |
|
|
|
metrics['token_recall'] = metrics.pop('r') |
|
metrics['token_precision'] = metrics.pop('p') |
|
metrics['token_f1'] = metrics.pop('f1') |
|
return dict(metrics) |