|
"""Dialog Serializers. |
|
|
|
Dialog serializers are the way to take dialog data and turn it into |
|
text that can be fed to the model. |
|
|
|
The format of the dialog is: |
|
|
|
dialog = [ |
|
{"user": "hello", "system": "hi"}, |
|
{"user": "kkk", "system": ""}, |
|
{"user": "kkk", "system": ""}, |
|
] |
|
""" |
|
from typing import Any, Dict, List, Optional |
|
|
|
from .formats import SystemFormat |
|
from .operators import InstanceFieldOperator |
|
|
|
|
|
class SerializeDialog(InstanceFieldOperator): |
|
"""Serializes dialog data for feeding into a model. |
|
|
|
This class takes structured dialog data and converts it into a text format |
|
according to a specified template. It allows for the inclusion or exclusion |
|
of system responses and can operate on a per-turn basis or aggregate the entire |
|
dialog. |
|
|
|
Attributes: |
|
field (str): The field in the input data that contains the dialog. |
|
to_field (Optional[str]): The field in the output data where the serialized dialog will be stored. |
|
last_user_turn_to_field (Optional[str]): Field to store the last user turn. |
|
last_system_turn_to_field (Optional[str]): Field to store the last system turn. |
|
context_field (Optional[str]): Field that contains additional context to be prepended to the dialog. |
|
""" |
|
|
|
format: Optional[SystemFormat] = None |
|
last_response_to_field: Optional[str] = None |
|
context_field: Optional[str] = None |
|
context_separator: str = " " |
|
|
|
def standardize_format(self, demo_format): |
|
turn_format = demo_format.replace("{source}", "{user}") |
|
turn_format = turn_format.replace("{target}", "{system}") |
|
return turn_format.replace("{target_prefix}", "") |
|
|
|
def slice_first_turn(self, turn_format): |
|
return turn_format[turn_format.index("{user}") :] |
|
|
|
def slice_last_turn(self, turn_format): |
|
return turn_format[: turn_format.index("{system}") + len("{system}")] |
|
|
|
def slice_last_response(self, turn_format): |
|
return turn_format[: turn_format.index("{user}") + len("{user}")] |
|
|
|
def get_turn_format(self, turn_format, step, length): |
|
if step == 0: |
|
turn_format = self.slice_first_turn(turn_format) |
|
if step == length - 1: |
|
turn_format = self.slice_last_turn(turn_format) |
|
if self.last_response_to_field is not None: |
|
turn_format = self.slice_last_response(turn_format) |
|
return turn_format |
|
|
|
def get_general_turn_format(self, instance): |
|
general_format = ( |
|
instance["recipe_metadata"]["format"] |
|
if self.format is None |
|
else self.format |
|
) |
|
return self.standardize_format(general_format.demo_format) |
|
|
|
def process_instance_value( |
|
self, structured_dialog: List[Dict[str, str]], instance: Dict[str, Any] |
|
): |
|
dialog = ( |
|
"" |
|
if self.context_field is None |
|
else instance[self.context_field] + self.context_separator |
|
) |
|
general_turn_format = self.get_general_turn_format(instance) |
|
for i, turn in enumerate(structured_dialog): |
|
turn_format = self.get_turn_format( |
|
general_turn_format, i, len(structured_dialog) |
|
) |
|
dialog += turn_format.format(**turn) |
|
if self.last_response_to_field is not None: |
|
instance[self.last_response_to_field] = turn["system"] |
|
return dialog |
|
|