"""Dialog Serializers. Dialog serializers are the way to take dialog data and turn it into text that can be fed to the model. The format of the dialog is: dialog = [ {"user": "hello", "system": "hi"}, {"user": "kkk", "system": ""}, {"user": "kkk", "system": ""}, ] """ from typing import Any, Dict, List, Optional from .formats import SystemFormat from .operators import InstanceFieldOperator class SerializeDialog(InstanceFieldOperator): """Serializes dialog data for feeding into a model. This class takes structured dialog data and converts it into a text format according to a specified template. It allows for the inclusion or exclusion of system responses and can operate on a per-turn basis or aggregate the entire dialog. Attributes: field (str): The field in the input data that contains the dialog. to_field (Optional[str]): The field in the output data where the serialized dialog will be stored. last_user_turn_to_field (Optional[str]): Field to store the last user turn. last_system_turn_to_field (Optional[str]): Field to store the last system turn. context_field (Optional[str]): Field that contains additional context to be prepended to the dialog. """ format: Optional[SystemFormat] = None last_response_to_field: Optional[str] = None context_field: Optional[str] = None context_separator: str = " " def standardize_format(self, demo_format): turn_format = demo_format.replace("{source}", "{user}") turn_format = turn_format.replace("{target}", "{system}") return turn_format.replace("{target_prefix}", "") def slice_first_turn(self, turn_format): return turn_format[turn_format.index("{user}") :] def slice_last_turn(self, turn_format): return turn_format[: turn_format.index("{system}") + len("{system}")] def slice_last_response(self, turn_format): return turn_format[: turn_format.index("{user}") + len("{user}")] def get_turn_format(self, turn_format, step, length): if step == 0: turn_format = self.slice_first_turn(turn_format) if step == length - 1: turn_format = self.slice_last_turn(turn_format) if self.last_response_to_field is not None: turn_format = self.slice_last_response(turn_format) return turn_format def get_general_turn_format(self, instance): general_format = ( instance["recipe_metadata"]["format"] if self.format is None else self.format ) return self.standardize_format(general_format.demo_format) def process_instance_value( self, structured_dialog: List[Dict[str, str]], instance: Dict[str, Any] ): dialog = ( "" if self.context_field is None else instance[self.context_field] + self.context_separator ) general_turn_format = self.get_general_turn_format(instance) for i, turn in enumerate(structured_dialog): turn_format = self.get_turn_format( general_turn_format, i, len(structured_dialog) ) dialog += turn_format.format(**turn) if self.last_response_to_field is not None: instance[self.last_response_to_field] = turn["system"] return dialog