Spaces:
Paused
Paused
File size: 4,173 Bytes
b686823 7e34f5e b686823 acaeec7 b88b6bf b686823 7798457 b686823 adf0b2e 5d21832 b686823 7e34f5e b686823 adf0b2e eb22e2e b686823 a08a2d9 b686823 7e34f5e 7798457 fe281bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import re
import string
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.metrics import Metrics, MetricCategory
from lighteval.metrics.utils import CorpusLevelMetric, MetricUseCase
from aenum import extend_enum
import numpy as np
from lighteval.tasks.requests import Doc
from Levenshtein import distance
import collections
from lighteval.utils import as_list
from ..envs import OWNER
def get_tokens(s):
if not s:
return []
return normalize_answer(s).split()
ARTICLES_REGEX = re.compile(r"\b(a|an|the)\b", re.UNICODE)
def normalize_answer(s):
def remove_articles(text):
return ARTICLES_REGEX.sub(" ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s.replace('<pad>', '').replace('</s>', '').strip()))))
def compute_f1(a_gold, a_pred):
gold_toks = get_tokens(a_gold)
pred_toks = get_tokens(a_pred)
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if len(gold_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def normalized_edit_similarity(p1, p2):
return 1-distance(p1, p2)/ max(len(p1), len(p2))
def compute_token_edit(a_gold, a_pred):
gold_toks = get_tokens(a_gold)
pred_toks = get_tokens(a_pred)
if len(gold_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
num_same = sum([max([normalized_edit_similarity(gold_t, pred_t) for pred_t in pred_toks]) for gold_t in gold_toks])
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def tlnls(a_gold, a_pred):
digit_count = sum(1 for char in a_pred if char.isdigit())
if digit_count < len(a_pred) / 2:
return compute_token_edit(a_gold, a_pred)
else:
return compute_f1(a_gold, a_pred)
def heq_eval_fn(golds: list[str], predictions: list[str], formatted_doc: Doc = None):
if len(predictions) > 1:
raise ValueError("Predictions should have one item")
pred = re.sub('<[^>]+>', '', predictions[0]).strip() # remove xml tags
return max([tlnls(x, pred) for x in golds])
heq_tlnls_metric = CorpusLevelMetric(
metric="heq_tlnls",
higher_is_better=True,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.ACCURACY,
corpus_level_fn=np.mean,
sample_level_fn=heq_eval_fn
)
extend_enum(Metrics, 'heq_tlnls_metric', heq_tlnls_metric)
def heq_prompt_fn(line, task_name: str = None):
"""Defines how to go from a dataset line to a doc object.
Follow examples in src/lighteval/tasks/tasks_prompt_formatting.py, or get more info
about what this function should do in the README.
"""
return Doc(
task_name=task_name,
query=line["prompt"].strip(),
choices=[resp.strip() for resp in line["response"]],
gold_index=list(range(len(line["response"]))),
instruction="",
)
# This is how you create a simple tasks (like hellaswag) which has one single subset
# attached to it, and one evaluation possible.
heq_task = LightevalTaskConfig(
name="heq-qa-tlnls",
prompt_function="heq_prompt_fn", # must be defined in the file or imported from src/lighteval/tasks/tasks_prompt_formatting.py
suite=["custom"],
hf_repo=f"{OWNER}/tests",
hf_subset="default",
hf_avail_splits=["heq"],
evaluation_splits=["heq"],
metric=['heq_tlnls_metric'],
stop_sequence=['\n'],
generation_size=64
) |