File size: 6,822 Bytes
8ab7e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b20012f
 
8ab7e68
 
 
 
 
 
 
 
b20012f
 
 
 
8ab7e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TODO: Add a description here."""

import evaluate
import datasets
from collections import defaultdict, Counter
import re
import numpy as np
from sacrebleu.metrics import BLEU

# TODO: Add BibTeX citation
_CITATION = """\
@article{zhou2022doccoder,
  title={DocCoder: Generating Code by Retrieving and Reading Docs},
  author={Zhou, Shuyan and Alon, Uri and Xu, Frank F and Jiang, Zhengbao and Neubig, Graham},
  journal={arXiv preprint arXiv:2207.05987},
  year={2022}
}
"""

_DESCRIPTION = """\
The evaluation metrics for natural language to bash generation.
The preprocessing is customized for [`tldr`](https://github.com/tldr-pages/tldr) dataset where we first conduct annoymization on the variables.
"""


_KWARGS_DESCRIPTION = """
predictions: list of str. The predictions
references: list of str. The references

Return
- **template_matching**: the exact match accuracy
- **command_accuracy**: accuracy of predicting the correct bash command name (e.g., `ls`)
- **bleu_char**: char bleu score
- **token recall/precision/f1**: the recall/precision/f1 of the predicted tokens
"""

VAR_STR = "[[VAR]]"


def clean_command(s):
    s = s.replace("sudo", "").strip()
    s =  s.replace("`", "").replace('"', "").replace("'", "")
    #  '>', '|', '+'
    s = s.replace("|", " ").replace(">", " ").replace("<", " ")
    s = " ".join(s.split())
    return s

def anonymize_command(s):
    s = s.replace("={", " {")
    var_to_pc_holder = defaultdict(lambda: len(var_to_pc_holder))
    for var in re.findall("{{(.*?)}}", s):
        _ = var_to_pc_holder[var]
    for var, id in var_to_pc_holder.items():
        var_str = "{{%s}}" % var
        s = s.replace(var_str, f"{VAR_STR}_{id}")
    # s = re.sub("{{.*?}}", VAR_STR, s)
    return s

def clean_anonymize_command(s):
    return anonymize_command(clean_command(s))


def get_bag_of_words(cmd):
    cmd = clean_anonymize_command(cmd)
    tokens = cmd.strip().split()
    return tokens

def calc_template_matching(gold, pred):
    ag = clean_anonymize_command(gold)
    ap = clean_anonymize_command(pred)
    m = {'template_matching': int(ag == ap)}
    return m

def token_prf(tok_gold, tok_pred, match_blank=False):
    if match_blank and len(tok_gold) == 0: # do not generate anything
        if len(tok_pred) == 0:
            m = {'r': 1, 'p': 1, 'f1': 1}
        else:
            m = {'r': 0, 'p': 0, 'f1': 0}
    else:
        tok_gold_dict = Counter(tok_gold)
        tok_pred_dict = Counter(tok_pred)
        tokens = set([*tok_gold_dict] + [*tok_pred_dict])
        hit = 0
        for token in tokens:
            hit += min(tok_gold_dict.get(token, 0), tok_pred_dict.get(token, 0))
        p = hit / (sum(tok_pred_dict.values()) + 1e-10)
        r = hit / (sum(tok_gold_dict.values()) + 1e-10)
        f1 = 2 * p * r / (p + r + 1e-10)
        m = {'r': r, 'p': p, 'f1': f1}
    return m

def measure_bag_of_word(gold, pred):
    tok_gold = get_bag_of_words(gold)
    tok_pred = get_bag_of_words(pred)
    m = token_prf(tok_gold, tok_pred) # whole sentence
    gold_cmd = tok_gold[0] if len(tok_gold) else "NONE_GOLD"
    pred_cmd = tok_pred[0] if len(tok_pred) else "NONE_PRED"
    m = {**m, 'command_accuracy': int(gold_cmd == pred_cmd)}

    return m

def tldr_metrics(references, predictions):
    assert len(references) == len(predictions)
    metric_list = defaultdict(list)
    for ref, pred in zip(references, predictions):
        for k, v in calc_template_matching(ref, pred).items():
            metric_list[k].append(v)
        for k, v in measure_bag_of_word(ref, pred).items():
            metric_list[k].append(v)

    for k, v in metric_list.items():
        metric_list[k] = np.mean(v)

    def clean_for_bleu(s):
        s = s.replace("sudo", "").strip()
        s = s.replace("`", "").replace('"', "").replace("'", "")
        #  '>', '|', '+'
        s = s.replace("|", " ").replace(">", " ").replace("<", " ")
        s = " ".join(s.split())
        s = s.replace("={", " {")
        var_to_pc_holder = defaultdict(lambda: len(var_to_pc_holder))
        for var in re.findall("{{(.*?)}}", s):
            _ = var_to_pc_holder[var]
        for var, id in var_to_pc_holder.items():
            var_str = "{{%s}}" % var
            s = s.replace(var_str, f"${id}")
        # s = re.sub("{{.*?}}", VAR_STR, s)
        # print(s)
        return s
    
    def to_characters(s):
        # s = s.replace(" ", "")
        # s = ' '.join(list(s))
        return s
    # character level
    bleu = BLEU(tokenize='char')
    predictions = [to_characters(clean_for_bleu(x)) for x in predictions]
    references = [to_characters(clean_for_bleu(x)) for x in references]
    bleu_score = bleu.corpus_score(predictions, [references]).score
    metric_list['bleu_char'] = bleu_score
    return metric_list

@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class TLDREval(evaluate.Metric):
    """Evaluate Bash scripts."""

    def _info(self):
        return evaluate.MetricInfo(
            # This is the description that will appear on the modules page.
            module_type="metric",
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            # This defines the format of each prediction and reference
            features=datasets.Features({
                "predictions": datasets.Value("string", id="sequence"),
                "references": datasets.Value("string", id="sequence"),
            }),
            # Homepage of the module for documentation
            homepage="https://github.com/shuyanzhou/docprompting",
            # Additional links to the codebase or references
            codebase_urls=["https://github.com/shuyanzhou/docprompting"],
            reference_urls=["https://github.com/shuyanzhou/docprompting"]
        )

    def _compute(self, predictions, references):
        """Returns the scores"""
        metrics = tldr_metrics(references, predictions)
        # rename for better display
        metrics['token_recall'] = metrics.pop('r')
        metrics['token_precision'] = metrics.pop('p')
        metrics['token_f1'] = metrics.pop('f1')
        return dict(metrics)