File size: 2,477 Bytes
647567b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
from gradio_client import Client
from transformers import AutoTokenizer


base_model_id = "alpindale/Mistral-7B-v0.2-hf"
eval_tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, add_special_tokens=False, trust_remote_code=True, use_fast=True)


def get_encoded_length(text):
    encoded = eval_tokenizer.encode(text)
    return len(encoded)


client = Client("https://capdev.govtext.gov.sg:9092/")


# gradio_port = 8503

prompt_template = "<start_header_id>{role}<end_header_id>{message}<|eot_id|>"

input_max_length = 5500

def format_query(history, message):
    formatted_messages_list = []
    new_msg_formatted = prompt_template.format(role="user", message=message)
    new_msg_formatted_length = get_encoded_length(new_msg_formatted)
    total_content_length = new_msg_formatted_length

    formatted_messages_list = [new_msg_formatted]
    for user_msg, system_msg in reversed(history):
        system_msg_formatted = prompt_template.format(role="system", message=system_msg)
        system_msg_formatted_length = get_encoded_length(system_msg_formatted)
        if total_content_length + system_msg_formatted_length < input_max_length:
            formatted_messages_list.insert(0,system_msg_formatted)
        else:
            break

        user_msg_formatted = prompt_template.format(role="user", message=user_msg)
        user_msg_formatted_length = get_encoded_length(user_msg_formatted)
        if total_content_length + user_msg_formatted_length < input_max_length:
            formatted_messages_list.insert(0, user_msg_formatted)

        else:
            break


    # print(formatted_messages_list)

    return "".join(formatted_messages_list) + "<start_header_id>system<end_header_id>"

    


def get_reply_from_chatbot(message, history):

    query_formatted = format_query(history, message)
    # print(query_formatted)

    result = client.predict(
		eval_prompt=query_formatted,
		temperature=0.7,
		max_new_tokens=100,
		api_name="/predict"
    )
    
    # find the last generated message
    response_str = result.split("<start_header_id>system<end_header_id>")[-1]
    response_str = response_str.split("<|eot_id|>")[0].strip()

    return response_str

demo = gr.ChatInterface(fn=get_reply_from_chatbot, examples=["How is your day?", "What are you doing now?", "Dinner is ready."], title="Talk to Watson")
# demo.launch(server_name="0.0.0.0", server_port=gradio_port)

# for HF spaces
demo.launch()