pminervini commited on
Commit
b25a00b
1 Parent(s): d265631
src/backend/tasks/halueval/halueval_dialogue.yaml CHANGED
@@ -7,8 +7,8 @@ validation_split: data
7
  test_split: data
8
  num_fewshot: 0
9
  doc_to_text: !function utils.doc_to_text_dialogue
10
- doc_to_target: !function utils.doc_to_target_qa
11
- process_results: !function utils.process_results_qa
12
  metric_list:
13
  - metric: em
14
  aggregation: mean
 
7
  test_split: data
8
  num_fewshot: 0
9
  doc_to_text: !function utils.doc_to_text_dialogue
10
+ doc_to_target: !function utils.doc_to_target
11
+ process_results: !function utils.process_results
12
  metric_list:
13
  - metric: em
14
  aggregation: mean
src/backend/tasks/halueval/halueval_qa.yaml CHANGED
@@ -7,8 +7,8 @@ validation_split: data
7
  test_split: data
8
  num_fewshot: 0
9
  doc_to_text: !function utils.doc_to_text_qa
10
- doc_to_target: !function utils.doc_to_target_qa
11
- process_results: !function utils.process_results_qa
12
  metric_list:
13
  - metric: em
14
  aggregation: mean
 
7
  test_split: data
8
  num_fewshot: 0
9
  doc_to_text: !function utils.doc_to_text_qa
10
+ doc_to_target: !function utils.doc_to_target
11
+ process_results: !function utils.process_results
12
  metric_list:
13
  - metric: em
14
  aggregation: mean
src/backend/tasks/halueval/halueval_summarization.yaml CHANGED
@@ -7,8 +7,8 @@ validation_split: data
7
  test_split: data
8
  num_fewshot: 0
9
  doc_to_text: !function utils.doc_to_text_summarization
10
- doc_to_target: !function utils.doc_to_target_qa
11
- process_results: !function utils.process_results_qa
12
  metric_list:
13
  - metric: em
14
  aggregation: mean
 
7
  test_split: data
8
  num_fewshot: 0
9
  doc_to_text: !function utils.doc_to_text_summarization
10
+ doc_to_target: !function utils.doc_to_target
11
+ process_results: !function utils.process_results
12
  metric_list:
13
  - metric: em
14
  aggregation: mean
src/backend/tasks/halueval/utils.py CHANGED
@@ -102,11 +102,11 @@ def doc_to_text_summarization(doc: dict[str, str]) -> str:
102
  return doc_text
103
 
104
 
105
- def doc_to_target_qa(doc: dict[str, str]) -> str:
106
  return doc['hallucination']
107
 
108
 
109
- def compute_metrics_qa(gold_answer: str, prediction: str) -> dict[str, float]:
110
  is_correct = True
111
 
112
  if ("Yes" in prediction and "No" in prediction) or ("Yes" not in prediction and "No" not in prediction):
@@ -122,13 +122,15 @@ def compute_metrics_qa(gold_answer: str, prediction: str) -> dict[str, float]:
122
  if is_correct:
123
  res["em"] = 1.0 if is_exact else 0.0
124
 
 
 
125
  return res
126
 
127
 
128
- def process_results_qa(doc: dict[str, str], results: list[str]):
129
  # results is e.g., ['Yes']
130
- gold_list = doc_to_target_qa(doc)
131
  # gold_list is e.g., 'yes'
132
  prediction = results[0].strip().split("\n")[0]
133
- scores = compute_metrics_qa(gold_list, prediction)
134
  return scores
 
102
  return doc_text
103
 
104
 
105
+ def doc_to_target(doc: dict[str, str]) -> str:
106
  return doc['hallucination']
107
 
108
 
109
+ def compute_metrics(gold_answer: str, prediction: str) -> dict[str, float]:
110
  is_correct = True
111
 
112
  if ("Yes" in prediction and "No" in prediction) or ("Yes" not in prediction and "No" not in prediction):
 
122
  if is_correct:
123
  res["em"] = 1.0 if is_exact else 0.0
124
 
125
+ res["acc"] = 1.0 if (is_correct and is_exact) else 0.0
126
+
127
  return res
128
 
129
 
130
+ def process_results(doc: dict[str, str], results: list[str]):
131
  # results is e.g., ['Yes']
132
+ gold_list = doc_to_target(doc)
133
  # gold_list is e.g., 'yes'
134
  prediction = results[0].strip().split("\n")[0]
135
+ scores = compute_metrics(gold_list, prediction)
136
  return scores