ubermenchh commited on
Commit
de5be0f
1 Parent(s): ab4e177

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +239 -0
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Iterator
3
+ import gradio as gr
4
+ from text_generation import Client
5
+
6
+ repo_id =
7
+ HF_TOKEN = os.environ.get('HF_READ_TOKEN', False)
8
+ EOS_STRING = '</s>'
9
+ EOT_STRING = '<EOT>'
10
+
11
+ def get_prompt(message, chat_history, system_prompt):
12
+ texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
13
+
14
+ do_strip = False
15
+ for user_input, response in chat_history:
16
+ user_input = user_input.strip() if do_strip else user_input
17
+ do_strip = True
18
+ texts.append(f"{user_input} [/INST\ {response.strip()} </s><s>[INST] ")
19
+ message = message.strip() if do_strip else message
20
+ texts.append(f"{message} [/INST]")
21
+ return ''.join(texts)
22
+
23
+ def run(model_id, message, chat_history, system_prompt, max_new_tokens=1024, temperature=0.3, top_p=0.9, top_k=50):
24
+ API_URL = "https://api-inference.huggingface.co/models/" + model_id
25
+ client = Client(API_URL, header={'Authorization'L f"Bearer {HF_TOKEN}"})
26
+ generate_kwargs = dict(
27
+ max_new_tokens=max_new_tokens,
28
+ do_sample=True,
29
+ top_p=top_p,
30
+ top_k=top_k,
31
+ temperature=temperature
32
+ )
33
+ stream = client.generate_stream(prompt, **generate_kwargs)
34
+ output = ''
35
+ for response in stream:
36
+ if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
37
+ return output
38
+ else:
39
+ output += response.token.text
40
+ yield output
41
+ return output
42
+
43
+ DEFAULT_SYSTEM_PROMPT = """
44
+ You are Jarvis. You are an AI assistant, you are moderately-polite and give only true information.
45
+ You carefully provide accurate, factual, thoughtful, nuanced answers, and are brilliant at reasoning.
46
+ If you think there might not be a correct answer, you say so. Since you are autoregressive,
47
+ each token you produce is another opportunity to use computation, therefore you always spend a few sentences explaining background context,
48
+ assumptions, and step-by-step thinking BEFORE you try to answer a question.
49
+ """
50
+ MAX_MAX_NEW_TOKENS = 4096
51
+ DEFAULT_MAX_NEW_TOKENS = 256
52
+ MAX_INPUT_TOKEN_LENGTH = 4000
53
+
54
+ DESCRIPTION = "# He's just Jarvis. ;)"
55
+
56
+ def clear_and_save_textbox(message): return '', message
57
+
58
+ def display_output(message, history=[]):
59
+ history.append((message, ''))
60
+ return history
61
+
62
+ def delete_prev_fn(history=[]):
63
+ try:
64
+ message, _ = history.pop()
65
+ except IndexError:
66
+ message = ''
67
+ return history, message or ''
68
+
69
+ def generate(model_id, message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k):
70
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
71
+ raise ValueError
72
+
73
+ history = history_with_input[:-1]
74
+ generator = run(model_id, message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
75
+
76
+ try:
77
+ first_response = next(generator)
78
+ yield history + [(message, first_response)]
79
+ except StopIteration:
80
+ yield history + [(message, '')]
81
+ for response in generator:
82
+ yield history + [(message, response)]
83
+
84
+ def process_example(model_id, message):
85
+ generator = generate(model)id, message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
86
+ for x in generator:
87
+ pass
88
+ return '', x
89
+
90
+ def check_input_token_length(message, chat_history, system_prompt):
91
+ input_token_length = len(message) + len(chat_history)
92
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
93
+ raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Client your chat history and try again.")
94
+
95
+ with gr.Blocks(theme='Taithrah/Minimal') as demo:
96
+ gr.Markdown(DESCRIPTION)
97
+ with gr.Group():
98
+ chatbot = gr.Chatbot(label='Jarvis')
99
+ with gr.Row():
100
+ textbox = gr.Textbox(container=False, show_label=False, placeholder='Hey, Jarvis', scale=7)
101
+ model_id = gr.Dropdown(label='LLM',
102
+ choices=[
103
+ 'mistralai/Mistral-7B-Instruct-v0.1',
104
+ 'HuggingFaceH4/zephyr-7b-beta',
105
+ 'meta-llama/Llama-2-7b-chat-hf'
106
+ ],
107
+ value='mistralai/Mistral-7B-Instruct-v0.1', scale=3)
108
+ submit_button = gr.Button('Submit', variant='primary', scale=1, min_width=0)
109
+
110
+ with gr.Row():
111
+ retry_button = gr.Button('Retry', variant='secondary')
112
+ undo_button = gr.Button('Undo', variant='secondary')
113
+ clear_button = gr.Button('Clear', variant='secondary')
114
+
115
+ saved_input = gr.State()
116
+
117
+ with gr.Accordion(label='Advanced Options', open=False):
118
+ system_prompt = gr.Textbox(label='System prompt',, value=DEFAULT_SYSTEM_PROMPT, lines=5, interactive=False)
119
+ max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
120
+ temperature = gr.Slider(label='Temperatur', minimum=0.1, maximum=4.0, step=0.1, value=0.1)
121
+ top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)
122
+ top_k = gr.Slider(label='Top-K', minimum=1, maximum=1000, step=1, value=10)
123
+
124
+ textbox.submit(
125
+ fn=clear_and_save_textbox,
126
+ inputs=textbox,
127
+ outputs=[textbox, saved_input],
128
+ api_name=False,
129
+ queue=False,
130
+ ).then(
131
+ fn=display_input,
132
+ inputs=[saved_input, chatbot],
133
+ outputs=chatbot,
134
+ api_name=False,
135
+ queue=False,
136
+ ).then(
137
+ fn=check_input_token_length,
138
+ inputs=[saved_input, chatbot, system_prompt],
139
+ api_name=False,
140
+ queue=False,
141
+ ).success(
142
+ fn=generate,
143
+ inputs=[
144
+ model_id,
145
+ saved_input,
146
+ chatbot,
147
+ system_prompt,
148
+ max_new_tokens,
149
+ temperature,
150
+ top_p,
151
+ top_k,
152
+ ],
153
+ outputs=chatbot,
154
+ api_name=False,
155
+ )
156
+
157
+ button_event_preprocess = submit_button.click(
158
+ fn=clear_and_save_textbox,
159
+ inputs=textbox,
160
+ outputs=[textbox, saved_input],
161
+ api_name=False,
162
+ queue=False,
163
+ ).then(
164
+ fn=display_input,
165
+ inputs=[saved_input, chatbot],
166
+ outputs=chatbot,
167
+ api_name=False,
168
+ queue=False,
169
+ ).then(
170
+ fn=check_input_token_length,
171
+ inputs=[saved_input, chatbot, system_prompt],
172
+ api_name=False,
173
+ queue=False,
174
+ ).success(
175
+ fn=generate,
176
+ inputs=[
177
+ model_id,
178
+ saved_input,
179
+ chatbot,
180
+ system_prompt,
181
+ max_new_tokens,
182
+ temperature,
183
+ top_p,
184
+ top_k,
185
+ ],
186
+ outputs=chatbot,
187
+ api_name=False,
188
+ )
189
+
190
+ retry_button.click(
191
+ fn=delete_prev_fn,
192
+ inputs=chatbot,
193
+ outputs=[chatbot, saved_input],
194
+ api_name=False,
195
+ queue=False,
196
+ ).then(
197
+ fn=display_input,
198
+ inputs=[saved_input, chatbot],
199
+ outputs=chatbot,
200
+ api_name=False,
201
+ queue=False,
202
+ ).then(
203
+ fn=generate,
204
+ inputs=[
205
+ model_id,
206
+ saved_input,
207
+ chatbot,
208
+ system_prompt,
209
+ max_new_tokens,
210
+ temperature,
211
+ top_p,
212
+ top_k,
213
+ ],
214
+ outputs=chatbot,
215
+ api_name=False,
216
+ )
217
+
218
+ undo_button.click(
219
+ fn=delete_prev_fn,
220
+ inputs=chatbot,
221
+ outputs=[chatbot, saved_input],
222
+ api_name=False,
223
+ queue=False,
224
+ ).then(
225
+ fn=lambda x: x,
226
+ inputs=[saved_input],
227
+ outputs=textbox,
228
+ api_name=False,
229
+ queue=False,
230
+ )
231
+
232
+ clear_button.click(
233
+ fn=lambda: ([], ''),
234
+ outputs=[chatbot, saved_input],
235
+ queue=False,
236
+ api_name=False,
237
+ )
238
+
239
+ demo.queue(max_size=32).launch(show_api=False)