keras-chatbot-arena / chatstate.py
martin-gorner's picture
initial commit
2ca0c5e
raw
history blame
3.39 kB
# chat helper
class ChatState:
def __init__(self, model, system="", chat_template="auto"):
chat_template = (
type(model).__name__ if chat_template == "auto" else chat_template
)
if chat_template == "Llama3CausalLM":
self.__START_TURN_SYSTEM__ = (
"<|start_header_id|>system<|end_header_id|>\n\n"
)
self.__START_TURN_USER__ = (
"<|start_header_id|>user<|end_header_id|>\n\n"
)
self.__START_TURN_MODEL__ = (
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
self.__END_TURN_SYSTEM__ = "<|eot_id|>"
self.__END_TURN_USER__ = "<|eot_id|>"
self.__END_TURN_MODEL__ = "<|eot_id|>"
print("Using chat template for: Llama")
elif chat_template == "GemmaCausalLM":
self.__START_TURN_SYSTEM__ = ""
self.__START_TURN_USER__ = "<start_of_turn>user\n"
self.__START_TURN_MODEL__ = "<start_of_turn>model\n"
self.__END_TURN_SYSTEM__ = "\n"
self.__END_TURN_USER__ = "<end_of_turn>\n"
self.__END_TURN_MODEL__ = "<end_of_turn>\n"
print("Using chat template for: Gemma")
elif chat_template == "MistralCausalLM":
self.__START_TURN_SYSTEM__ = ""
self.__START_TURN_USER__ = "[INST]"
self.__START_TURN_MODEL__ = ""
self.__END_TURN_SYSTEM__ = "<s>"
self.__END_TURN_USER__ = "[/INST]"
self.__END_TURN_MODEL__ = "</s>"
print("Using chat template for: Mistral")
elif chat_template == "Vicuna":
self.__START_TURN_SYSTEM__ = ""
self.__START_TURN_USER__ = "USER: "
self.__START_TURN_MODEL__ = "ASSISTANT: "
self.__END_TURN_SYSTEM__ = "\n\n"
self.__END_TURN_USER__ = "\n"
self.__END_TURN_MODEL__ = "</s>\n"
print("Using chat template for : Vicuna")
else:
assert (0, "Unknown turn tags for this model class")
self.model = model
self.system = system
self.history = []
def add_to_history_as_user(self, message):
self.history.append(
self.__START_TURN_USER__ + message + self.__END_TURN_USER__
)
def add_to_history_as_model(self, message):
self.history.append(
self.__START_TURN_MODEL__ + message + self.__END_TURN_MODEL__
)
def get_history(self):
return "".join([*self.history])
def get_full_prompt(self):
prompt = self.get_history() + self.__START_TURN_MODEL__
if len(self.system) > 0:
prompt = (
self.__START_TURN_SYSTEM__
+ self.system
+ self.__END_TURN_SYSTEM__
+ prompt
)
return prompt
def send_message(self, message):
"""
Handles sending a user message and getting a model response.
Args:
message: The user's message.
Returns:
The model's response.
"""
self.add_to_history_as_user(message)
prompt = self.get_full_prompt()
response = self.model.generate(
prompt, max_length=1024, strip_prompt=True
)
self.add_to_history_as_model(response)
return (message, response)