mathemakitten commited on
Commit
0fa8793
1 Parent(s): 101ad9a
Files changed (1) hide show
  1. glue-suite-v2.py +36 -32
glue-suite-v2.py CHANGED
@@ -1,8 +1,9 @@
1
-
2
  from typing import Optional, Union, Callable
3
  from dataclasses import dataclass
4
  from datasets import Dataset
5
 
 
 
6
  @dataclass
7
  class SubTask:
8
  model_or_pipeline: Optional[Union[str, "Pipeline", Callable, "PreTrainedModel", "TFPreTrainedModel"]] = None
@@ -13,35 +14,38 @@ class SubTask:
13
  args_for_task: Optional[dict] = None
14
 
15
 
16
- preprocessor = lambda x: x["text"].lower()
 
 
 
17
 
18
- suite = [
19
- SubTask(
20
- data="imdb",
21
- split="test",
22
- data_preprocessor=preprocessor,
23
- args_for_task={
24
- "metric": "accuracy",
25
- "input_column": "text",
26
- "label_column": "label",
27
- "label_mapping": {
28
- "LABEL_0": 0.0,
29
- "LABEL_1": 1.0
30
- }
31
- }
32
- ),
33
- SubTask(
34
- data="sst2",
35
- split="test[:10]",
36
- data_preprocessor=preprocessor,
37
- args_for_task={
38
- "metric": "accuracy",
39
- "input_column": "sentence",
40
- "label_column": "label",
41
- "label_mapping": {
42
- "LABEL_0": 0.0,
43
- "LABEL_1": 1.0
44
- }
45
- }
46
- )
47
- ]
 
 
1
  from typing import Optional, Union, Callable
2
  from dataclasses import dataclass
3
  from datasets import Dataset
4
 
5
+ from evaluate.evaluation_suite import EvaluationSuite
6
+
7
  @dataclass
8
  class SubTask:
9
  model_or_pipeline: Optional[Union[str, "Pipeline", Callable, "PreTrainedModel", "TFPreTrainedModel"]] = None
 
14
  args_for_task: Optional[dict] = None
15
 
16
 
17
+ class EvaluationSuite:
18
+
19
+ def __init__(self):
20
+ self.preprocessor = None #lambda x: x["text"].lower()
21
 
22
+ self.suite = [
23
+ SubTask(
24
+ data="imdb",
25
+ split="test",
26
+ data_preprocessor=self.preprocessor,
27
+ args_for_task={
28
+ "metric": "accuracy",
29
+ "input_column": "text",
30
+ "label_column": "label",
31
+ "label_mapping": {
32
+ "LABEL_0": 0.0,
33
+ "LABEL_1": 1.0
34
+ }
35
+ }
36
+ ),
37
+ SubTask(
38
+ data="sst2",
39
+ split="test[:10]",
40
+ data_preprocessor=self.preprocessor,
41
+ args_for_task={
42
+ "metric": "accuracy",
43
+ "input_column": "sentence",
44
+ "label_column": "label",
45
+ "label_mapping": {
46
+ "LABEL_0": 0.0,
47
+ "LABEL_1": 1.0
48
+ }
49
+ }
50
+ )
51
+ ]