File size: 2,772 Bytes
47125b5 7f6dcb7 066c396 47125b5 066c396 47125b5 066c396 7f6dcb7 066c396 8f519a2 066c396 7f6dcb7 066c396 f4655a2 066c396 f4655a2 066c396 7f6dcb7 066c396 7f6dcb7 066c396 7f6dcb7 066c396 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from dataclasses import field
from typing import Any, Dict, List, Optional
from datasets import Features, Sequence, Value
from .operator import StreamInstanceOperatorValidator
UNITXT_DATASET_SCHEMA = Features(
{
"source": Value("string"),
"target": Value("string"),
"references": Sequence(Value("string")),
"metrics": Sequence(Value("string")),
"group": Value("string"),
"postprocessors": Sequence(Value("string")),
"additional_inputs": Sequence(
{"key": Value(dtype="string"), "value": Value("string")}
),
}
)
# UNITXT_METRIC_SCHEMA = Features({
# "predictions": Value("string", id="sequence"),
# "target": Value("string", id="sequence"),
# "references": Value("string", id="sequence"),
# "metrics": Value("string", id="sequence"),
# 'group': Value('string'),
# 'postprocessors': Value("string", id="sequence"),
# })
class ToUnitxtGroup(StreamInstanceOperatorValidator):
group: str
metrics: List[str] = None
postprocessors: List[str] = field(default_factory=lambda: ["to_string_stripped"])
remove_unnecessary_fields: bool = True
def _to_lists_of_keys_and_values(self, dict: Dict[str, str]):
return {
"key": [key for key, _ in dict.items()],
"value": [str(value) for _, value in dict.items()],
}
def process(
self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:
additional_inputs = {**instance["inputs"], **instance["outputs"]}
instance["additional_inputs"] = self._to_lists_of_keys_and_values(
additional_inputs
)
if self.remove_unnecessary_fields:
keys_to_delete = []
for key in instance.keys():
if key not in UNITXT_DATASET_SCHEMA:
keys_to_delete.append(key)
for key in keys_to_delete:
del instance[key]
instance["group"] = self.group
if self.metrics is not None:
instance["metrics"] = self.metrics
if self.postprocessors is not None:
instance["postprocessors"] = self.postprocessors
return instance
def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None):
# verify the instance has the required schema
assert instance is not None, "Instance is None"
assert isinstance(
instance, dict
), f"Instance should be a dict, got {type(instance)}"
assert all(
key in instance for key in UNITXT_DATASET_SCHEMA
), f"Instance should have the following keys: {UNITXT_DATASET_SCHEMA}. Instance is: {instance}"
UNITXT_DATASET_SCHEMA.encode_example(instance)
|