ybelkada commited on
Commit
eb25f8a
β€’
1 Parent(s): fad0ab5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +297 -0
app.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ from huggingface_hub import Repository
5
+ from text_generation import Client
6
+
7
+ # from dialogues import DialogueTemplate
8
+ from share_btn import (community_icon_html, loading_icon_html, share_btn_css,
9
+ share_js)
10
+
11
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
12
+ API_TOKEN = os.environ.get("API_TOKEN", None)
13
+ API_URL = os.environ.get("API_URL", None)
14
+ API_URL = "https://api-inference.huggingface.co/models/timdettmers/guanaco-33b-merged"
15
+
16
+ client = Client(
17
+ API_URL,
18
+ headers={"Authorization": f"Bearer {API_TOKEN}"},
19
+ )
20
+
21
+ repo = None
22
+
23
+
24
+ def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep):
25
+ past = []
26
+ for data in chatbot:
27
+ user_data, model_data = data
28
+
29
+ if not user_data.startswith(user_name):
30
+ user_data = user_name + user_data
31
+ if not model_data.startswith(sep + assistant_name):
32
+ model_data = sep + assistant_name + model_data
33
+
34
+ past.append(user_data + model_data.rstrip() + sep)
35
+
36
+ if not inputs.startswith(user_name):
37
+ inputs = user_name + inputs
38
+
39
+ total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip()
40
+
41
+ return total_inputs
42
+
43
+
44
+ def has_no_history(chatbot, history):
45
+ return not chatbot and not history
46
+
47
+
48
+ header = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. "
49
+ prompt_template = "### Human: {query} ### Assistant:{response}"
50
+
51
+ def generate(
52
+ system_message,
53
+ user_message,
54
+ chatbot,
55
+ history,
56
+ temperature,
57
+ top_k,
58
+ top_p,
59
+ max_new_tokens,
60
+ repetition_penalty,
61
+ do_save=True,
62
+ ):
63
+ # Don't return meaningless message when the input is empty
64
+ if not user_message:
65
+ print("Empty input")
66
+
67
+ history.append(user_message)
68
+
69
+ past_messages = []
70
+ for data in chatbot:
71
+ user_data, model_data = data
72
+
73
+ past_messages.extend(
74
+ [{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}]
75
+ )
76
+
77
+ if len(past_messages) < 1:
78
+ prompt = header + prompt_template.format(query=user_message, response="")
79
+ else:
80
+ prompt = header
81
+ for i in range(0, len(past_messages), 2):
82
+ intermediate_prompt = prompt_template.format(query=past_messages[i]["content"], response=past_messages[i+1]["content"])
83
+ print("intermediate: ", intermediate_prompt)
84
+ prompt = prompt + intermediate_prompt
85
+
86
+ prompt = prompt + prompt_template.format(query=user_message, response="")
87
+
88
+
89
+ generate_kwargs = {
90
+ "temperature": temperature,
91
+ "top_k": top_k,
92
+ "top_p": top_p,
93
+ "max_new_tokens": max_new_tokens,
94
+ }
95
+
96
+ temperature = float(temperature)
97
+ if temperature < 1e-2:
98
+ temperature = 1e-2
99
+ top_p = float(top_p)
100
+
101
+ generate_kwargs = dict(
102
+ temperature=temperature,
103
+ max_new_tokens=max_new_tokens,
104
+ top_p=top_p,
105
+ repetition_penalty=repetition_penalty,
106
+ do_sample=True,
107
+ truncate=999,
108
+ seed=42,
109
+ )
110
+
111
+ stream = client.generate_stream(
112
+ prompt,
113
+ **generate_kwargs,
114
+ )
115
+
116
+ output = ""
117
+ for idx, response in enumerate(stream):
118
+ print(f'step {idx} - {response.token.text}')
119
+ if response.token.text == '':
120
+ break
121
+
122
+ if response.token.special:
123
+ continue
124
+ output += response.token.text
125
+ if idx == 0:
126
+ history.append(" " + output)
127
+ else:
128
+ history[-1] = output
129
+
130
+ chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)]
131
+
132
+ yield chat, history, user_message, ""
133
+
134
+ return chat, history, user_message, ""
135
+
136
+
137
+ examples = [
138
+ "A Llama entered in my garden, what should I do?"
139
+ ]
140
+
141
+
142
+ def clear_chat():
143
+ return [], []
144
+
145
+
146
+ def process_example(args):
147
+ for [x, y] in generate(args):
148
+ pass
149
+ return [x, y]
150
+
151
+
152
+ title = """<h1 align="center">Guanaco Playground πŸ’¬</h1>"""
153
+ custom_css = """
154
+ #banner-image {
155
+ display: block;
156
+ margin-left: auto;
157
+ margin-right: auto;
158
+ }
159
+ #chat-message {
160
+ font-size: 14px;
161
+ min-height: 300px;
162
+ }
163
+ """
164
+
165
+ with gr.Blocks(analytics_enabled=False, css=custom_css) as demo:
166
+ gr.HTML(title)
167
+
168
+ with gr.Row():
169
+ with gr.Column():
170
+ gr.Markdown(
171
+ """
172
+ πŸ’» This demo showcases the Guanaco 33B model, released together with the paper [QLoRA](https://arxiv.org/abs/2305.14314)
173
+ """
174
+ )
175
+
176
+ with gr.Row():
177
+ do_save = gr.Checkbox(
178
+ value=True,
179
+ label="Store data",
180
+ info="You agree to the storage of your prompt and generated text for research and development purposes:",
181
+ )
182
+ with gr.Accordion(label="System Prompt", open=False, elem_id="parameters-accordion"):
183
+ system_message = gr.Textbox(
184
+ elem_id="system-message",
185
+ placeholder="Below is a conversation between a human user and a helpful AI coding assistant.",
186
+ show_label=False,
187
+ )
188
+ with gr.Row():
189
+ with gr.Box():
190
+ output = gr.Markdown()
191
+ chatbot = gr.Chatbot(elem_id="chat-message", label="Chat")
192
+
193
+ with gr.Row():
194
+ with gr.Column(scale=3):
195
+ user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input")
196
+ with gr.Row():
197
+ send_button = gr.Button("Send", elem_id="send-btn", visible=True)
198
+
199
+ clear_chat_button = gr.Button("Clear chat", elem_id="clear-btn", visible=True)
200
+
201
+ with gr.Accordion(label="Parameters", open=False, elem_id="parameters-accordion"):
202
+ temperature = gr.Slider(
203
+ label="Temperature",
204
+ value=0.2,
205
+ minimum=0.0,
206
+ maximum=1.0,
207
+ step=0.1,
208
+ interactive=True,
209
+ info="Higher values produce more diverse outputs",
210
+ )
211
+ top_k = gr.Slider(
212
+ label="Top-k",
213
+ value=50,
214
+ minimum=0.0,
215
+ maximum=100,
216
+ step=1,
217
+ interactive=True,
218
+ info="Sample from a shortlist of top-k tokens",
219
+ )
220
+ top_p = gr.Slider(
221
+ label="Top-p (nucleus sampling)",
222
+ value=0.95,
223
+ minimum=0.0,
224
+ maximum=1,
225
+ step=0.05,
226
+ interactive=True,
227
+ info="Higher values sample more low-probability tokens",
228
+ )
229
+ max_new_tokens = gr.Slider(
230
+ label="Max new tokens",
231
+ value=512,
232
+ minimum=0,
233
+ maximum=1024,
234
+ step=4,
235
+ interactive=True,
236
+ info="The maximum numbers of new tokens",
237
+ )
238
+ repetition_penalty = gr.Slider(
239
+ label="Repetition Penalty",
240
+ value=1.2,
241
+ minimum=0.0,
242
+ maximum=10,
243
+ step=0.1,
244
+ interactive=True,
245
+ info="The parameter for repetition penalty. 1.0 means no penalty.",
246
+ )
247
+ with gr.Row():
248
+ gr.Examples(
249
+ examples=examples,
250
+ inputs=[user_message],
251
+ cache_examples=False,
252
+ fn=process_example,
253
+ outputs=[output],
254
+ )
255
+
256
+ history = gr.State([])
257
+ # To clear out "message" input textbox and use this to regenerate message
258
+ last_user_message = gr.State("")
259
+
260
+ user_message.submit(
261
+ generate,
262
+ inputs=[
263
+ system_message,
264
+ user_message,
265
+ chatbot,
266
+ history,
267
+ temperature,
268
+ top_k,
269
+ top_p,
270
+ max_new_tokens,
271
+ repetition_penalty,
272
+ do_save,
273
+ ],
274
+ outputs=[chatbot, history, last_user_message, user_message],
275
+ )
276
+
277
+ send_button.click(
278
+ generate,
279
+ inputs=[
280
+ system_message,
281
+ user_message,
282
+ chatbot,
283
+ history,
284
+ temperature,
285
+ top_k,
286
+ top_p,
287
+ max_new_tokens,
288
+ repetition_penalty,
289
+ do_save,
290
+ ],
291
+ outputs=[chatbot, history, last_user_message, user_message],
292
+ )
293
+
294
+ clear_chat_button.click(clear_chat, outputs=[chatbot, history])
295
+ # share_button.click(None, [], [], _js=share_js)
296
+
297
+ demo.queue(concurrency_count=16).launch(debug=True)