Spaces:
Paused
Paused
import toml | |
import gradio as gr | |
from string import Template | |
from transformers import AutoTokenizer | |
from vid2persona.gen import tgi_openllm | |
from vid2persona.gen import local_openllm | |
tokenizer = None | |
def _get_system_prompt( | |
personality_json_dict: dict, | |
prompt_tpl_path: str | |
) -> str: | |
"""Assumes a single character is passed.""" | |
prompt_tpl_path = f"{prompt_tpl_path}/llm.toml" | |
system_prompt = Template(toml.load(prompt_tpl_path)['conversation']['system']) | |
name = personality_json_dict["name"] | |
physcial_description = personality_json_dict["physicalDescription"] | |
personality_traits = [str(trait) for trait in personality_json_dict["personalityTraits"]] | |
likes = [str(like) for like in personality_json_dict["likes"]] | |
dislikes = [str(dislike) for dislike in personality_json_dict["dislikes"]] | |
background = [str(info) for info in personality_json_dict["background"]] | |
goals = [str(goal) for goal in personality_json_dict["goals"]] | |
relationships = [str(relationship) for relationship in personality_json_dict["relationships"]] | |
system_prompt = system_prompt.substitute( | |
name=name, | |
physcial_description=physcial_description, | |
personality_traits=', '.join(personality_traits), | |
likes=', '.join(likes), | |
background=', '.join(background), | |
goals=', '.join(goals), | |
relationships=', '.join(relationships) | |
) | |
return system_prompt | |
async def chat( | |
message: str, | |
chat_history: list[tuple[str, str]], | |
personality_json_dict: dict, | |
prompt_tpl_path: str, | |
model_id: str, | |
max_input_token_length: int, | |
max_new_tokens: int, | |
temperature: float, | |
top_p: float, | |
top_k: int, | |
repetition_penalty: float, | |
hf_token: str, | |
): | |
messages = [] | |
system_prompt = _get_system_prompt(personality_json_dict, prompt_tpl_path) | |
messages.append({"role": "system", "content": system_prompt}) | |
for user, assistant in chat_history: | |
messages.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) | |
messages.append({"role": "user", "content": message}) | |
parameters = { | |
"max_new_tokens": max_new_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": top_k, | |
"repetition_penalty": repetition_penalty | |
} | |
try: | |
for response in local_openllm.send_message(messages, model_id, max_input_token_length, parameters): | |
yield response | |
except Exception as e: | |
gr.Warning(f"{e} ➡️ Switching to TGI remotely hosted model") | |
finally: | |
async for response in tgi_openllm.send_messages(messages, model_id, hf_token, parameters): | |
yield response |