import gradio as gr from gradio_client import Client from transformers import AutoTokenizer base_model_id = "alpindale/Mistral-7B-v0.2-hf" eval_tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, add_special_tokens=False, trust_remote_code=True, use_fast=True) def get_encoded_length(text): encoded = eval_tokenizer.encode(text) return len(encoded) client = Client("https://capdev.govtext.gov.sg:9092/") # gradio_port = 8503 prompt_template = "{role}{message}<|eot_id|>" input_max_length = 5500 def format_query(history, message): formatted_messages_list = [] new_msg_formatted = prompt_template.format(role="user", message=message) new_msg_formatted_length = get_encoded_length(new_msg_formatted) total_content_length = new_msg_formatted_length formatted_messages_list = [new_msg_formatted] for user_msg, system_msg in reversed(history): system_msg_formatted = prompt_template.format(role="system", message=system_msg) system_msg_formatted_length = get_encoded_length(system_msg_formatted) if total_content_length + system_msg_formatted_length < input_max_length: formatted_messages_list.insert(0,system_msg_formatted) else: break user_msg_formatted = prompt_template.format(role="user", message=user_msg) user_msg_formatted_length = get_encoded_length(user_msg_formatted) if total_content_length + user_msg_formatted_length < input_max_length: formatted_messages_list.insert(0, user_msg_formatted) else: break # print(formatted_messages_list) return "".join(formatted_messages_list) + "system" def get_reply_from_chatbot(message, history): query_formatted = format_query(history, message) # print(query_formatted) result = client.predict( eval_prompt=query_formatted, temperature=0.7, max_new_tokens=100, api_name="/predict" ) # find the last generated message response_str = result.split("system")[-1] response_str = response_str.split("<|eot_id|>")[0].strip() return response_str demo = gr.ChatInterface(fn=get_reply_from_chatbot, examples=["How is your day?", "What are you doing now?", "Dinner is ready."], title="Talk to Watson") # demo.launch(server_name="0.0.0.0", server_port=gradio_port) # for HF spaces demo.launch()