File size: 3,374 Bytes
970dac4 4aee30b 970dac4 4aee30b 970dac4 4aee30b 970dac4 4aee30b 970dac4 4aee30b 970dac4 4aee30b 970dac4 4aee30b 970dac4 4aee30b 970dac4 4aee30b 970dac4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
"""Dialog Serializers.
Dialog serializers are the way to take dialog data and turn it into
text that can be fed to the model.
The format of the dialog is:
dialog = [
{"user": "hello", "system": "hi"},
{"user": "kkk", "system": ""},
{"user": "kkk", "system": ""},
]
"""
from typing import Any, Dict, List, Optional
from .formats import SystemFormat
from .operators import InstanceFieldOperator
class SerializeDialog(InstanceFieldOperator):
"""Serializes dialog data for feeding into a model.
This class takes structured dialog data and converts it into a text format
according to a specified template. It allows for the inclusion or exclusion
of system responses and can operate on a per-turn basis or aggregate the entire
dialog.
Attributes:
field (str): The field in the input data that contains the dialog.
to_field (Optional[str]): The field in the output data where the serialized dialog will be stored.
last_user_turn_to_field (Optional[str]): Field to store the last user turn.
last_system_turn_to_field (Optional[str]): Field to store the last system turn.
context_field (Optional[str]): Field that contains additional context to be prepended to the dialog.
"""
format: Optional[SystemFormat] = None
last_response_to_field: Optional[str] = None
context_field: Optional[str] = None
context_separator: str = " "
def standardize_format(self, demo_format):
turn_format = demo_format.replace("{source}", "{user}")
turn_format = turn_format.replace("{target}", "{system}")
return turn_format.replace("{target_prefix}", "")
def slice_first_turn(self, turn_format):
return turn_format[turn_format.index("{user}") :]
def slice_last_turn(self, turn_format):
return turn_format[: turn_format.index("{system}") + len("{system}")]
def slice_last_response(self, turn_format):
return turn_format[: turn_format.index("{user}") + len("{user}")]
def get_turn_format(self, turn_format, step, length):
if step == 0:
turn_format = self.slice_first_turn(turn_format)
if step == length - 1:
turn_format = self.slice_last_turn(turn_format)
if self.last_response_to_field is not None:
turn_format = self.slice_last_response(turn_format)
return turn_format
def get_general_turn_format(self, instance):
general_format = (
instance["recipe_metadata"]["format"]
if self.format is None
else self.format
)
return self.standardize_format(general_format.demo_format)
def process_instance_value(
self, structured_dialog: List[Dict[str, str]], instance: Dict[str, Any]
):
dialog = (
""
if self.context_field is None
else instance[self.context_field] + self.context_separator
)
general_turn_format = self.get_general_turn_format(instance)
for i, turn in enumerate(structured_dialog):
turn_format = self.get_turn_format(
general_turn_format, i, len(structured_dialog)
)
dialog += turn_format.format(**turn)
if self.last_response_to_field is not None:
instance[self.last_response_to_field] = turn["system"]
return dialog
|