Spaces:
Runtime error
Runtime error
File size: 2,477 Bytes
647567b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
from gradio_client import Client
from transformers import AutoTokenizer
base_model_id = "alpindale/Mistral-7B-v0.2-hf"
eval_tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, add_special_tokens=False, trust_remote_code=True, use_fast=True)
def get_encoded_length(text):
encoded = eval_tokenizer.encode(text)
return len(encoded)
client = Client("https://capdev.govtext.gov.sg:9092/")
# gradio_port = 8503
prompt_template = "<start_header_id>{role}<end_header_id>{message}<|eot_id|>"
input_max_length = 5500
def format_query(history, message):
formatted_messages_list = []
new_msg_formatted = prompt_template.format(role="user", message=message)
new_msg_formatted_length = get_encoded_length(new_msg_formatted)
total_content_length = new_msg_formatted_length
formatted_messages_list = [new_msg_formatted]
for user_msg, system_msg in reversed(history):
system_msg_formatted = prompt_template.format(role="system", message=system_msg)
system_msg_formatted_length = get_encoded_length(system_msg_formatted)
if total_content_length + system_msg_formatted_length < input_max_length:
formatted_messages_list.insert(0,system_msg_formatted)
else:
break
user_msg_formatted = prompt_template.format(role="user", message=user_msg)
user_msg_formatted_length = get_encoded_length(user_msg_formatted)
if total_content_length + user_msg_formatted_length < input_max_length:
formatted_messages_list.insert(0, user_msg_formatted)
else:
break
# print(formatted_messages_list)
return "".join(formatted_messages_list) + "<start_header_id>system<end_header_id>"
def get_reply_from_chatbot(message, history):
query_formatted = format_query(history, message)
# print(query_formatted)
result = client.predict(
eval_prompt=query_formatted,
temperature=0.7,
max_new_tokens=100,
api_name="/predict"
)
# find the last generated message
response_str = result.split("<start_header_id>system<end_header_id>")[-1]
response_str = response_str.split("<|eot_id|>")[0].strip()
return response_str
demo = gr.ChatInterface(fn=get_reply_from_chatbot, examples=["How is your day?", "What are you doing now?", "Dinner is ready."], title="Talk to Watson")
# demo.launch(server_name="0.0.0.0", server_port=gradio_port)
# for HF spaces
demo.launch()
|