File size: 2,640 Bytes
3da1d9d
47125b5
7f6dcb7
066c396
47125b5
066c396
47125b5
066c396
 
 
 
 
 
 
 
 
3da1d9d
066c396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f519a2
066c396
 
7f6dcb7
 
 
 
 
 
 
 
 
3da1d9d
 
7f6dcb7
066c396
f4655a2
 
066c396
 
f4655a2
 
 
 
066c396
 
 
 
 
 
 
7f6dcb7
066c396
7f6dcb7
 
 
 
066c396
7f6dcb7
 
066c396
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import json
from dataclasses import field
from typing import Any, Dict, List, Optional

from datasets import Features, Sequence, Value

from .operator import StreamInstanceOperatorValidator

UNITXT_DATASET_SCHEMA = Features(
    {
        "source": Value("string"),
        "target": Value("string"),
        "references": Sequence(Value("string")),
        "metrics": Sequence(Value("string")),
        "group": Value("string"),
        "postprocessors": Sequence(Value("string")),
        "task_data": Value(dtype="string"),
    }
)

# UNITXT_METRIC_SCHEMA = Features({
#     "predictions": Value("string", id="sequence"),
#     "target": Value("string", id="sequence"),
#     "references": Value("string", id="sequence"),
#     "metrics": Value("string", id="sequence"),
#     'group': Value('string'),
#     'postprocessors': Value("string", id="sequence"),
# })


class ToUnitxtGroup(StreamInstanceOperatorValidator):
    group: str
    metrics: List[str] = None
    postprocessors: List[str] = field(default_factory=lambda: ["to_string_stripped"])
    remove_unnecessary_fields: bool = True

    def _to_lists_of_keys_and_values(self, dict: Dict[str, str]):
        return {
            "key": [key for key, _ in dict.items()],
            "value": [str(value) for _, value in dict.items()],
        }

    def process(
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
    ) -> Dict[str, Any]:
        task_data = {**instance["inputs"], **instance["outputs"]}
        instance["task_data"] = json.dumps(task_data)

        if self.remove_unnecessary_fields:
            keys_to_delete = []

            for key in instance.keys():
                if key not in UNITXT_DATASET_SCHEMA:
                    keys_to_delete.append(key)

            for key in keys_to_delete:
                del instance[key]
        instance["group"] = self.group
        if self.metrics is not None:
            instance["metrics"] = self.metrics
        if self.postprocessors is not None:
            instance["postprocessors"] = self.postprocessors
        return instance

    def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None):
        # verify the instance has the required schema
        assert instance is not None, "Instance is None"
        assert isinstance(
            instance, dict
        ), f"Instance should be a dict, got {type(instance)}"
        assert all(
            key in instance for key in UNITXT_DATASET_SCHEMA
        ), f"Instance should have the following keys: {UNITXT_DATASET_SCHEMA}. Instance is: {instance}"
        UNITXT_DATASET_SCHEMA.encode_example(instance)