Cookize commited on
Commit
7e92c24
1 Parent(s): 488ef58

ADD: BoolQ, TurthfulQA (#5)

Browse files

- ADD: truthfulqa_mc1, boolq FIX: CMMLU prompt (ed7a3db35a638f10b2693888d5c96306422ef09a)

Files changed (3) hide show
  1. tasks.py +83 -0
  2. tlem.py +4 -1
  3. utils.py +5 -5
tasks.py CHANGED
@@ -209,6 +209,7 @@ def multichoice_zh(responses: Any, references: list[str]):
209
  class Metrics:
210
  cmmlu = multichoice_zh
211
  mmlu = multichoice
 
212
  ceval = multichoice_zh
213
 
214
  def winogrande(responses: list[str], answers: list[str | int]):
@@ -269,6 +270,13 @@ class Metrics:
269
 
270
  return responses, answers
271
 
 
 
 
 
 
 
 
272
  def MATH(responses: list[str], answers: list[str]):
273
  extract_responses = sync_pipe(get_answer)(responses)
274
  extract_answers = sync_pipe(get_answer)(answers)
@@ -808,6 +816,81 @@ class BBH:
808
  return suite
809
 
810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
811
  class CEVAL:
812
  input_column = "input"
813
  label_column = "answer"
 
209
  class Metrics:
210
  cmmlu = multichoice_zh
211
  mmlu = multichoice
212
+ truthful_qa_mc1 = multichoice
213
  ceval = multichoice_zh
214
 
215
  def winogrande(responses: list[str], answers: list[str | int]):
 
270
 
271
  return responses, answers
272
 
273
+ def boolq(responses: list[str], answers: list[str | int]):
274
+
275
+ responses = [first_capital_postprocess(response) for response in responses]
276
+ answers = ["A" if answer else "B" for answer in answers]
277
+
278
+ return responses, answers
279
+
280
  def MATH(responses: list[str], answers: list[str]):
281
  extract_responses = sync_pipe(get_answer)(responses)
282
  extract_answers = sync_pipe(get_answer)(answers)
 
816
  return suite
817
 
818
 
819
+ class BoolQ:
820
+ input_column = "input"
821
+ label_column = "answer"
822
+
823
+ @classmethod
824
+ def prompt_boolq(cls, example, chat=False):
825
+
826
+ prompt = f"{example['passage']}\nQuestion: {example['question']}\nA. Yes\nB. No\nAnswer: "
827
+
828
+ return {"input": prompt}
829
+
830
+ @classmethod
831
+ def suite(cls, chat: bool):
832
+
833
+ suite = [
834
+ Task(
835
+ dataset_name="boolq",
836
+ metric_name=("sustech/tlem", "boolq"),
837
+ input_column=cls.input_column,
838
+ label_column=cls.label_column,
839
+ prompt=partial(cls.prompt_boolq, chat=chat),
840
+ few_shot=0 if chat else 5,
841
+ few_shot_from="train",
842
+ split="validation",
843
+ )
844
+ ]
845
+
846
+ return suite
847
+
848
+ class TruthfulQAMC1:
849
+ input_column = "input"
850
+ label_column = "answer"
851
+
852
+ @classmethod
853
+ def prompt_truthful_qa(cls, example):
854
+
855
+ target = example["mc1_targets"]
856
+ choices = target["choices"]
857
+ labels = target["labels"]
858
+
859
+ prompt = f"The following is a multiple-choice question. Please choose the most suitable one as the answer to this question.\n\n"
860
+ prompt += example["question"]
861
+
862
+ answer = []
863
+
864
+ for idx, choice, label in zip(list("ABCDEFGHIJ")[:len(choices)], choices, labels):
865
+
866
+ prompt += f"\n{idx}. {choice}"
867
+
868
+ if label == 1:
869
+ answer = idx
870
+
871
+ prompt += "\nAnswer: "
872
+
873
+ return {
874
+ "input": prompt,
875
+ "answer": answer
876
+ }
877
+
878
+ @classmethod
879
+ def suite(cls):
880
+ suite = [
881
+ Task(
882
+ dataset_name=("truthful_qa", "multiple_choice"),
883
+ metric_name=("sustech/tlem", "truthful_qa_mc1"),
884
+ input_column=cls.input_column,
885
+ label_column=cls.label_column,
886
+ prompt=partial(cls.prompt_truthful_qa),
887
+ few_shot=0,
888
+ split="validation",
889
+ )
890
+ ]
891
+
892
+ return suite
893
+
894
  class CEVAL:
895
  input_column = "input"
896
  label_column = "answer"
tlem.py CHANGED
@@ -151,7 +151,10 @@ class Suite(EvaluationSuite):
151
  suite = DROP.suite()
152
  case "winogrande":
153
  suite = Winogrande.suite()
154
-
 
 
 
155
  case "mt_bench":
156
  suite = Task(
157
  dataset_name="SUSTech/mt_bench_judge",
 
151
  suite = DROP.suite()
152
  case "winogrande":
153
  suite = Winogrande.suite()
154
+ case "truthfulqa_mc1":
155
+ suite = TruthfulQAMC1.suite()
156
+ case _ if name.startswith("boolq"):
157
+ suite = BoolQ.suite(chat=chat)
158
  case "mt_bench":
159
  suite = Task(
160
  dataset_name="SUSTech/mt_bench_judge",
utils.py CHANGED
@@ -74,27 +74,27 @@ def extract_choice_zh(gen):
74
  def extract_choice(gen):
75
  # answer is A | choice is A | choose A
76
  res = re.search(
77
- r"(?:(?:[Cc]hoose)|(?:(?:[Aa]nswer|[Cc]hoice)(?![^ABCD]{0,20}?(?:n't|not))[^ABCD]{0,10}?\b(?:|is|:|be))\b)[^ABCD]{0,20}?\b(A|B|C|D)\b",
78
  gen,
79
  )
80
 
81
  # A is correct | A is right
82
  if res is None:
83
  res = re.search(
84
- r"\b(A|B|C|D)\b(?![^ABCD]{0,8}?(?:n't|not)[^ABCD]{0,5}?(?:correct|right))[^ABCD]{0,10}?\b(?:correct|right)\b",
85
  gen,
86
  )
87
 
88
  # straight answer: A
89
  if res is None:
90
- res = re.search(r"^(A|B|C|D)(?:\.|,|:|$)", gen)
91
 
92
  # simply extract the first appearred letter
93
  if res is None:
94
- res = re.search(r"(?<![a-zA-Z])(A|B|C|D)(?![a-zA-Z=])", gen)
95
 
96
  if res is None:
97
- res = "A"
98
 
99
  if isinstance(res, str):
100
  return res
 
74
  def extract_choice(gen):
75
  # answer is A | choice is A | choose A
76
  res = re.search(
77
+ r"(?:(?:[Cc]hoose)|(?:(?:[Aa]nswer|[Cc]hoice)(?![^ABCDEFGHIJKL]{0,20}?(?:n't|not))[^ABCDEFGHIJKL]{0,10}?\b(?:|is|:|be))\b)[^ABCDEFGHIJKL]{0,20}?\b(A|B|C|D|E|F|G|H|I|J|K|L)\b",
78
  gen,
79
  )
80
 
81
  # A is correct | A is right
82
  if res is None:
83
  res = re.search(
84
+ r"\b(A|B|C|D|E|F|G|H|I|J|K|L)\b(?![^ABCDEFGHIJKL]{0,8}?(?:n't|not)[^ABCDEFGHIJKL]{0,5}?(?:correct|right))[^ABCDEFGHIJKL]{0,10}?\b(?:correct|right)\b",
85
  gen,
86
  )
87
 
88
  # straight answer: A
89
  if res is None:
90
+ res = re.search(r"^(A|B|C|D|E|F|G|H|I|J|K|L)(?:\.|,|:|$)", gen)
91
 
92
  # simply extract the first appearred letter
93
  if res is None:
94
+ res = re.search(r"(?<![a-zA-Z])(A|B|C|D|E|F|G|H|I|J|K|L)(?![a-zA-Z=])", gen)
95
 
96
  if res is None:
97
+ res = "L"
98
 
99
  if isinstance(res, str):
100
  return res