watsonchua commited on
Commit
647567b
1 Parent(s): 4bb2a2d

add app and requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +77 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_client import Client
3
+ from transformers import AutoTokenizer
4
+
5
+
6
+ base_model_id = "alpindale/Mistral-7B-v0.2-hf"
7
+ eval_tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, add_special_tokens=False, trust_remote_code=True, use_fast=True)
8
+
9
+
10
+ def get_encoded_length(text):
11
+ encoded = eval_tokenizer.encode(text)
12
+ return len(encoded)
13
+
14
+
15
+ client = Client("https://capdev.govtext.gov.sg:9092/")
16
+
17
+
18
+ # gradio_port = 8503
19
+
20
+ prompt_template = "<start_header_id>{role}<end_header_id>{message}<|eot_id|>"
21
+
22
+ input_max_length = 5500
23
+
24
+ def format_query(history, message):
25
+ formatted_messages_list = []
26
+ new_msg_formatted = prompt_template.format(role="user", message=message)
27
+ new_msg_formatted_length = get_encoded_length(new_msg_formatted)
28
+ total_content_length = new_msg_formatted_length
29
+
30
+ formatted_messages_list = [new_msg_formatted]
31
+ for user_msg, system_msg in reversed(history):
32
+ system_msg_formatted = prompt_template.format(role="system", message=system_msg)
33
+ system_msg_formatted_length = get_encoded_length(system_msg_formatted)
34
+ if total_content_length + system_msg_formatted_length < input_max_length:
35
+ formatted_messages_list.insert(0,system_msg_formatted)
36
+ else:
37
+ break
38
+
39
+ user_msg_formatted = prompt_template.format(role="user", message=user_msg)
40
+ user_msg_formatted_length = get_encoded_length(user_msg_formatted)
41
+ if total_content_length + user_msg_formatted_length < input_max_length:
42
+ formatted_messages_list.insert(0, user_msg_formatted)
43
+
44
+ else:
45
+ break
46
+
47
+
48
+ # print(formatted_messages_list)
49
+
50
+ return "".join(formatted_messages_list) + "<start_header_id>system<end_header_id>"
51
+
52
+
53
+
54
+
55
+ def get_reply_from_chatbot(message, history):
56
+
57
+ query_formatted = format_query(history, message)
58
+ # print(query_formatted)
59
+
60
+ result = client.predict(
61
+ eval_prompt=query_formatted,
62
+ temperature=0.7,
63
+ max_new_tokens=100,
64
+ api_name="/predict"
65
+ )
66
+
67
+ # find the last generated message
68
+ response_str = result.split("<start_header_id>system<end_header_id>")[-1]
69
+ response_str = response_str.split("<|eot_id|>")[0].strip()
70
+
71
+ return response_str
72
+
73
+ demo = gr.ChatInterface(fn=get_reply_from_chatbot, examples=["How is your day?", "What are you doing now?", "Dinner is ready."], title="Talk to Watson")
74
+ # demo.launch(server_name="0.0.0.0", server_port=gradio_port)
75
+
76
+ # for HF spaces
77
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ transformers