ubermenchh commited on
Commit
b4882b6
1 Parent(s): d77021e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +242 -0
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Iterator
3
+ import gradio as gr
4
+ from text_generation import Client
5
+
6
+ model_id = '01-ai/Yi-6B'
7
+ API_URL = "https://api-inference.huggingface.co/models/" + model_id
8
+ HF_TOKEN = os.environ.get('HF_READ_TOKEN', False)
9
+
10
+ client = Client(
11
+ API_URL,
12
+ headers={'Authorization': f'Bearer {HF_TOKEN}'}
13
+ )
14
+ EOS_STRING = '</s>'
15
+ EOT_STRING = '<EOT>'
16
+
17
+ def get_prompt(messgae, chat_history, system_prompt):
18
+ texts = [f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"]
19
+
20
+ do_strip = False
21
+ for user_input, response in chat_history:
22
+ user_input = user_input.strip() if do_strip else user_input
23
+ do_strip = True
24
+ texts.append(f"{user_input} [/INST] {response.strip()} </s><s>[INST] ")
25
+ message = message.strip() if do_strip else message
26
+ texts.append(f"{message} [/INST]")
27
+ return ''.join(texts)
28
+
29
+ def run(message, chat_history, system_prompt, max_new_tokens=1024, temperature=0.3, top_p=0.9, top_k=50):
30
+ prompt = get_prompt(message, chat_history, system_prompt)
31
+
32
+ generate_kwargs = dict(
33
+ max_new_tokens=max_new_tokens,
34
+ do_sample=True,
35
+ top_p=top_p,
36
+ top_k=top_k,
37
+ temperature=temperature
38
+ )
39
+ stream = client.generate_stream(prompt, **generate_kwargs)
40
+ output = ''
41
+ for response in stream:
42
+ if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
43
+ return output
44
+ else:
45
+ output += response.token.text
46
+ yield output
47
+ return output
48
+
49
+ DEFAULT_SYSTEM_PROMPT = """
50
+ You are Yi. You are an AI assistant, you are moderately-polite and give only true information.
51
+ You carefully provide accurate, factual, thoughtful, nuanced answers, and are brilliant at reasoning.
52
+ If you think there might not be a correct answer, you say so. Since you are autoregressive,
53
+ each token you produce is another opportunity to use computation, therefore you always spend a few sentences explaining background context,
54
+ assumptions, and step-by-step thinking BEFORE you try to answer a question.
55
+ """
56
+ MAX_MAX_NEW_TOKENS = 4096
57
+ DEFAULT_MAX_NEW_TOKENS = 256
58
+ MAX_INPUT_TOKEN_LENGTH = 4000
59
+
60
+ DESCRIPTION = "# [Yi-6B](https://huggingface.co/01-ai/Yi-6B)"
61
+
62
+ def clear_and_save_textbox(message): return '', message
63
+
64
+ def display_input(message, history=[]):
65
+ history.append((message, ''))
66
+ return history
67
+
68
+ def delete_prev_fn(history=[]):
69
+ try:
70
+ message, _ = history.pop()
71
+ except IndexError:
72
+ message = ''
73
+ return history, message or ''
74
+
75
+ def generate(message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k):
76
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
77
+ raise ValueError
78
+
79
+ history = history_with_input[:-1]
80
+ generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
81
+ try:
82
+ first_response = next(generator)
83
+ yield history + [(message, first_response)]
84
+ except StopIteration:
85
+ yield history + [(message, '')]
86
+ for response in generator:
87
+ yield history + [(message, response)]
88
+
89
+ def process_example(message):
90
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
91
+ for x in generator:
92
+ pass
93
+ return '', x
94
+
95
+ def check_input_token_length(message, chat_history, system_prompt):
96
+ input_token_length = len(message) + len(chat_history)
97
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
98
+ raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.")
99
+
100
+ with gr.Blocks(theme='Taithrah/Minimal') as demo:
101
+ gr.Markdown(DESCRIPTION)
102
+
103
+
104
+
105
+ with gr.Group():
106
+ chatbot = gr.Chatbot(label='Mistral-7B-Instruct-v0.1')
107
+ with gr.Row():
108
+ textbox = gr.Textbox(
109
+ container=False,
110
+ show_label=False,
111
+ placeholder='Hi, Mistral',
112
+ scale=10
113
+ )
114
+ submit_button = gr.Button('Submit', variant='primary', scale=1, min_width=0)
115
+
116
+ with gr.Row():
117
+ retry_button = gr.Button('Retry', variant='secondary')
118
+ undo_button = gr.Button('Undo', variant='secondary')
119
+ clear_button = gr.Button('Clear', variant='secondary')
120
+
121
+ saved_input = gr.State()
122
+
123
+ with gr.Accordion(label='Advanced options', open=False):
124
+ system_prompt = gr.Textbox(label='System prompt', value=DEFAULT_SYSTEM_PROMPT, lines=5, interactive=False)
125
+ max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
126
+ temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=4.0, step=0.1, value=0.1)
127
+ top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)
128
+ top_k = gr.Slider(label='Top-K', minimum=1, maximum=1000, step=1, value=10)
129
+
130
+ textbox.submit(
131
+ fn=clear_and_save_textbox,
132
+ inputs=textbox,
133
+ outputs=[textbox, saved_input],
134
+ api_name=False,
135
+ queue=False,
136
+ ).then(
137
+ fn=display_input,
138
+ inputs=[saved_input, chatbot],
139
+ outputs=chatbot,
140
+ api_name=False,
141
+ queue=False,
142
+ ).then(
143
+ fn=check_input_token_length,
144
+ inputs=[saved_input, chatbot, system_prompt],
145
+ api_name=False,
146
+ queue=False,
147
+ ).success(
148
+ fn=generate,
149
+ inputs=[
150
+ saved_input,
151
+ chatbot,
152
+ system_prompt,
153
+ max_new_tokens,
154
+ temperature,
155
+ top_p,
156
+ top_k,
157
+ ],
158
+ outputs=chatbot,
159
+ api_name=False,
160
+ )
161
+
162
+ button_event_preprocess = submit_button.click(
163
+ fn=clear_and_save_textbox,
164
+ inputs=textbox,
165
+ outputs=[textbox, saved_input],
166
+ api_name=False,
167
+ queue=False,
168
+ ).then(
169
+ fn=display_input,
170
+ inputs=[saved_input, chatbot],
171
+ outputs=chatbot,
172
+ api_name=False,
173
+ queue=False,
174
+ ).then(
175
+ fn=check_input_token_length,
176
+ inputs=[saved_input, chatbot, system_prompt],
177
+ api_name=False,
178
+ queue=False,
179
+ ).success(
180
+ fn=generate,
181
+ inputs=[
182
+ saved_input,
183
+ chatbot,
184
+ system_prompt,
185
+ max_new_tokens,
186
+ temperature,
187
+ top_p,
188
+ top_k,
189
+ ],
190
+ outputs=chatbot,
191
+ api_name=False,
192
+ )
193
+
194
+ retry_button.click(
195
+ fn=delete_prev_fn,
196
+ inputs=chatbot,
197
+ outputs=[chatbot, saved_input],
198
+ api_name=False,
199
+ queue=False,
200
+ ).then(
201
+ fn=display_input,
202
+ inputs=[saved_input, chatbot],
203
+ outputs=chatbot,
204
+ api_name=False,
205
+ queue=False,
206
+ ).then(
207
+ fn=generate,
208
+ inputs=[
209
+ saved_input,
210
+ chatbot,
211
+ system_prompt,
212
+ max_new_tokens,
213
+ temperature,
214
+ top_p,
215
+ top_k,
216
+ ],
217
+ outputs=chatbot,
218
+ api_name=False,
219
+ )
220
+
221
+ undo_button.click(
222
+ fn=delete_prev_fn,
223
+ inputs=chatbot,
224
+ outputs=[chatbot, saved_input],
225
+ api_name=False,
226
+ queue=False,
227
+ ).then(
228
+ fn=lambda x: x,
229
+ inputs=[saved_input],
230
+ outputs=textbox,
231
+ api_name=False,
232
+ queue=False,
233
+ )
234
+
235
+ clear_button.click(
236
+ fn=lambda: ([], ''),
237
+ outputs=[chatbot, saved_input],
238
+ queue=False,
239
+ api_name=False,
240
+ )
241
+
242
+ demo.queue(max_size=32).launch(show_api=False)