dusieq commited on
Commit
02c4ff6
1 Parent(s): 5526141

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -1
app.py CHANGED
@@ -1,3 +1,217 @@
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.Interface.load("models/uni-tianyan/Uni-TianYan").launch()
 
1
+ import os
2
+ import time
3
+
4
  import gradio as gr
5
+ from mcli import predict
6
+
7
+
8
+ URL = os.environ.get("URL")
9
+ if URL is None:
10
+ raise ValueError("URL environment variable must be set")
11
+ if os.environ.get("MOSAICML_API_KEY") is None:
12
+ raise ValueError("git environment variable must be set")
13
+
14
+
15
+ class Chat:
16
+ default_system_prompt = "A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers."
17
+ system_format = "<|im_start|>system\n{}<|im_end|>\n"
18
+
19
+ def __init__(self, system: str = None, user: str = None, assistant: str = None) -> None:
20
+ if system is not None:
21
+ self.set_system_prompt(system)
22
+ else:
23
+ self.reset_system_prompt()
24
+ self.user = user if user else "<|im_start|>user\n{}<|im_end|>\n"
25
+ self.assistant = assistant if assistant else "<|im_start|>assistant\n{}<|im_end|>\n"
26
+ self.response_prefix = self.assistant.split("{}")[0]
27
+
28
+ def set_system_prompt(self, system_prompt):
29
+ # self.system = self.system_format.format(system_prompt)
30
+ return system_prompt
31
+
32
+ def reset_system_prompt(self):
33
+ return self.set_system_prompt(self.default_system_prompt)
34
+
35
+ def history_as_formatted_str(self, system, history) -> str:
36
+ system = self.system_format.format(system)
37
+ text = system + "".join(
38
+ [
39
+ "\n".join(
40
+ [
41
+ self.user.format(item[0]),
42
+ self.assistant.format(item[1]),
43
+ ]
44
+ )
45
+ for item in history[:-1]
46
+ ]
47
+ )
48
+ text += self.user.format(history[-1][0])
49
+ text += self.response_prefix
50
+ # stopgap solution to too long sequences
51
+ if len(text) > 4500:
52
+ # delete from the middle between <|im_start|> and <|im_end|>
53
+ # find the middle ones, then expand out
54
+ start = text.find("<|im_start|>", 139)
55
+ end = text.find("<|im_end|>", 139)
56
+ while end < len(text) and len(text) > 4500:
57
+ end = text.find("<|im_end|>", end + 1)
58
+ text = text[:start] + text[end + 1 :]
59
+ if len(text) > 4500:
60
+ # the nice way didn't work, just truncate
61
+ # deleting the beginning
62
+ text = text[-4500:]
63
+
64
+ return text
65
+
66
+ def clear_history(self, history):
67
+ return []
68
+
69
+ def turn(self, user_input: str):
70
+ self.user_turn(user_input)
71
+ return self.bot_turn()
72
+
73
+ def user_turn(self, user_input: str, history):
74
+ history.append([user_input, ""])
75
+ return user_input, history
76
+
77
+ def bot_turn(self, system, history):
78
+ conversation = self.history_as_formatted_str(system, history)
79
+ assistant_response = call_inf_server(conversation)
80
+ history[-1][-1] = assistant_response
81
+ print(system)
82
+ print(history)
83
+ return "", history
84
+
85
+
86
+ def call_inf_server(prompt):
87
+ try:
88
+ response = predict(
89
+ URL,
90
+ {"inputs": [prompt], "temperature": 0.2, "top_p": 0.9, "output_len": 512},
91
+ timeout=70,
92
+ )
93
+ # print(f'prompt: {prompt}')
94
+ # print(f'len(prompt): {len(prompt)}')
95
+ response = response["outputs"][0]
96
+ # print(f'len(response): {len(response)}')
97
+ # remove spl tokens from prompt
98
+ spl_tokens = ["<|im_start|>", "<|im_end|>"]
99
+ clean_prompt = prompt.replace(spl_tokens[0], "").replace(spl_tokens[1], "")
100
+ return response[len(clean_prompt) :] # remove the prompt
101
+ except Exception as e:
102
+ # assume it is our error
103
+ # just wait and try one more time
104
+ print(e)
105
+ time.sleep(1)
106
+ response = predict(
107
+ URL,
108
+ {"inputs": [prompt], "temperature": 0.2, "top_p": 0.9, "output_len": 512},
109
+ timeout=70,
110
+ )
111
+ # print(response)
112
+ response = response["outputs"][0]
113
+ return response[len(prompt) :] # remove the prompt
114
+
115
+
116
+ with gr.Blocks(
117
+ theme=gr.themes.Soft(),
118
+ css=".disclaimer {font-variant-caps: all-small-caps;}",
119
+ ) as demo:
120
+ gr.Markdown(
121
+ """<h1><center>MosaicML MPT-30B-Chat</center></h1>
122
+ This demo is of [MPT-30B-Chat](https://huggingface.co/mosaicml/mpt-30b-chat). It is based on [MPT-30B](https://huggingface.co/mosaicml/mpt-30b) fine-tuned on approximately 300,000 turns of high-quality conversations, and is powered by [MosaicML Inference](https://www.mosaicml.com/inference).
123
+ If you're interested in [training](https://www.mosaicml.com/training) and [deploying](https://www.mosaicml.com/inference) your own MPT or LLMs, [sign up](https://forms.mosaicml.com/demo?utm_source=huggingface&utm_medium=referral&utm_campaign=mpt-30b) for MosaicML platform.
124
+ """
125
+ )
126
+ conversation = Chat()
127
+ chatbot = gr.Chatbot().style(height=500)
128
+ with gr.Row():
129
+ with gr.Column():
130
+ msg = gr.Textbox(
131
+ label="Chat Message Box",
132
+ placeholder="Chat Message Box",
133
+ show_label=False,
134
+ ).style(container=False)
135
+ with gr.Column():
136
+ with gr.Row():
137
+ submit = gr.Button("Submit")
138
+ stop = gr.Button("Stop")
139
+ clear = gr.Button("Clear")
140
+ with gr.Row():
141
+ with gr.Accordion("Advanced Options:", open=False):
142
+ with gr.Row():
143
+ with gr.Column(scale=2):
144
+ system = gr.Textbox(
145
+ label="System Prompt",
146
+ value=Chat.default_system_prompt,
147
+ show_label=False,
148
+ ).style(container=False)
149
+ with gr.Column():
150
+ with gr.Row():
151
+ change = gr.Button("Change System Prompt")
152
+ reset = gr.Button("Reset System Prompt")
153
+ with gr.Row():
154
+ gr.Markdown(
155
+ "Disclaimer: MPT-30B can produce factually incorrect output, and should not be relied on to produce "
156
+ "factually accurate information. MPT-30B was trained on various public datasets; while great efforts "
157
+ "have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
158
+ "biased, or otherwise offensive outputs.",
159
+ elem_classes=["disclaimer"],
160
+ )
161
+ with gr.Row():
162
+ gr.Markdown(
163
+ "[Privacy policy](https://gist.github.com/samhavens/c29c68cdcd420a9aa0202d0839876dac)",
164
+ elem_classes=["disclaimer"],
165
+ )
166
+
167
+ submit_event = msg.submit(
168
+ fn=conversation.user_turn,
169
+ inputs=[msg, chatbot],
170
+ outputs=[msg, chatbot],
171
+ queue=False,
172
+ ).then(
173
+ fn=conversation.bot_turn,
174
+ inputs=[system, chatbot],
175
+ outputs=[msg, chatbot],
176
+ queue=True,
177
+ )
178
+ submit_click_event = submit.click(
179
+ fn=conversation.user_turn,
180
+ inputs=[msg, chatbot],
181
+ outputs=[msg, chatbot],
182
+ queue=False,
183
+ ).then(
184
+ fn=conversation.bot_turn,
185
+ inputs=[system, chatbot],
186
+ outputs=[msg, chatbot],
187
+ queue=True,
188
+ )
189
+ stop.click(
190
+ fn=None,
191
+ inputs=None,
192
+ outputs=None,
193
+ cancels=[submit_event, submit_click_event],
194
+ queue=False,
195
+ )
196
+ clear.click(lambda: None, None, chatbot, queue=False).then(
197
+ fn=conversation.clear_history,
198
+ inputs=[chatbot],
199
+ outputs=[chatbot],
200
+ queue=False,
201
+ )
202
+ change.click(
203
+ fn=conversation.set_system_prompt,
204
+ inputs=[system],
205
+ outputs=[system],
206
+ queue=False,
207
+ )
208
+ reset.click(
209
+ fn=conversation.reset_system_prompt,
210
+ inputs=[],
211
+ outputs=[system],
212
+ queue=False,
213
+ )
214
+
215
+
216
+ demo.queue(max_size=36, concurrency_count=14).launch(debug=True)
217