# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """TODO: Add a description here.""" import evaluate import datasets from collections import defaultdict, Counter import re import numpy as np from sacrebleu.metrics import BLEU # TODO: Add BibTeX citation _CITATION = """\ @article{zhou2022doccoder, title={DocCoder: Generating Code by Retrieving and Reading Docs}, author={Zhou, Shuyan and Alon, Uri and Xu, Frank F and Jiang, Zhengbao and Neubig, Graham}, journal={arXiv preprint arXiv:2207.05987}, year={2022} } """ _DESCRIPTION = """\ The evaluation metrics for natural language to bash generation. The preprocessing is customized for [`tldr`](https://github.com/tldr-pages/tldr) dataset where we first conduct annoymization on the variables. """ _KWARGS_DESCRIPTION = """ predictions: list of str. The predictions references: list of str. The references Return - **template_matching**: the exact match accuracy - **command_accuracy**: accuracy of predicting the correct bash command name (e.g., `ls`) - **bleu_char**: char bleu score - **token recall/precision/f1**: the recall/precision/f1 of the predicted tokens """ VAR_STR = "[[VAR]]" def clean_command(s): s = s.replace("sudo", "").strip() s = s.replace("`", "").replace('"', "").replace("'", "") # '>', '|', '+' s = s.replace("|", " ").replace(">", " ").replace("<", " ") s = " ".join(s.split()) return s def anonymize_command(s): s = s.replace("={", " {") var_to_pc_holder = defaultdict(lambda: len(var_to_pc_holder)) for var in re.findall("{{(.*?)}}", s): _ = var_to_pc_holder[var] for var, id in var_to_pc_holder.items(): var_str = "{{%s}}" % var s = s.replace(var_str, f"{VAR_STR}_{id}") # s = re.sub("{{.*?}}", VAR_STR, s) return s def clean_anonymize_command(s): return anonymize_command(clean_command(s)) def get_bag_of_words(cmd): cmd = clean_anonymize_command(cmd) tokens = cmd.strip().split() return tokens def calc_template_matching(gold, pred): ag = clean_anonymize_command(gold) ap = clean_anonymize_command(pred) m = {'template_matching': int(ag == ap)} return m def token_prf(tok_gold, tok_pred, match_blank=False): if match_blank and len(tok_gold) == 0: # do not generate anything if len(tok_pred) == 0: m = {'r': 1, 'p': 1, 'f1': 1} else: m = {'r': 0, 'p': 0, 'f1': 0} else: tok_gold_dict = Counter(tok_gold) tok_pred_dict = Counter(tok_pred) tokens = set([*tok_gold_dict] + [*tok_pred_dict]) hit = 0 for token in tokens: hit += min(tok_gold_dict.get(token, 0), tok_pred_dict.get(token, 0)) p = hit / (sum(tok_pred_dict.values()) + 1e-10) r = hit / (sum(tok_gold_dict.values()) + 1e-10) f1 = 2 * p * r / (p + r + 1e-10) m = {'r': r, 'p': p, 'f1': f1} return m def measure_bag_of_word(gold, pred): tok_gold = get_bag_of_words(gold) tok_pred = get_bag_of_words(pred) m = token_prf(tok_gold, tok_pred) # whole sentence gold_cmd = tok_gold[0] if len(tok_gold) else "NONE_GOLD" pred_cmd = tok_pred[0] if len(tok_pred) else "NONE_PRED" m = {**m, 'command_accuracy': int(gold_cmd == pred_cmd)} return m def tldr_metrics(references, predictions): assert len(references) == len(predictions) metric_list = defaultdict(list) for ref, pred in zip(references, predictions): for k, v in calc_template_matching(ref, pred).items(): metric_list[k].append(v) for k, v in measure_bag_of_word(ref, pred).items(): metric_list[k].append(v) for k, v in metric_list.items(): metric_list[k] = np.mean(v) def clean_for_bleu(s): s = s.replace("sudo", "").strip() s = s.replace("`", "").replace('"', "").replace("'", "") # '>', '|', '+' s = s.replace("|", " ").replace(">", " ").replace("<", " ") s = " ".join(s.split()) s = s.replace("={", " {") var_to_pc_holder = defaultdict(lambda: len(var_to_pc_holder)) for var in re.findall("{{(.*?)}}", s): _ = var_to_pc_holder[var] for var, id in var_to_pc_holder.items(): var_str = "{{%s}}" % var s = s.replace(var_str, f"${id}") # s = re.sub("{{.*?}}", VAR_STR, s) # print(s) return s def to_characters(s): # s = s.replace(" ", "") # s = ' '.join(list(s)) return s # character level bleu = BLEU(tokenize='char') predictions = [to_characters(clean_for_bleu(x)) for x in predictions] references = [to_characters(clean_for_bleu(x)) for x in references] bleu_score = bleu.corpus_score(predictions, [references]).score metric_list['bleu_char'] = bleu_score return metric_list @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class TLDREval(evaluate.Metric): """Evaluate Bash scripts.""" def _info(self): return evaluate.MetricInfo( # This is the description that will appear on the modules page. module_type="metric", description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, # This defines the format of each prediction and reference features=datasets.Features({ "predictions": datasets.Value("string", id="sequence"), "references": datasets.Value("string", id="sequence"), }), # Homepage of the module for documentation homepage="https://github.com/shuyanzhou/docprompting", # Additional links to the codebase or references codebase_urls=["https://github.com/shuyanzhou/docprompting"], reference_urls=["https://github.com/shuyanzhou/docprompting"] ) def _compute(self, predictions, references): """Returns the scores""" metrics = tldr_metrics(references, predictions) # rename for better display metrics['token_recall'] = metrics.pop('r') metrics['token_precision'] = metrics.pop('p') metrics['token_f1'] = metrics.pop('f1') return dict(metrics)