talk_to_watson / app.py
watsonchua's picture
add app and requirements.txt
647567b
raw
history blame contribute delete
No virus
2.48 kB
import gradio as gr
from gradio_client import Client
from transformers import AutoTokenizer
base_model_id = "alpindale/Mistral-7B-v0.2-hf"
eval_tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, add_special_tokens=False, trust_remote_code=True, use_fast=True)
def get_encoded_length(text):
encoded = eval_tokenizer.encode(text)
return len(encoded)
client = Client("https://capdev.govtext.gov.sg:9092/")
# gradio_port = 8503
prompt_template = "<start_header_id>{role}<end_header_id>{message}<|eot_id|>"
input_max_length = 5500
def format_query(history, message):
formatted_messages_list = []
new_msg_formatted = prompt_template.format(role="user", message=message)
new_msg_formatted_length = get_encoded_length(new_msg_formatted)
total_content_length = new_msg_formatted_length
formatted_messages_list = [new_msg_formatted]
for user_msg, system_msg in reversed(history):
system_msg_formatted = prompt_template.format(role="system", message=system_msg)
system_msg_formatted_length = get_encoded_length(system_msg_formatted)
if total_content_length + system_msg_formatted_length < input_max_length:
formatted_messages_list.insert(0,system_msg_formatted)
else:
break
user_msg_formatted = prompt_template.format(role="user", message=user_msg)
user_msg_formatted_length = get_encoded_length(user_msg_formatted)
if total_content_length + user_msg_formatted_length < input_max_length:
formatted_messages_list.insert(0, user_msg_formatted)
else:
break
# print(formatted_messages_list)
return "".join(formatted_messages_list) + "<start_header_id>system<end_header_id>"
def get_reply_from_chatbot(message, history):
query_formatted = format_query(history, message)
# print(query_formatted)
result = client.predict(
eval_prompt=query_formatted,
temperature=0.7,
max_new_tokens=100,
api_name="/predict"
)
# find the last generated message
response_str = result.split("<start_header_id>system<end_header_id>")[-1]
response_str = response_str.split("<|eot_id|>")[0].strip()
return response_str
demo = gr.ChatInterface(fn=get_reply_from_chatbot, examples=["How is your day?", "What are you doing now?", "Dinner is ready."], title="Talk to Watson")
# demo.launch(server_name="0.0.0.0", server_port=gradio_port)
# for HF spaces
demo.launch()