File size: 4,173 Bytes
b686823
 
 
 
 
7e34f5e
b686823
 
 
 
acaeec7
b88b6bf
b686823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7798457
b686823
 
adf0b2e
5d21832
b686823
 
 
 
 
 
 
 
 
7e34f5e
b686823
 
 
 
 
 
 
 
adf0b2e
 
eb22e2e
b686823
 
 
 
 
 
 
 
 
a08a2d9
b686823
 
 
7e34f5e
7798457
 
fe281bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import re
import string
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.metrics import Metrics, MetricCategory
from lighteval.metrics.utils import CorpusLevelMetric, MetricUseCase
from aenum import extend_enum
import numpy as np
from lighteval.tasks.requests import Doc
from Levenshtein import distance
import collections
from lighteval.utils import as_list
from ..envs import OWNER

def get_tokens(s):
    if not s:
        return []
    return normalize_answer(s).split()

ARTICLES_REGEX = re.compile(r"\b(a|an|the)\b", re.UNICODE)
def normalize_answer(s):
    def remove_articles(text):
        return ARTICLES_REGEX.sub(" ", text)
    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s.replace('<pad>', '').replace('</s>', '').strip()))))

def compute_f1(a_gold, a_pred):
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(gold_toks == pred_toks)
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def normalized_edit_similarity(p1, p2):
    return 1-distance(p1, p2)/ max(len(p1), len(p2))

def compute_token_edit(a_gold, a_pred):
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(gold_toks == pred_toks)
    num_same = sum([max([normalized_edit_similarity(gold_t, pred_t) for pred_t in pred_toks]) for gold_t in gold_toks])
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def tlnls(a_gold, a_pred):
    digit_count = sum(1 for char in a_pred if char.isdigit())
    if digit_count < len(a_pred) / 2:
        return compute_token_edit(a_gold, a_pred)
    else:
        return compute_f1(a_gold, a_pred)

def heq_eval_fn(golds: list[str], predictions: list[str], formatted_doc: Doc = None):
    if len(predictions)  > 1:
        raise ValueError("Predictions should have one item")
    pred = re.sub('<[^>]+>', '', predictions[0]).strip() # remove xml tags
    return max([tlnls(x, pred) for x in golds])
    
heq_tlnls_metric = CorpusLevelMetric(
    metric="heq_tlnls",
    higher_is_better=True,
    category=MetricCategory.GENERATIVE,
    use_case=MetricUseCase.ACCURACY,
    corpus_level_fn=np.mean,
    sample_level_fn=heq_eval_fn
)
extend_enum(Metrics, 'heq_tlnls_metric', heq_tlnls_metric)

def heq_prompt_fn(line, task_name: str = None):
    """Defines how to go from a dataset line to a doc object.
    Follow examples in src/lighteval/tasks/tasks_prompt_formatting.py, or get more info
    about what this function should do in the README.
    """
    return Doc(
        task_name=task_name,
        query=line["prompt"].strip(),
        choices=[resp.strip() for resp in line["response"]],
        gold_index=list(range(len(line["response"]))),
        instruction="",
    )

# This is how you create a simple tasks (like hellaswag) which has one single subset
# attached to it, and one evaluation possible.
heq_task = LightevalTaskConfig(
    name="heq-qa-tlnls",
    prompt_function="heq_prompt_fn",  # must be defined in the file or imported from src/lighteval/tasks/tasks_prompt_formatting.py
    suite=["custom"],
    hf_repo=f"{OWNER}/tests",
    hf_subset="default",
    hf_avail_splits=["heq"],
    evaluation_splits=["heq"],
    metric=['heq_tlnls_metric'],
    stop_sequence=['\n'],
    generation_size=64
)