chansung's picture
Update vid2persona/pipeline/llm.py
62d6792 verified
import toml
import gradio as gr
from string import Template
from transformers import AutoTokenizer
from vid2persona.gen import tgi_openllm
from vid2persona.gen import local_openllm
tokenizer = None
def _get_system_prompt(
personality_json_dict: dict,
prompt_tpl_path: str
) -> str:
"""Assumes a single character is passed."""
prompt_tpl_path = f"{prompt_tpl_path}/llm.toml"
system_prompt = Template(toml.load(prompt_tpl_path)['conversation']['system'])
name = personality_json_dict["name"]
physcial_description = personality_json_dict["physicalDescription"]
personality_traits = [str(trait) for trait in personality_json_dict["personalityTraits"]]
likes = [str(like) for like in personality_json_dict["likes"]]
dislikes = [str(dislike) for dislike in personality_json_dict["dislikes"]]
background = [str(info) for info in personality_json_dict["background"]]
goals = [str(goal) for goal in personality_json_dict["goals"]]
relationships = [str(relationship) for relationship in personality_json_dict["relationships"]]
system_prompt = system_prompt.substitute(
name=name,
physcial_description=physcial_description,
personality_traits=', '.join(personality_traits),
likes=', '.join(likes),
background=', '.join(background),
goals=', '.join(goals),
relationships=', '.join(relationships)
)
return system_prompt
async def chat(
message: str,
chat_history: list[tuple[str, str]],
personality_json_dict: dict,
prompt_tpl_path: str,
model_id: str,
max_input_token_length: int,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: float,
hf_token: str,
):
messages = []
system_prompt = _get_system_prompt(personality_json_dict, prompt_tpl_path)
messages.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
messages.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
messages.append({"role": "user", "content": message})
parameters = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty
}
try:
for response in local_openllm.send_message(messages, model_id, max_input_token_length, parameters):
yield response
except Exception as e:
gr.Warning(f"{e} ➡️ Switching to TGI remotely hosted model")
finally:
async for response in tgi_openllm.send_messages(messages, model_id, hf_token, parameters):
yield response