File size: 3,311 Bytes
a8f310f
4868000
f8c1b22
4868000
 
 
 
 
 
 
 
 
 
a8f310f
4868000
a8f310f
 
 
 
 
 
 
 
 
8d5bd0c
 
 
 
a8f310f
 
 
8d5bd0c
 
 
 
a8f310f
8d5bd0c
a8f310f
8d5bd0c
4868000
 
 
 
 
8d5bd0c
 
 
 
 
 
 
 
 
a8f310f
 
 
8d5bd0c
 
a8f310f
8d5bd0c
 
a8f310f
8d5bd0c
 
 
 
 
 
 
 
 
 
 
a8f310f
 
 
8d5bd0c
a8f310f
 
 
8d5bd0c
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import Any, Dict, List, Optional

from .operator import StreamInstanceOperator


class Tasker:
    pass


class FormTask(Tasker, StreamInstanceOperator):
    inputs: List[str]
    outputs: List[str]
    metrics: List[str]
    augmentable_inputs: List[str] = []

    def verify(self):
        for augmentable_input in self.augmentable_inputs:
            assert (
                augmentable_input in self.inputs
            ), f"augmentable_input f{augmentable_input} is not part of {self.inputs}"

    def process(
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
    ) -> Dict[str, Any]:
        try:
            inputs = {key: instance[key] for key in self.inputs}
        except KeyError as e:
            raise KeyError(
                f"Unexpected FormTask input column names ({[key for key in self.inputs if key not in instance]})."
                f"The available input names: {list(instance.keys())}"
            ) from e
        try:
            outputs = {key: instance[key] for key in self.outputs}
        except KeyError as e:
            raise KeyError(
                f"Unexpected FormTask output column names: {[key for key in self.outputs if key not in instance]}"
                f" \n available names:{list(instance.keys())}\n given output names:{self.outputs}"
            ) from e

        return {
            "inputs": inputs,
            "outputs": outputs,
            "metrics": self.metrics,
        }


class MultipleChoiceTask(FormTask):
    choices_field: str = "choices"
    choices_separator: str = "\n"
    enumeration_suffix: str = ". "
    use_text_in_target: bool = False
    alphabet: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

    def process_single_choice(
        self, choice: str, index: int, use_text: bool = True
    ) -> str:
        try:
            processed_choice = f"{self.alphabet[index]}"
        except IndexError as e:
            raise ValueError(
                f"Too many choices, the length of alphabet '{self.alphabet}': {len(self.alphabet)} is the limit"
            ) from e
        if use_text:
            processed_choice += f"{self.enumeration_suffix}{choice}"
        return processed_choice

    def process_choices(self, choices: List[str]) -> str:
        processed_choices = []
        for index, choice in enumerate(choices):
            processed_choices.append(self.process_single_choice(choice, index))
        return self.choices_separator.join(processed_choices)

    def process_target(self, choices, target_index):
        return self.process_single_choice(
            choices[target_index], target_index, use_text=self.use_text_in_target
        )

    def process(
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
    ) -> Dict[str, Any]:
        result = super().process(instance, stream_name)
        target_key, target_value = next(iter(result["outputs"].items()))
        choices = result["inputs"][self.choices_field]
        target_index_in_choices = choices.index(target_value)

        processed_choices = self.process_choices(choices)
        processed_target = self.process_target(choices, target_index_in_choices)

        result["inputs"][self.choices_field] = processed_choices
        result["outputs"][target_key] = processed_target

        return result