Shaltiel commited on
Commit
fe84b5e
·
1 Parent(s): 7135a84

Added translation task

Browse files
custom_tasks.py CHANGED
@@ -9,11 +9,12 @@ Author:
9
  from src.custom_tasks.heq_task import *
10
  from src.custom_tasks.sentiment_task import *
11
  from src.custom_tasks.winograd_task import *
 
12
 
13
  ## MODULE LOGIC
14
  # You should not need to touch this
15
  # Convert to dict for lighteval
16
- TASKS_TABLE = [task.as_dict() for task in [heq_task, sentiment_task, winograd_task]]
17
 
18
  if __name__ == "__main__":
19
  print(t["name"] for t in TASKS_TABLE)
 
9
  from src.custom_tasks.heq_task import *
10
  from src.custom_tasks.sentiment_task import *
11
  from src.custom_tasks.winograd_task import *
12
+ from src.custom_tasks.translation_task import *
13
 
14
  ## MODULE LOGIC
15
  # You should not need to touch this
16
  # Convert to dict for lighteval
17
+ TASKS_TABLE = [task.as_dict() for task in [heq_task, sentiment_task, winograd_task, translation_task]]
18
 
19
  if __name__ == "__main__":
20
  print(t["name"] for t in TASKS_TABLE)
src/about.py CHANGED
@@ -21,5 +21,5 @@ TASKS_HARNESS = [task.value.benchmark for task in Tasks]
21
  # ---------------------------------------------------
22
 
23
  # TASKS_LIGHTEVAL = "lighteval|anli:r1|0|0,lighteval|logiqa|0|0"
24
- tasks = ['heq-qa-tlnls', 'sentiment-acc', 'winograd-acc']
25
  TASKS_LIGHTEVAL = ','.join(f'custom|{t}|0|0' for t in tasks)# + ',leaderboard|arc:challenge|0|0'
 
21
  # ---------------------------------------------------
22
 
23
  # TASKS_LIGHTEVAL = "lighteval|anli:r1|0|0,lighteval|logiqa|0|0"
24
+ tasks = ['heq-qa-tlnls', 'sentiment-acc', 'winograd-acc', 'he-en-trans-bleu']
25
  TASKS_LIGHTEVAL = ','.join(f'custom|{t}|0|0' for t in tasks)# + ',leaderboard|arc:challenge|0|0'
src/custom_tasks/translation_task.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import string
3
+ from lighteval.tasks.lighteval_task import LightevalTaskConfig
4
+ from lighteval.metrics import Metrics, MetricCategory
5
+ from lighteval.metrics.utils import CorpusLevelMetric, MetricUseCase
6
+ from aenum import extend_enum
7
+ import numpy as np
8
+ from lighteval.tasks.requests import Doc
9
+ from Levenshtein import distance
10
+ import collections
11
+ from lighteval.utils import as_list
12
+ import sacrebleu
13
+
14
+ def trans_prompt_fn(line, task_name: str = None):
15
+ """Defines how to go from a dataset line to a doc object.
16
+ Follow examples in src/lighteval/tasks/tasks_prompt_formatting.py, or get more info
17
+ about what this function should do in the README.
18
+ """
19
+ return Doc(
20
+ task_name=task_name,
21
+ query=line["prompt"].strip(),
22
+ choices=[line["response"][0].strip()],
23
+ gold_index=[0],
24
+ instruction="",
25
+ )
26
+
27
+ def translation_eval_fn(golds: list[str], predictions: list[str], formatted_doc: Doc = None):
28
+ if len(predictions) > 1:
29
+ raise ValueError("Predictions should have one item")
30
+ return float(sacrebleu.sentence_bleu(hypothesis=predictions[0], references=golds).score / 100)
31
+
32
+ sentence_bleu = CorpusLevelMetric(
33
+ metric="sentence_bleu",
34
+ sample_level_fn=translation_eval_fn,
35
+ category=MetricCategory.GENERATIVE,
36
+ use_case=MetricUseCase.TRANSLATION,
37
+ corpus_level_fn=np.mean,
38
+ higher_is_better=True,
39
+ )
40
+ extend_enum(Metrics, 'sentence_bleu', sentence_bleu)
41
+
42
+ # This is how you create a simple tasks (like hellaswag) which has one single subset
43
+ # attached to it, and one evaluation possible.
44
+ translation_task = LightevalTaskConfig(
45
+ name="he-en-trans-bleu",
46
+ prompt_function="trans_prompt_fn", # must be defined in the file or imported from src/lighteval/tasks/tasks_prompt_formatting.py
47
+ suite=["custom"],
48
+ hf_repo="dicta-hebrew-llm-leaderboard/tests",
49
+ hf_subset="default",
50
+ hf_avail_splits=["en2he", "he2en"],
51
+ evaluation_splits=["en2he", "he2en"],
52
+ metric=['sentence_bleu', 'bleu_1', 'bleu_4'],
53
+ stop_sequence=['\n'],
54
+ generation_size=220
55
+ )