# chat helper class ChatState: def __init__(self, model, system="", chat_template="auto"): chat_template = ( type(model).__name__ if chat_template == "auto" else chat_template ) if chat_template == "Llama3CausalLM": self.__START_TURN_SYSTEM__ = ( "<|start_header_id|>system<|end_header_id|>\n\n" ) self.__START_TURN_USER__ = ( "<|start_header_id|>user<|end_header_id|>\n\n" ) self.__START_TURN_MODEL__ = ( "<|start_header_id|>assistant<|end_header_id|>\n\n" ) self.__END_TURN_SYSTEM__ = "<|eot_id|>" self.__END_TURN_USER__ = "<|eot_id|>" self.__END_TURN_MODEL__ = "<|eot_id|>" print("Using chat template for: Llama") elif chat_template == "GemmaCausalLM": self.__START_TURN_SYSTEM__ = "" self.__START_TURN_USER__ = "user\n" self.__START_TURN_MODEL__ = "model\n" self.__END_TURN_SYSTEM__ = "\n" self.__END_TURN_USER__ = "\n" self.__END_TURN_MODEL__ = "\n" print("Using chat template for: Gemma") elif chat_template == "MistralCausalLM": self.__START_TURN_SYSTEM__ = "" self.__START_TURN_USER__ = "[INST]" self.__START_TURN_MODEL__ = "" self.__END_TURN_SYSTEM__ = "" self.__END_TURN_USER__ = "[/INST]" self.__END_TURN_MODEL__ = "" print("Using chat template for: Mistral") elif chat_template == "Vicuna": self.__START_TURN_SYSTEM__ = "" self.__START_TURN_USER__ = "USER: " self.__START_TURN_MODEL__ = "ASSISTANT: " self.__END_TURN_SYSTEM__ = "\n\n" self.__END_TURN_USER__ = "\n" self.__END_TURN_MODEL__ = "\n" print("Using chat template for : Vicuna") else: assert (0, "Unknown turn tags for this model class") self.model = model self.system = system self.history = [] def add_to_history_as_user(self, message): self.history.append( self.__START_TURN_USER__ + message + self.__END_TURN_USER__ ) def add_to_history_as_model(self, message): self.history.append( self.__START_TURN_MODEL__ + message + self.__END_TURN_MODEL__ ) def get_history(self): return "".join([*self.history]) def get_full_prompt(self): prompt = self.get_history() + self.__START_TURN_MODEL__ if len(self.system) > 0: prompt = ( self.__START_TURN_SYSTEM__ + self.system + self.__END_TURN_SYSTEM__ + prompt ) return prompt def send_message(self, message): """ Handles sending a user message and getting a model response. Args: message: The user's message. Returns: The model's response. """ self.add_to_history_as_user(message) prompt = self.get_full_prompt() response = self.model.generate( prompt, max_length=2048, strip_prompt=True ) self.add_to_history_as_model(response) return (message, response)