File size: 3,374 Bytes
970dac4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aee30b
970dac4
4aee30b
970dac4
 
 
 
 
 
 
 
 
 
4aee30b
970dac4
 
 
 
 
 
 
 
4aee30b
970dac4
 
 
 
 
 
 
 
4aee30b
970dac4
 
4aee30b
970dac4
 
 
 
4aee30b
970dac4
 
4aee30b
970dac4
4aee30b
970dac4
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Dialog Serializers.

Dialog serializers are the way to take dialog data and turn it into
text that can be fed to the model.

The format of the dialog is:

dialog = [
    {"user": "hello", "system": "hi"},
    {"user": "kkk", "system": ""},
    {"user": "kkk", "system": ""},
]
"""
from typing import Any, Dict, List, Optional

from .formats import SystemFormat
from .operators import InstanceFieldOperator


class SerializeDialog(InstanceFieldOperator):
    """Serializes dialog data for feeding into a model.

    This class takes structured dialog data and converts it into a text format
    according to a specified template. It allows for the inclusion or exclusion
    of system responses and can operate on a per-turn basis or aggregate the entire
    dialog.

    Attributes:
        field (str): The field in the input data that contains the dialog.
        to_field (Optional[str]): The field in the output data where the serialized dialog will be stored.
        last_user_turn_to_field (Optional[str]): Field to store the last user turn.
        last_system_turn_to_field (Optional[str]): Field to store the last system turn.
        context_field (Optional[str]): Field that contains additional context to be prepended to the dialog.
    """

    format: Optional[SystemFormat] = None
    last_response_to_field: Optional[str] = None
    context_field: Optional[str] = None
    context_separator: str = " "

    def standardize_format(self, demo_format):
        turn_format = demo_format.replace("{source}", "{user}")
        turn_format = turn_format.replace("{target}", "{system}")
        return turn_format.replace("{target_prefix}", "")

    def slice_first_turn(self, turn_format):
        return turn_format[turn_format.index("{user}") :]

    def slice_last_turn(self, turn_format):
        return turn_format[: turn_format.index("{system}") + len("{system}")]

    def slice_last_response(self, turn_format):
        return turn_format[: turn_format.index("{user}") + len("{user}")]

    def get_turn_format(self, turn_format, step, length):
        if step == 0:
            turn_format = self.slice_first_turn(turn_format)
        if step == length - 1:
            turn_format = self.slice_last_turn(turn_format)
            if self.last_response_to_field is not None:
                turn_format = self.slice_last_response(turn_format)
        return turn_format

    def get_general_turn_format(self, instance):
        general_format = (
            instance["recipe_metadata"]["format"]
            if self.format is None
            else self.format
        )
        return self.standardize_format(general_format.demo_format)

    def process_instance_value(
        self, structured_dialog: List[Dict[str, str]], instance: Dict[str, Any]
    ):
        dialog = (
            ""
            if self.context_field is None
            else instance[self.context_field] + self.context_separator
        )
        general_turn_format = self.get_general_turn_format(instance)
        for i, turn in enumerate(structured_dialog):
            turn_format = self.get_turn_format(
                general_turn_format, i, len(structured_dialog)
            )
            dialog += turn_format.format(**turn)
        if self.last_response_to_field is not None:
            instance[self.last_response_to_field] = turn["system"]
        return dialog