facat commited on
Commit
33a6f85
1 Parent(s): 9199665
Files changed (3) hide show
  1. tasks.py +22 -4
  2. tlem.py +12 -37
  3. utils.py +28 -0
tasks.py CHANGED
@@ -6,11 +6,20 @@ from typing import Any, Optional, Protocol, Iterable, Callable
6
  import logging
7
  import pandas as pd
8
  from functools import partial
 
9
 
10
  from .utils import *
11
 
12
  from evaluate import load
13
  from collections import defaultdict
 
 
 
 
 
 
 
 
14
 
15
 
16
  def fake_pipeline(prompts: Iterable[str]) -> list[str]:
@@ -100,6 +109,7 @@ class Task:
100
  )
101
  return metric
102
 
 
103
  def run(
104
  self,
105
  pipeline,
@@ -129,14 +139,20 @@ def multichoice(responses: Any, references: list[str]):
129
  else:
130
  responses = decode_choice(responses)
131
 
132
- # return [
133
- # int(response == reference) for reference, response in zip(references, responses)
134
- # ]
 
 
 
 
 
 
135
  return responses, references
136
 
137
 
138
  class Metrics:
139
- cmmlu = multichoice
140
  mmlu = multichoice
141
 
142
  def gsm8k(responses: list[str], answers: list[str | int]):
@@ -299,6 +315,7 @@ class CMMLU:
299
  .to_dict()
300
  )
301
  suite = defaultdict(list)
 
302
  for k, v in cls.categories.items():
303
  for subject in v:
304
  suite[k].extend(
@@ -429,6 +446,7 @@ class MMLU:
429
  .to_dict()
430
  )
431
  suite = defaultdict(list)
 
432
  for k, v in cls.categories.items():
433
  for subject in v:
434
  suite[k].extend(
 
6
  import logging
7
  import pandas as pd
8
  from functools import partial
9
+ from datasets.utils.logging import disable_progress_bar
10
 
11
  from .utils import *
12
 
13
  from evaluate import load
14
  from collections import defaultdict
15
+ import sys
16
+
17
+ # if sys.version_info >= (3, 9):
18
+ # from functools import cache
19
+ # else:
20
+ # from functools import lru_cache as cache
21
+
22
+ disable_progress_bar()
23
 
24
 
25
  def fake_pipeline(prompts: Iterable[str]) -> list[str]:
 
109
  )
110
  return metric
111
 
112
+ # @cache
113
  def run(
114
  self,
115
  pipeline,
 
139
  else:
140
  responses = decode_choice(responses)
141
 
142
+ return responses, references
143
+
144
+
145
+ def multichoice_zh(responses: Any, references: list[str]):
146
+ if isinstance(responses[0], str):
147
+ responses = [extract_choice_zh(response) for response in responses]
148
+ else:
149
+ responses = decode_choice(responses)
150
+
151
  return responses, references
152
 
153
 
154
  class Metrics:
155
+ cmmlu = multichoice_zh
156
  mmlu = multichoice
157
 
158
  def gsm8k(responses: list[str], answers: list[str | int]):
 
315
  .to_dict()
316
  )
317
  suite = defaultdict(list)
318
+ cls.categories["all"] = list(finer_categories.keys())
319
  for k, v in cls.categories.items():
320
  for subject in v:
321
  suite[k].extend(
 
446
  .to_dict()
447
  )
448
  suite = defaultdict(list)
449
+ cls.categories["all"] = list(finer_categories.keys())
450
  for k, v in cls.categories.items():
451
  for subject in v:
452
  suite[k].extend(
tlem.py CHANGED
@@ -16,32 +16,7 @@ import pandas as pd
16
  from .tasks import *
17
  from .utils import is_equiv
18
 
19
- # %%
20
-
21
- # %cd ../tlem
22
-
23
- # %load_ext ipytorch
24
- # %ls
25
-
26
-
27
- # TODO: Add BibTeX citation
28
- _CITATION = """\
29
- """
30
-
31
- # TODO: Add description of the module here
32
- _DESCRIPTION = """\
33
- """
34
 
35
-
36
- # TODO: Add description of the arguments of the module here
37
- _KWARGS_DESCRIPTION = """
38
- """
39
-
40
- # TODO: Define external resources urls if needed
41
- BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
42
-
43
-
44
- # @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
45
  class ReasoningMetric(evaluate.Metric):
46
  """TODO: Short description of my evaluation module."""
47
 
@@ -59,9 +34,9 @@ class ReasoningMetric(evaluate.Metric):
59
  return evaluate.EvaluationModuleInfo(
60
  # This is the description that will appear on the modules page.
61
  # module_type="measurement",
62
- description=_DESCRIPTION,
63
- citation=_CITATION,
64
- inputs_description=_KWARGS_DESCRIPTION,
65
  # This defines the format of each prediction and reference
66
  features=features,
67
  # Homepage of the module for documentation
@@ -106,26 +81,30 @@ class Suite(EvaluationSuite):
106
  def run(
107
  self,
108
  model_or_pipeline: Any,
109
- name="tlem",
110
  ) -> dict[str, float]:
111
  self.assert_suite_nonempty()
112
 
113
  def run_tasks(tasks):
114
- for task in tqdm(tasks):
 
115
  if task.name not in self.cached_result:
116
  self.cached_result[task.name] = task.run(model_or_pipeline)
117
  results = [self.cached_result[task.name] for task in tasks]
118
  return pd.DataFrame(results).mean().to_dict()
119
 
120
  if isinstance(self.suite, dict):
121
- for category, tasks in tqdm(self.suite.items()):
122
- logging.warning(f"Combined results: {category}:{run_tasks(tasks)}")
 
123
  else:
124
  logging.warning(f"Combined results: {run_tasks(self.suite)}")
125
 
126
  return self.cached_result
127
 
128
  def add(self, name):
 
 
 
129
  chat = False
130
  match name:
131
  case _ if "chat" in name:
@@ -146,8 +125,4 @@ class Suite(EvaluationSuite):
146
  def __init__(self, name="tlem"):
147
  super().__init__(name)
148
  self.cached_result = {}
149
-
150
- self.suite = [
151
- # TASK_REGISTRY["gsm8k"],
152
- # TASK_REGISTRY["competition_math"],
153
- ]
 
16
  from .tasks import *
17
  from .utils import is_equiv
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
 
 
 
 
 
 
 
 
 
 
20
  class ReasoningMetric(evaluate.Metric):
21
  """TODO: Short description of my evaluation module."""
22
 
 
34
  return evaluate.EvaluationModuleInfo(
35
  # This is the description that will appear on the modules page.
36
  # module_type="measurement",
37
+ description="",
38
+ citation="",
39
+ inputs_description="",
40
  # This defines the format of each prediction and reference
41
  features=features,
42
  # Homepage of the module for documentation
 
81
  def run(
82
  self,
83
  model_or_pipeline: Any,
 
84
  ) -> dict[str, float]:
85
  self.assert_suite_nonempty()
86
 
87
  def run_tasks(tasks):
88
+ for task in (bar := tqdm(tasks, leave=False)):
89
+ bar.desc = f"complete {task.name}."
90
  if task.name not in self.cached_result:
91
  self.cached_result[task.name] = task.run(model_or_pipeline)
92
  results = [self.cached_result[task.name] for task in tasks]
93
  return pd.DataFrame(results).mean().to_dict()
94
 
95
  if isinstance(self.suite, dict):
96
+ for category, tasks in (bar := tqdm(self.suite.items())):
97
+ bar.desc = f"complete {category}."
98
+ logging.warning(f"Combined results {category}: {run_tasks(tasks)}")
99
  else:
100
  logging.warning(f"Combined results: {run_tasks(self.suite)}")
101
 
102
  return self.cached_result
103
 
104
  def add(self, name):
105
+ self.load(name)
106
+
107
+ def load(self, name):
108
  chat = False
109
  match name:
110
  case _ if "chat" in name:
 
125
  def __init__(self, name="tlem"):
126
  super().__init__(name)
127
  self.cached_result = {}
128
+ self.suite = []
 
 
 
 
utils.py CHANGED
@@ -9,6 +9,34 @@ NUMERIC_IN_ZH = (
9
  )
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def extract_choice(gen):
13
  # answer is A | choice is A | choose A
14
  res = re.search(
 
9
  )
10
 
11
 
12
+ def extract_choice_zh(gen):
13
+ # 答案是A | 选项是A | 应该选A选项
14
+ res = re.search(
15
+ r"(?:(?:选|选择|选定)[::]?\s*|(?:(?:答案|选项)(?![^ABCD]{0,10}?(?:不|非)[^ABCD]{0,10}?(?:是|选|为|:|:|】))[^ABCD]{0,10}?(?:是|选|为|:|:|】))[^ABCD]{0,10}?)(A|B|C|D)(?:选项)?(?:\)|。|\.|,|,|.|、|A|B|C|D|$|:|:|\)|))",
16
+ gen,
17
+ )
18
+
19
+ # A选项正确 | A选项符合题意
20
+ if res is None:
21
+ res = re.search(
22
+ r"(A|B|C|D)(?:选?项)?(?![^ABCD]{0,4}?(?:不|非)[^ABCD]{0,4}?(?:正确|对[的,。:]|符合))[^ABCD]{0,4}?(?:正确|对[的,。:]|符合)",
23
+ gen,
24
+ )
25
+
26
+ # 直接输出 A
27
+ if res is None:
28
+ res = re.search(r"^[\((]?(A|B|C|D)(?:。|\)|)|\.|,|,|.|:|:|$)", gen)
29
+
30
+ # 获取第一个出现的字母
31
+ if res is None:
32
+ res = re.search(r"(?<![a-zA-Z])(A|B|C|D)(?![a-zA-Z=])", gen)
33
+ if res is None:
34
+ res = "A"
35
+ if isinstance(res, str):
36
+ return res
37
+ return res.group(1)
38
+
39
+
40
  def extract_choice(gen):
41
  # answer is A | choice is A | choose A
42
  res = re.search(