facat commited on
Commit
84e1d00
1 Parent(s): 5ad9651

output in dataset

Browse files
Files changed (1) hide show
  1. tasks.py +17 -4
tasks.py CHANGED
@@ -62,6 +62,7 @@ class Task:
62
  metric_name: str | tuple[str, str] = ("sustech/tlem", "mmlu")
63
  input_column: str = "question"
64
  label_column: str = ""
 
65
  prompt: Optional[Callable | str] = None
66
  few_shot: int = 0
67
  few_shot_from: Optional[str] = None
@@ -85,7 +86,6 @@ class Task:
85
  )
86
  }
87
  self.label_column = self.label_column or self.input_column
88
- self.outputs = []
89
 
90
  def __eq__(self, __value: object) -> bool:
91
  return self.name == __value.name
@@ -98,6 +98,10 @@ class Task:
98
  def labels(self):
99
  return self.dataset[self.label_column]
100
 
 
 
 
 
101
  @cached_property
102
  def dataset(self):
103
  ds = (
@@ -160,20 +164,29 @@ class Task:
160
  # logging.info(f"{self.name}:{results}")
161
  return results
162
 
163
- # @cache
164
  def run(
165
  self,
166
  pipeline,
167
  ):
168
- self.outputs = self.outputs or pipeline(self.samples)
 
 
 
169
 
170
  return self.result
171
 
172
  async def arun(self, pipeline):
173
- self.outputs = self.outputs or await pipeline(self.samples)
 
 
174
 
175
  return self.result
176
 
 
 
 
 
 
177
 
178
  def multichoice(responses: Any, references: list[str]):
179
  if isinstance(responses[0], str):
 
62
  metric_name: str | tuple[str, str] = ("sustech/tlem", "mmlu")
63
  input_column: str = "question"
64
  label_column: str = ""
65
+ output_column: str = "generated_text"
66
  prompt: Optional[Callable | str] = None
67
  few_shot: int = 0
68
  few_shot_from: Optional[str] = None
 
86
  )
87
  }
88
  self.label_column = self.label_column or self.input_column
 
89
 
90
  def __eq__(self, __value: object) -> bool:
91
  return self.name == __value.name
 
98
  def labels(self):
99
  return self.dataset[self.label_column]
100
 
101
+ @cached_property
102
+ def outputs(self):
103
+ return self.dataset[self.output_column]
104
+
105
  @cached_property
106
  def dataset(self):
107
  ds = (
 
164
  # logging.info(f"{self.name}:{results}")
165
  return results
166
 
 
167
  def run(
168
  self,
169
  pipeline,
170
  ):
171
+ if self.output_column not in self.dataset.column_names:
172
+ self.dataset = self.dataset.add_column(
173
+ self.output_column, pipeline(self.samples)
174
+ )
175
 
176
  return self.result
177
 
178
  async def arun(self, pipeline):
179
+ self.dataset = self.dataset.add_column(
180
+ self.output_column, await pipeline(self.samples)
181
+ )
182
 
183
  return self.result
184
 
185
+ def save(self, path):
186
+ self.dataset.select_columns(
187
+ [self.input_column, self.output_column, self.label_column]
188
+ ).save_to_disk(path)
189
+
190
 
191
  def multichoice(responses: Any, references: list[str]):
192
  if isinstance(responses[0], str):