cryscan commited on
Commit
a33184e
1 Parent(s): 1834a63

Rework demo UI.

Browse files
Files changed (1) hide show
  1. app.py +96 -46
app.py CHANGED
@@ -110,22 +110,22 @@ Arrange the given numbers in ascending order.
110
  ["Simply put, the theory of relativity states that", 150, 1.0, 0.5, 0.2, 0.2],
111
  ]
112
 
113
- infer_interface = gr.Interface(
114
- fn=infer,
115
- description=f'''{desc} <b>Please try examples first (bottom of page)</b> (edit them to use your question). Demo limited to ctxlen {ctx_limit}.''',
116
- allow_flagging="never",
117
- inputs=[
118
- gr.Textbox(lines=10, label="Prompt", value="Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n"), # prompt
119
- gr.Slider(10, 200, step=10, value=150), # token_count
120
- gr.Slider(0.2, 2.0, step=0.1, value=1.0), # temperature
121
- gr.Slider(0.0, 1.0, step=0.05, value=0.7), # top_p
122
- gr.Slider(0.0, 1.0, step=0.1, value=0.2), # presencePenalty
123
- gr.Slider(0.0, 1.0, step=0.1, value=0.2), # countPenalty
124
- ],
125
- outputs=gr.Textbox(label="Generated Output", lines=28),
126
- examples=examples,
127
- cache_examples=False,
128
- ).queue()
129
 
130
  ########################################################################################################
131
 
@@ -159,8 +159,12 @@ She also likes to tell {user} a lot about herself and her opinions, and she usua
159
 
160
  _, intro_state = model.forward(pipeline.encode(chat_intro), None)
161
 
 
 
 
 
162
  def chat(
163
- message: str,
164
  history,
165
  token_count=10,
166
  temperature=1.0,
@@ -174,6 +178,7 @@ def chat(
174
  token_ban=[], # ban the generation of some tokens
175
  token_stop=[]) # stop generation whenever you see any token here
176
 
 
177
  message = message.strip(' ')
178
  message = message.replace('\n', '')
179
  ctx = f"{user}{interface} {message}\n\n{bot}{interface}"
@@ -181,9 +186,9 @@ def chat(
181
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
182
  print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
183
 
184
- history = history or [[], intro_state, []] # [chat, state, all_tokens]
185
 
186
- [chat_log, state, all_tokens] = history
187
  out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:], state)
188
 
189
  begin = len(all_tokens)
@@ -230,35 +235,80 @@ def chat(
230
  gc.collect()
231
  torch.cuda.empty_cache()
232
 
233
- chat_log.append((message, out_str.strip()))
234
- history = [chat_log, state, all_tokens]
235
- return chat_log, history
236
-
237
- chat_interface = gr.Interface(
238
- fn=chat,
239
- description=f'''You are {user}, bot is {bot}.''',
240
- allow_flagging="never",
241
- inputs = [
242
- gr.Textbox(label="Message"),
243
- "state",
244
- gr.Slider(10, 1000, step=10, value=250), # token_count
245
- gr.Slider(0.2, 2.0, step=0.1, value=1.0), # temperature
246
- gr.Slider(0.0, 1.0, step=0.05, value=0.8), # top_p
247
- gr.Slider(0.0, 1.0, step=0.1, value=0.2), # presence_penalty
248
- gr.Slider(0.0, 1.0, step=0.1, value=0.2), # count_penalty
249
- ],
250
- outputs=[
251
- gr.Chatbot(label="Chat Log", color_map=("blue", "pink")),
252
- "state"
253
- ]
254
- ).queue()
255
 
256
  ########################################################################################################
257
 
258
- demo = gr.TabbedInterface(
259
- [infer_interface, chat_interface], ["Generative", "Chat"],
260
- title=title,
261
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  demo.queue(max_size=10)
264
- demo.launch(share=True)
 
110
  ["Simply put, the theory of relativity states that", 150, 1.0, 0.5, 0.2, 0.2],
111
  ]
112
 
113
+ # infer_interface = gr.Interface(
114
+ # fn=infer,
115
+ # description=f'''{desc} <b>Please try examples first (bottom of page)</b> (edit them to use your question). Demo limited to ctxlen {ctx_limit}.''',
116
+ # allow_flagging="never",
117
+ # inputs=[
118
+ # gr.Textbox(lines=10, label="Prompt", value="Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n"), # prompt
119
+ # gr.Slider(10, 200, step=10, value=150), # token_count
120
+ # gr.Slider(0.2, 2.0, step=0.1, value=1.0), # temperature
121
+ # gr.Slider(0.0, 1.0, step=0.05, value=0.7), # top_p
122
+ # gr.Slider(0.0, 1.0, step=0.1, value=0.2), # presencePenalty
123
+ # gr.Slider(0.0, 1.0, step=0.1, value=0.2), # countPenalty
124
+ # ],
125
+ # outputs=gr.Textbox(label="Generated Output", lines=28),
126
+ # examples=examples,
127
+ # cache_examples=False,
128
+ # ).queue()
129
 
130
  ########################################################################################################
131
 
 
159
 
160
  _, intro_state = model.forward(pipeline.encode(chat_intro), None)
161
 
162
+ def user(user_message, chatbot):
163
+ chatbot = chatbot or []
164
+ return "", chatbot + [[user_message, None]]
165
+
166
  def chat(
167
+ chatbot,
168
  history,
169
  token_count=10,
170
  temperature=1.0,
 
178
  token_ban=[], # ban the generation of some tokens
179
  token_stop=[]) # stop generation whenever you see any token here
180
 
181
+ message = chatbot[-1][0]
182
  message = message.strip(' ')
183
  message = message.replace('\n', '')
184
  ctx = f"{user}{interface} {message}\n\n{bot}{interface}"
 
186
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
187
  print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
188
 
189
+ history = history or [intro_state, []] # [chat, state, all_tokens]
190
 
191
+ [state, all_tokens] = history
192
  out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:], state)
193
 
194
  begin = len(all_tokens)
 
235
  gc.collect()
236
  torch.cuda.empty_cache()
237
 
238
+ chatbot[-1][1] = out_str.strip()
239
+ history = [state, all_tokens]
240
+ return chatbot, history
241
+
242
+ # chat_interface = gr.Interface(
243
+ # fn=chat,
244
+ # description=f'''You are {user}, bot is {bot}.''',
245
+ # allow_flagging="never",
246
+ # inputs = [
247
+ # gr.Textbox(label="Message"),
248
+ # "state",
249
+ # gr.Slider(10, 1000, step=10, value=250), # token_count
250
+ # gr.Slider(0.2, 2.0, step=0.1, value=1.0), # temperature
251
+ # gr.Slider(0.0, 1.0, step=0.05, value=0.8), # top_p
252
+ # gr.Slider(0.0, 1.0, step=0.1, value=0.2), # presence_penalty
253
+ # gr.Slider(0.0, 1.0, step=0.1, value=0.2), # count_penalty
254
+ # ],
255
+ # outputs=[
256
+ # gr.Chatbot(label="Chat Log", color_map=("blue", "pink")),
257
+ # "state"
258
+ # ]
259
+ # ).queue()
260
 
261
  ########################################################################################################
262
 
263
+ # demo = gr.TabbedInterface(
264
+ # [infer_interface, chat_interface], ["Generative", "Chat"],
265
+ # title=title,
266
+ # )
267
+
268
+ # demo.queue(max_size=10)
269
+ # demo.launch(share=True)
270
+
271
+ with gr.Blocks() as demo:
272
+ with gr.Tab("Generative"):
273
+ with gr.Row():
274
+ with gr.Column():
275
+ prompt = gr.Textbox(lines=10, label="Prompt", value="Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n")
276
+ token_count = gr.Slider(10, 1000, label="Max Token", step=10, value=250)
277
+ temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
278
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.8)
279
+ presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.2)
280
+ count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.2)
281
+ with gr.Column():
282
+ with gr.Row():
283
+ submit = gr.Button("Submit")
284
+ clear = gr.Button("Clear")
285
+ output = gr.Textbox(label="Generated Output", lines=28)
286
+ data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, label="Example Prompts", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
287
+ submit.click(infer, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
288
+ clear.click(lambda: None, [], [output])
289
+ data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
290
+ with gr.Tab("Chat"):
291
+ with gr.Row():
292
+ with gr.Column():
293
+ chatbot = gr.Chatbot()
294
+ state = gr.State()
295
+ message = gr.Textbox(label="Message")
296
+ with gr.Row():
297
+ send = gr.Button("Send")
298
+ clear = gr.Button("Clear")
299
+ with gr.Column():
300
+ token_count = gr.Slider(10, 1000, label="Max Token", step=10, value=250)
301
+ temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
302
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.8)
303
+ presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.2)
304
+ count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.2)
305
+ message.submit(user, [message, chatbot], [message, chatbot], queue=False).then(
306
+ chat, [chatbot, state, token_count, temperature, top_p, presence_penalty, count_penalty], [chatbot, state]
307
+ )
308
+ send.click(user, [message, chatbot], [message, chatbot], queue=False).then(
309
+ chat, [chatbot, state, token_count, temperature, top_p, presence_penalty, count_penalty], [chatbot, state]
310
+ )
311
+ clear.click(lambda: ([], None, ""), [], [chatbot, state, message])
312
 
313
  demo.queue(max_size=10)
314
+ demo.launch(share=False)