|
from dataclasses import field |
|
from typing import Any, Dict, List, Optional |
|
|
|
from datasets import Features, Sequence, Value |
|
|
|
from .operator import StreamInstanceOperatorValidator |
|
|
|
UNITXT_DATASET_SCHEMA = Features( |
|
{ |
|
"source": Value("string"), |
|
"target": Value("string"), |
|
"references": Sequence(Value("string")), |
|
"metrics": Sequence(Value("string")), |
|
"group": Value("string"), |
|
"postprocessors": Sequence(Value("string")), |
|
"additional_inputs": Sequence( |
|
{"key": Value(dtype="string"), "value": Value("string")} |
|
), |
|
} |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ToUnitxtGroup(StreamInstanceOperatorValidator): |
|
group: str |
|
metrics: List[str] = None |
|
postprocessors: List[str] = field(default_factory=lambda: ["to_string_stripped"]) |
|
remove_unnecessary_fields: bool = True |
|
|
|
def _to_lists_of_keys_and_values(self, dict: Dict[str, str]): |
|
return { |
|
"key": [key for key, _ in dict.items()], |
|
"value": [str(value) for _, value in dict.items()], |
|
} |
|
|
|
def process( |
|
self, instance: Dict[str, Any], stream_name: Optional[str] = None |
|
) -> Dict[str, Any]: |
|
additional_inputs = {**instance["inputs"], **instance["outputs"]} |
|
instance["additional_inputs"] = self._to_lists_of_keys_and_values( |
|
additional_inputs |
|
) |
|
|
|
if self.remove_unnecessary_fields: |
|
keys_to_delete = [] |
|
|
|
for key in instance.keys(): |
|
if key not in UNITXT_DATASET_SCHEMA: |
|
keys_to_delete.append(key) |
|
|
|
for key in keys_to_delete: |
|
del instance[key] |
|
instance["group"] = self.group |
|
if self.metrics is not None: |
|
instance["metrics"] = self.metrics |
|
if self.postprocessors is not None: |
|
instance["postprocessors"] = self.postprocessors |
|
return instance |
|
|
|
def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None): |
|
|
|
assert instance is not None, "Instance is None" |
|
assert isinstance( |
|
instance, dict |
|
), f"Instance should be a dict, got {type(instance)}" |
|
assert all( |
|
key in instance for key in UNITXT_DATASET_SCHEMA |
|
), f"Instance should have the following keys: {UNITXT_DATASET_SCHEMA}. Instance is: {instance}" |
|
UNITXT_DATASET_SCHEMA.encode_example(instance) |
|
|