onuri commited on
Commit
e397ecb
1 Parent(s): a491f2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +313 -25
app.py CHANGED
@@ -1,28 +1,316 @@
 
 
1
  import gradio as gr
 
2
  from text_generation import Client, InferenceAPIClient
3
 
4
- model_name = "OpenAssistant/oasst-sft-1-pythia-12b"
5
- prompt = "Hello, how are you doing today?"
6
-
7
- # instantiate an InferenceAPIClient
8
- client = InferenceAPIClient(model_name)
9
-
10
- def chatbot(input_text):
11
- global prompt
12
-
13
- # concatenate the prompt and the input_text
14
- prompt += input_text
15
-
16
- # generate the response
17
- response = client.generate(prompt)
18
-
19
- # update the prompt with the generated response
20
- # prompt += response["generated_text"]
21
-
22
- # return the response
23
- return response["generated_text"]
24
-
25
- io = gr.Interface(fn=chatbot, inputs=gr.inputs.Textbox(), outputs=gr.outputs.Textbox(label="Chatty"))
26
-
27
- io.launch()
28
- s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
  import gradio as gr
4
+
5
  from text_generation import Client, InferenceAPIClient
6
 
7
+ openchat_preprompt = (
8
+ "\n<human>: Hi!\n<bot>: My name is Bot, model version is 0.15, part of an open-source kit for "
9
+ "fine-tuning new bots! I was created by Together, LAION, and Ontocord.ai and the open-source "
10
+ "community. I am not human, not evil and not alive, and thus have no thoughts and feelings, "
11
+ "but I am programmed to be helpful, polite, honest, and friendly.\n"
12
+ )
13
+
14
+
15
+ def get_client(model: str):
16
+ if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
17
+ return Client(os.getenv("OPENCHAT_API_URL"))
18
+ return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None))
19
+
20
+
21
+ def get_usernames(model: str):
22
+ """
23
+ Returns:
24
+ (str, str, str, str): pre-prompt, username, bot name, separator
25
+ """
26
+ if model == "OpenAssistant/oasst-sft-1-pythia-12b":
27
+ return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>"
28
+ if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
29
+ return openchat_preprompt, "<human>: ", "<bot>: ", "\n"
30
+ return "", "User: ", "Assistant: ", "\n"
31
+
32
+
33
+ def predict(
34
+ model: str,
35
+ inputs: str,
36
+ typical_p: float,
37
+ top_p: float,
38
+ temperature: float,
39
+ top_k: int,
40
+ repetition_penalty: float,
41
+ watermark: bool,
42
+ chatbot,
43
+ history,
44
+ ):
45
+ client = get_client(model)
46
+ preprompt, user_name, assistant_name, sep = get_usernames(model)
47
+
48
+ history.append(inputs)
49
+
50
+ past = []
51
+ for data in chatbot:
52
+ user_data, model_data = data
53
+
54
+ if not user_data.startswith(user_name):
55
+ user_data = user_name + user_data
56
+ if not model_data.startswith(sep + assistant_name):
57
+ model_data = sep + assistant_name + model_data
58
+
59
+ past.append(user_data + model_data.rstrip() + sep)
60
+
61
+ if not inputs.startswith(user_name):
62
+ inputs = user_name + inputs
63
+
64
+ total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip()
65
+
66
+ partial_words = ""
67
+
68
+ if model == "OpenAssistant/oasst-sft-1-pythia-12b":
69
+ iterator = client.generate_stream(
70
+ total_inputs,
71
+ typical_p=typical_p,
72
+ truncate=1000,
73
+ watermark=watermark,
74
+ max_new_tokens=500,
75
+ )
76
+ else:
77
+ iterator = client.generate_stream(
78
+ total_inputs,
79
+ top_p=top_p if top_p < 1.0 else None,
80
+ top_k=top_k,
81
+ truncate=1000,
82
+ repetition_penalty=repetition_penalty,
83
+ watermark=watermark,
84
+ temperature=temperature,
85
+ max_new_tokens=500,
86
+ stop_sequences=[user_name.rstrip(), assistant_name.rstrip()],
87
+ )
88
+
89
+ for i, response in enumerate(iterator):
90
+ if response.token.special:
91
+ continue
92
+
93
+ partial_words = partial_words + response.token.text
94
+ if partial_words.endswith(user_name.rstrip()):
95
+ partial_words = partial_words.rstrip(user_name.rstrip())
96
+ if partial_words.endswith(assistant_name.rstrip()):
97
+ partial_words = partial_words.rstrip(assistant_name.rstrip())
98
+
99
+ if i == 0:
100
+ history.append(" " + partial_words)
101
+ elif response.token.text not in user_name:
102
+ history[-1] = partial_words
103
+
104
+ chat = [
105
+ (history[i].strip(), history[i + 1].strip())
106
+ for i in range(0, len(history) - 1, 2)
107
+ ]
108
+ yield chat, history
109
+
110
+
111
+ def reset_textbox():
112
+ return gr.update(value="")
113
+
114
+
115
+ def radio_on_change(
116
+ value: str,
117
+ disclaimer,
118
+ typical_p,
119
+ top_p,
120
+ top_k,
121
+ temperature,
122
+ repetition_penalty,
123
+ watermark,
124
+ ):
125
+ if value == "OpenAssistant/oasst-sft-1-pythia-12b":
126
+ typical_p = typical_p.update(value=0.2, visible=True)
127
+ top_p = top_p.update(visible=False)
128
+ top_k = top_k.update(visible=False)
129
+ temperature = temperature.update(visible=False)
130
+ disclaimer = disclaimer.update(visible=False)
131
+ repetition_penalty = repetition_penalty.update(visible=False)
132
+ watermark = watermark.update(False)
133
+ elif value == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
134
+ typical_p = typical_p.update(visible=False)
135
+ top_p = top_p.update(value=0.25, visible=True)
136
+ top_k = top_k.update(value=50, visible=True)
137
+ temperature = temperature.update(value=0.6, visible=True)
138
+ repetition_penalty = repetition_penalty.update(value=1.01, visible=True)
139
+ watermark = watermark.update(False)
140
+ disclaimer = disclaimer.update(visible=True)
141
+ else:
142
+ typical_p = typical_p.update(visible=False)
143
+ top_p = top_p.update(value=0.95, visible=True)
144
+ top_k = top_k.update(value=4, visible=True)
145
+ temperature = temperature.update(value=0.5, visible=True)
146
+ repetition_penalty = repetition_penalty.update(value=1.03, visible=True)
147
+ watermark = watermark.update(True)
148
+ disclaimer = disclaimer.update(visible=False)
149
+ return (
150
+ disclaimer,
151
+ typical_p,
152
+ top_p,
153
+ top_k,
154
+ temperature,
155
+ repetition_penalty,
156
+ watermark,
157
+ )
158
+
159
+
160
+ title = """<h1 align="center">Large Language Model Chat API</h1>"""
161
+ description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
162
+ ```
163
+ User: <utterance>
164
+ Assistant: <utterance>
165
+ User: <utterance>
166
+ Assistant: <utterance>
167
+ ...
168
+ ```
169
+ In this app, you can explore the outputs of multiple LLMs when prompted in this way.
170
+ """
171
+
172
+ text_generation_inference = """
173
+ <div align="center">Powered by: <a href=https://github.com/huggingface/text-generation-inference>Text Generation Inference</a></div>
174
+ """
175
+
176
+ openchat_disclaimer = """
177
+ <div align="center">Checkout the official <a href=https://huggingface.co/spaces/togethercomputer/OpenChatKit>OpenChatKit feedback app</a> for the full experience.</div>
178
+ """
179
+
180
+ with gr.Blocks(
181
+ css="""#col_container {margin-left: auto; margin-right: auto;}
182
+ #chatbot {height: 520px; overflow: auto;}"""
183
+ ) as demo:
184
+ gr.HTML(title)
185
+ gr.Markdown(text_generation_inference, visible=True)
186
+ with gr.Column(elem_id="col_container"):
187
+ model = gr.Radio(
188
+ value="OpenAssistant/oasst-sft-1-pythia-12b",
189
+ choices=[
190
+ "OpenAssistant/oasst-sft-1-pythia-12b",
191
+ # "togethercomputer/GPT-NeoXT-Chat-Base-20B",
192
+ "google/flan-t5-xxl",
193
+ "google/flan-ul2",
194
+ "bigscience/bloom",
195
+ "bigscience/bloomz",
196
+ "EleutherAI/gpt-neox-20b",
197
+ ],
198
+ label="Model",
199
+ interactive=True,
200
+ )
201
+
202
+ chatbot = gr.Chatbot(elem_id="chatbot")
203
+ inputs = gr.Textbox(
204
+ placeholder="Hi there!", label="Type an input and press Enter"
205
+ )
206
+ disclaimer = gr.Markdown(openchat_disclaimer, visible=False)
207
+ state = gr.State([])
208
+ b1 = gr.Button()
209
+
210
+ with gr.Accordion("Parameters", open=False):
211
+ typical_p = gr.Slider(
212
+ minimum=-0,
213
+ maximum=1.0,
214
+ value=0.2,
215
+ step=0.05,
216
+ interactive=True,
217
+ label="Typical P mass",
218
+ )
219
+ top_p = gr.Slider(
220
+ minimum=-0,
221
+ maximum=1.0,
222
+ value=0.25,
223
+ step=0.05,
224
+ interactive=True,
225
+ label="Top-p (nucleus sampling)",
226
+ visible=False,
227
+ )
228
+ temperature = gr.Slider(
229
+ minimum=-0,
230
+ maximum=5.0,
231
+ value=0.6,
232
+ step=0.1,
233
+ interactive=True,
234
+ label="Temperature",
235
+ visible=False,
236
+ )
237
+ top_k = gr.Slider(
238
+ minimum=1,
239
+ maximum=50,
240
+ value=50,
241
+ step=1,
242
+ interactive=True,
243
+ label="Top-k",
244
+ visible=False,
245
+ )
246
+ repetition_penalty = gr.Slider(
247
+ minimum=0.1,
248
+ maximum=3.0,
249
+ value=1.03,
250
+ step=0.01,
251
+ interactive=True,
252
+ label="Repetition Penalty",
253
+ visible=False,
254
+ )
255
+ watermark = gr.Checkbox(value=False, label="Text watermarking")
256
+
257
+ model.change(
258
+ lambda value: radio_on_change(
259
+ value,
260
+ disclaimer,
261
+ typical_p,
262
+ top_p,
263
+ top_k,
264
+ temperature,
265
+ repetition_penalty,
266
+ watermark,
267
+ ),
268
+ inputs=model,
269
+ outputs=[
270
+ disclaimer,
271
+ typical_p,
272
+ top_p,
273
+ top_k,
274
+ temperature,
275
+ repetition_penalty,
276
+ watermark,
277
+ ],
278
+ )
279
+
280
+ inputs.submit(
281
+ predict,
282
+ [
283
+ model,
284
+ inputs,
285
+ typical_p,
286
+ top_p,
287
+ temperature,
288
+ top_k,
289
+ repetition_penalty,
290
+ watermark,
291
+ chatbot,
292
+ state,
293
+ ],
294
+ [chatbot, state],
295
+ )
296
+ b1.click(
297
+ predict,
298
+ [
299
+ model,
300
+ inputs,
301
+ typical_p,
302
+ top_p,
303
+ temperature,
304
+ top_k,
305
+ repetition_penalty,
306
+ watermark,
307
+ chatbot,
308
+ state,
309
+ ],
310
+ [chatbot, state],
311
+ )
312
+ b1.click(reset_textbox, [], [inputs])
313
+ inputs.submit(reset_textbox, [], [inputs])
314
+
315
+ gr.Markdown(description)
316
+ demo.queue(concurrency_count=16).launch(debug=True)