Elron commited on
Commit
7f6dcb7
1 Parent(s): 98cdd32

Upload schema.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. schema.py +25 -9
schema.py CHANGED
@@ -1,5 +1,5 @@
1
  from dataclasses import field
2
- from typing import Any, Dict, List
3
 
4
  from datasets import Features, Sequence, Value
5
 
@@ -13,6 +13,9 @@ UNITXT_DATASET_SCHEMA = Features(
13
  "metrics": Sequence(Value("string")),
14
  "group": Value("string"),
15
  "postprocessors": Sequence(Value("string")),
 
 
 
16
  }
17
  )
18
 
@@ -32,7 +35,20 @@ class ToUnitxtGroup(StreamInstanceOperatorValidator):
32
  postprocessors: List[str] = field(default_factory=lambda: ["to_string_stripped"])
33
  remove_unnecessary_fields: bool = True
34
 
35
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  if self.remove_unnecessary_fields:
37
  keys_to_delete = []
38
 
@@ -42,20 +58,20 @@ class ToUnitxtGroup(StreamInstanceOperatorValidator):
42
 
43
  for key in keys_to_delete:
44
  del instance[key]
45
-
46
  instance["group"] = self.group
47
  if self.metrics is not None:
48
  instance["metrics"] = self.metrics
49
  if self.postprocessors is not None:
50
  instance["postprocessors"] = self.postprocessors
51
-
52
  return instance
53
 
54
- def validate(self, instance: Dict[str, Any], stream_name: str = None):
55
  # verify the instance has the required schema
56
- assert instance is not None, f"Instance is None"
57
- assert isinstance(instance, dict), f"Instance should be a dict, got {type(instance)}"
 
 
58
  assert all(
59
- [key in instance for key in UNITXT_DATASET_SCHEMA]
60
- ), f"Instance should have the following keys: {UNITXT_DATASET_SCHEMA}"
61
  UNITXT_DATASET_SCHEMA.encode_example(instance)
 
1
  from dataclasses import field
2
+ from typing import Any, Dict, List, Optional
3
 
4
  from datasets import Features, Sequence, Value
5
 
 
13
  "metrics": Sequence(Value("string")),
14
  "group": Value("string"),
15
  "postprocessors": Sequence(Value("string")),
16
+ "additional_inputs": Sequence(
17
+ {"key": Value(dtype="string"), "value": Value("string")}
18
+ ),
19
  }
20
  )
21
 
 
35
  postprocessors: List[str] = field(default_factory=lambda: ["to_string_stripped"])
36
  remove_unnecessary_fields: bool = True
37
 
38
+ def _to_lists_of_keys_and_values(self, dict: Dict[str, str]):
39
+ return {
40
+ "key": [key for key, _ in dict.items()],
41
+ "value": [str(value) for _, value in dict.items()],
42
+ }
43
+
44
+ def process(
45
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
46
+ ) -> Dict[str, Any]:
47
+ additional_inputs = {**instance["inputs"], **instance["outputs"]}
48
+ instance["additional_inputs"] = self._to_lists_of_keys_and_values(
49
+ additional_inputs
50
+ )
51
+
52
  if self.remove_unnecessary_fields:
53
  keys_to_delete = []
54
 
 
58
 
59
  for key in keys_to_delete:
60
  del instance[key]
 
61
  instance["group"] = self.group
62
  if self.metrics is not None:
63
  instance["metrics"] = self.metrics
64
  if self.postprocessors is not None:
65
  instance["postprocessors"] = self.postprocessors
 
66
  return instance
67
 
68
+ def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None):
69
  # verify the instance has the required schema
70
+ assert instance is not None, "Instance is None"
71
+ assert isinstance(
72
+ instance, dict
73
+ ), f"Instance should be a dict, got {type(instance)}"
74
  assert all(
75
+ key in instance for key in UNITXT_DATASET_SCHEMA
76
+ ), f"Instance should have the following keys: {UNITXT_DATASET_SCHEMA}. Instance is: {instance}"
77
  UNITXT_DATASET_SCHEMA.encode_example(instance)