Elron commited on
Commit
066c396
·
1 Parent(s): a254196

Upload schema.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. schema.py +56 -0
schema.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import Features, Sequence, Value
2
+ from .operator import StreamInstanceOperatorValidator
3
+
4
+ from typing import Dict, Any, List
5
+
6
+ from dataclasses import field
7
+
8
+ UNITXT_DATASET_SCHEMA = Features(
9
+ {
10
+ "source": Value("string"),
11
+ "target": Value("string"),
12
+ "references": Sequence(Value("string")),
13
+ "metrics": Sequence(Value("string")),
14
+ "group": Value("string"),
15
+ "postprocessors": Sequence(Value("string")),
16
+ }
17
+ )
18
+
19
+ # UNITXT_METRIC_SCHEMA = Features({
20
+ # "predictions": Value("string", id="sequence"),
21
+ # "target": Value("string", id="sequence"),
22
+ # "references": Value("string", id="sequence"),
23
+ # "metrics": Value("string", id="sequence"),
24
+ # 'group': Value('string'),
25
+ # 'postprocessors': Value("string", id="sequence"),
26
+ # })
27
+
28
+
29
+ class ToUnitxtGroup(StreamInstanceOperatorValidator):
30
+ group: str
31
+ metrics: List[str] = None
32
+ postprocessors: List[str] = field(default_factory=lambda: ["to_string"])
33
+ remove_unnecessary_fields: bool = True
34
+
35
+ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
36
+ if self.remove_unnecessary_fields:
37
+ for key in instance.keys():
38
+ if key not in UNITXT_DATASET_SCHEMA:
39
+ del instance[key]
40
+
41
+ instance["group"] = self.group
42
+ if self.metrics is not None:
43
+ instance["metrics"] = self.metrics
44
+ if self.postprocessors is not None:
45
+ instance["postprocessors"] = self.postprocessors
46
+
47
+ return instance
48
+
49
+ def validate(self, instance: Dict[str, Any], stream_name: str = None):
50
+ # verify the instance has the required schema
51
+ assert instance is not None, f"Instance is None"
52
+ assert isinstance(instance, dict), f"Instance should be a dict, got {type(instance)}"
53
+ assert all(
54
+ [key in instance for key in UNITXT_DATASET_SCHEMA]
55
+ ), f"Instance should have the following keys: {UNITXT_DATASET_SCHEMA}"
56
+ UNITXT_DATASET_SCHEMA.encode_example(instance)