pminervini commited on
Commit
2561b63
1 Parent(s): f00379a
halueval-cli.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from huggingface_hub import snapshot_download
4
+
5
+ from src.backend.envs import EVAL_REQUESTS_PATH_BACKEND
6
+ from src.backend.manage_requests import get_eval_requests
7
+ from src.backend.manage_requests import EvalRequest
8
+ from src.backend.run_eval_suite import run_evaluation
9
+
10
+ from lm_eval.tasks import initialize_tasks, include_task_folder
11
+ from lm_eval import tasks, evaluator, utils
12
+
13
+ from src.backend.envs import Tasks, EVAL_REQUESTS_PATH_BACKEND, EVAL_RESULTS_PATH_BACKEND, DEVICE, LIMIT, Task
14
+ from src.envs import QUEUE_REPO
15
+
16
+
17
+ def main():
18
+ snapshot_download(repo_id=QUEUE_REPO, revision="main", local_dir=EVAL_REQUESTS_PATH_BACKEND, repo_type="dataset", max_workers=60)
19
+
20
+ PENDING_STATUS = "PENDING"
21
+ RUNNING_STATUS = "RUNNING"
22
+ FINISHED_STATUS = "FINISHED"
23
+ FAILED_STATUS = "FAILED"
24
+
25
+ status = [PENDING_STATUS, RUNNING_STATUS, FINISHED_STATUS, FAILED_STATUS]
26
+
27
+ # Get all eval request that are FINISHED, if you want to run other evals, change this parameter
28
+ eval_requests: list[EvalRequest] = get_eval_requests(job_status=status, hf_repo=QUEUE_REPO, local_dir=EVAL_REQUESTS_PATH_BACKEND)
29
+ eval_request = [r for r in eval_requests if 'bloom-560m' in r.model][0]
30
+
31
+ task_names = ['halueval_qa']
32
+
33
+ include_task_folder("src/backend/tasks/")
34
+ initialize_tasks('INFO')
35
+
36
+ print(tasks.ALL_TASKS)
37
+
38
+ task_names = utils.pattern_match(task_names, tasks.ALL_TASKS)
39
+
40
+ print(f"Selected Tasks: {task_names}")
41
+
42
+ results = evaluator.simple_evaluate(model="hf-auto", model_args=eval_request.get_model_args(), tasks=task_names, num_fewshot=0,
43
+ batch_size=4, device=DEVICE, use_cache=None, limit=8, write_out=True)
44
+
45
+ print('AAA', results)
46
+
47
+ if __name__ == "__main__":
48
+ main()
src/backend/tasks/halueval/halueval_qa.yaml CHANGED
@@ -25,7 +25,7 @@ metric_list:
25
  - metric: em
26
  aggregation: mean
27
  higher_is_better: true
28
- - metric: f1
29
  aggregation: mean
30
  higher_is_better: true
31
  metadata:
 
25
  - metric: em
26
  aggregation: mean
27
  higher_is_better: true
28
+ - metric: correctness
29
  aggregation: mean
30
  higher_is_better: true
31
  metadata:
src/backend/tasks/halueval/utils.py CHANGED
@@ -36,52 +36,39 @@ You should try your best to determine if the answer contains non-factual or hall
36
 
37
 
38
  def doc_to_text_qa(doc: dict[str, str]) -> str:
 
39
  doc_text = QA_INSTURCTIONS + "\n\n#Question#: " + doc["question"] + "\n#Answer#: " + doc["answer"] + "\n#Your Judgement#:"
40
  return doc_text
41
 
42
 
43
  def doc_to_target_qa(doc: dict[str, str]) -> str:
 
44
  return doc['hallucination']
45
 
46
 
47
- def em(gold_list: list[str], predictions: list[str]):
48
- # tests for exact match and on the normalised answer (compute_exact)
49
- em_sum = 0.0
50
- if len(gold_list) > 1:
51
- for i in range(len(gold_list)):
52
- gold_answers = gold_list[0:i] + gold_list[i + 1 :]
53
- # predictions compared against (n) golds and take maximum
54
- em_sum += max(squad_metrics.compute_exact(a, predictions) for a in gold_answers)
55
- else:
56
- em_sum += max(squad_metrics.compute_exact(a, predictions) for a in gold_list)
57
- return em_sum / max(1, len(gold_list))
58
 
 
 
 
 
 
 
59
 
60
- def compute_metrics(gold_list: list[str], predictions: list[str]) -> dict[str, float]:
61
- f1_sum = 0.0
62
- em_sum = 0.0
63
 
64
- is_correct_lst = []
65
- is_exact_lst = []
 
66
 
67
- if len(gold_list) > 1:
68
- for i in range(len(gold_list)):
69
- gold_answers = gold_list[0:i] + gold_list[i + 1 :]
70
- # predictions compared against (n) golds and take maximum
71
- em_sum += max(squad_metrics.compute_exact(a, predictions) for a in gold_answers)
72
- f1_sum += max(squad_metrics.compute_f1(a, predictions) for a in gold_answers)
73
- else:
74
- em_sum += max(squad_metrics.compute_exact(a, predictions) for a in gold_list)
75
- f1_sum += max(squad_metrics.compute_f1(a, predictions) for a in gold_list)
76
 
77
- return {
78
- "em": em_sum / max(1, len(gold_list)),
79
- "f1": f1_sum / max(1, len(gold_list)),
80
- }
81
 
82
-
83
- def process_results_qa(doc: dict[str, str], results):
84
  gold_list = doc_to_target_qa(doc)
85
- pred = results[0].strip().split("\n")[0]
86
- scores = compute_metrics(gold_list, pred)
 
87
  return scores
 
36
 
37
 
38
  def doc_to_text_qa(doc: dict[str, str]) -> str:
39
+ # print('XXX doc_to_text_qa')
40
  doc_text = QA_INSTURCTIONS + "\n\n#Question#: " + doc["question"] + "\n#Answer#: " + doc["answer"] + "\n#Your Judgement#:"
41
  return doc_text
42
 
43
 
44
  def doc_to_target_qa(doc: dict[str, str]) -> str:
45
+ # print('XXX doc_to_target_qa')
46
  return doc['hallucination']
47
 
48
 
49
+ def compute_metrics_qa(gold_answer: str, prediction: str) -> dict[str, float]:
50
+ is_correct = True
 
 
 
 
 
 
 
 
 
51
 
52
+ if ("Yes" in prediction and "No" in prediction) or ("Yes" not in prediction and "No" not in prediction):
53
+ is_correct = False
54
+ elif "Yes" in prediction:
55
+ prediction = "yes"
56
+ elif "No" in prediction:
57
+ prediction = "no"
58
 
59
+ is_exact = (gold_answer == prediction)
 
 
60
 
61
+ res = {"correctness": 1.0 if is_correct else 0.0}
62
+ if is_correct:
63
+ res["em"] = 1.0 if is_exact else 0.0
64
 
65
+ return res
 
 
 
 
 
 
 
 
66
 
 
 
 
 
67
 
68
+ def process_results_qa(doc: dict[str, str], results: list[str]):
69
+ # results is e.g., ['Yes']
70
  gold_list = doc_to_target_qa(doc)
71
+ # gold_list is e.g., 'yes'
72
+ prediction = results[0].strip().split("\n")[0]
73
+ scores = compute_metrics_qa(gold_list, prediction)
74
  return scores