Spaces:
Runtime error
Runtime error
import gradio as gr | |
from gradio_client import Client | |
from transformers import AutoTokenizer | |
base_model_id = "alpindale/Mistral-7B-v0.2-hf" | |
eval_tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, add_special_tokens=False, trust_remote_code=True, use_fast=True) | |
def get_encoded_length(text): | |
encoded = eval_tokenizer.encode(text) | |
return len(encoded) | |
client = Client("https://capdev.govtext.gov.sg:9092/") | |
# gradio_port = 8503 | |
prompt_template = "<start_header_id>{role}<end_header_id>{message}<|eot_id|>" | |
input_max_length = 5500 | |
def format_query(history, message): | |
formatted_messages_list = [] | |
new_msg_formatted = prompt_template.format(role="user", message=message) | |
new_msg_formatted_length = get_encoded_length(new_msg_formatted) | |
total_content_length = new_msg_formatted_length | |
formatted_messages_list = [new_msg_formatted] | |
for user_msg, system_msg in reversed(history): | |
system_msg_formatted = prompt_template.format(role="system", message=system_msg) | |
system_msg_formatted_length = get_encoded_length(system_msg_formatted) | |
if total_content_length + system_msg_formatted_length < input_max_length: | |
formatted_messages_list.insert(0,system_msg_formatted) | |
else: | |
break | |
user_msg_formatted = prompt_template.format(role="user", message=user_msg) | |
user_msg_formatted_length = get_encoded_length(user_msg_formatted) | |
if total_content_length + user_msg_formatted_length < input_max_length: | |
formatted_messages_list.insert(0, user_msg_formatted) | |
else: | |
break | |
# print(formatted_messages_list) | |
return "".join(formatted_messages_list) + "<start_header_id>system<end_header_id>" | |
def get_reply_from_chatbot(message, history): | |
query_formatted = format_query(history, message) | |
# print(query_formatted) | |
result = client.predict( | |
eval_prompt=query_formatted, | |
temperature=0.7, | |
max_new_tokens=100, | |
api_name="/predict" | |
) | |
# find the last generated message | |
response_str = result.split("<start_header_id>system<end_header_id>")[-1] | |
response_str = response_str.split("<|eot_id|>")[0].strip() | |
return response_str | |
demo = gr.ChatInterface(fn=get_reply_from_chatbot, examples=["How is your day?", "What are you doing now?", "Dinner is ready."], title="Talk to Watson") | |
# demo.launch(server_name="0.0.0.0", server_port=gradio_port) | |
# for HF spaces | |
demo.launch() | |