ffreemt commited on
Commit
c90a13a
1 Parent(s): 54a54e6

Update user hot for streaming

Browse files
Files changed (1) hide show
  1. app.py +58 -131
app.py CHANGED
@@ -3,6 +3,7 @@
3
  # ruff: noqa: E501
4
  import os
5
  import platform
 
6
  import time
7
  from dataclasses import asdict, dataclass
8
  from pathlib import Path
@@ -108,9 +109,9 @@ LLM = None
108
 
109
  try:
110
  model_loc, file_size = dl_hf_model(url)
111
- except Exception as exc:
112
- logger.erorr(exc)
113
- raise SystemExit(1) from exc
114
 
115
  LLM = AutoModelForCausalLM.from_pretrained(
116
  model_loc,
@@ -150,19 +151,16 @@ class GenerationConfig:
150
  def generate(
151
  question: str,
152
  llm=LLM,
153
- generation_config: GenerationConfig = GenerationConfig(),
154
  ):
155
  """Run model inference, will return a Generator if streaming is true."""
156
  # _ = prompt_template.format(question=question)
157
  # print(_)
158
 
159
- config = GenerationConfig(reset=True) # rid of OOM?
160
-
161
  prompt = prompt_template.format(question=question)
162
 
163
  return llm(
164
  prompt,
165
- # **asdict(generation_config),
166
  **asdict(config),
167
  )
168
 
@@ -170,107 +168,64 @@ def generate(
170
  logger.debug(f"{asdict(GenerationConfig())=}")
171
 
172
 
173
- def predict_str(prompt, bot): # bot is in fact bot_history
174
- # logger.debug(f"{prompt=}, {bot=}, {timeout=}")
175
-
176
- if bot is None:
177
- bot = []
178
-
179
- logger.debug(f"{prompt=}, {bot=}")
180
 
181
- try:
182
- # user_prompt = prompt
183
- generator = generate(
184
- prompt,
185
- )
186
 
187
- ns.generator = generator # for .then
 
 
 
 
 
 
 
 
 
188
 
189
- except Exception as exc:
190
- logger.error(exc)
191
 
192
- # bot.append([prompt, f"{response} {_}"])
193
- # return prompt, bot
194
 
195
- _ = bot + [[prompt, None]]
196
- logger.debug(f"{prompt=}, {_=}")
 
197
 
198
- return prompt, _
199
 
 
 
 
 
200
 
201
- def bot_str(bot):
202
- if bot:
203
- bot[-1][1] = ""
204
- else:
205
- bot = [["Something is wrong", ""]]
206
-
207
- response = ""
208
-
209
- flag = 1
210
- then = time.time()
211
- for word in ns.generator:
212
- # record first response time
213
- if flag:
214
- logger.debug(f"\t {time.time() - then:.1f}s")
215
- flag = 0
216
- print(word, end="", flush=True)
217
- # print(word, flush=True) # vertical stream
218
- response += word
219
- bot[-1][1] = response
220
- yield bot
221
 
 
 
 
 
 
 
 
 
222
 
223
- def predict(prompt, bot):
224
- # logger.debug(f"{prompt=}, {bot=}, {timeout=}")
225
- logger.debug(f"{prompt=}, {bot=}")
226
 
227
- ns.response = ""
228
- then = time.time()
229
- with about_time() as atime: # type: ignore
230
- try:
231
- # user_prompt = prompt
232
- generator = generate(
233
- prompt,
234
- )
235
-
236
- ns.generator = generator # for .then
237
-
238
- print("--", end=" ", flush=True)
239
-
240
- response = ""
241
-
242
- flag = 1
243
- for word in generator:
244
- # record first response time
245
- if flag:
246
- fisrt_arr = f"{time.time() - then:.1f}s"
247
- logger.debug(f"\t 1st arrival: {fisrt_arr}")
248
- flag = 0
249
- print(word, end="", flush=True)
250
- # print(word, flush=True) # vertical stream
251
- response += word
252
- ns.response = f"({fisrt_arr}){response}"
253
- print("")
254
- logger.debug(f"{response=}")
255
- except Exception as exc:
256
- logger.error(exc)
257
- response = f"{exc=}"
258
-
259
- # bot = {"inputs": [response]}
260
  _ = (
261
  f"(time elapsed: {atime.duration_human}, " # type: ignore
262
- f"{atime.duration/(len(prompt) + len(response)):.1f}s/char)" # type: ignore
263
  )
264
 
265
- ns.response = ""
266
- bot.append([prompt, f"{response} \n{_}"])
267
-
268
- return prompt, bot
269
 
270
 
271
  def predict_api(prompt):
272
  logger.debug(f"{prompt=}")
273
- ns.response = ""
274
  try:
275
  # user_prompt = prompt
276
  _ = GenerationConfig(
@@ -280,8 +235,8 @@ def predict_api(prompt):
280
  repetition_penalty=1.0,
281
  max_new_tokens=512, # adjust as needed
282
  seed=42,
283
- reset=False, # reset history (cache)
284
- stream=True,
285
  threads=cpu_count,
286
  # stop=prompt_prefix[1:2],
287
  )
@@ -294,7 +249,6 @@ def predict_api(prompt):
294
  for word in generator:
295
  print(word, end="", flush=True)
296
  response += word
297
- ns.response = response
298
  print("")
299
  logger.debug(f"{response=}")
300
  except Exception as exc:
@@ -306,10 +260,6 @@ def predict_api(prompt):
306
  return response
307
 
308
 
309
- def update_buff():
310
- return ns.response
311
-
312
-
313
  css = """
314
  .importantButton {
315
  background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
@@ -368,6 +318,7 @@ with gr.Blocks(
368
  theme=gr.themes.Soft(text_size="sm", spacing_size="sm"),
369
  css=css,
370
  ) as block:
 
371
  with gr.Accordion("🎈 Info", open=False):
372
  # gr.HTML(
373
  # """<center><a href="https://huggingface.co/spaces/mikeee/mpt-30b-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate"></a> and spin a CPU UPGRADE to avoid the queue</center>"""
@@ -382,7 +333,9 @@ with gr.Blocks(
382
 
383
  # chatbot = gr.Chatbot().style(height=700) # 500
384
  chatbot = gr.Chatbot(height=500)
385
- buff = gr.Textbox(show_label=False, visible=True)
 
 
386
  with gr.Row():
387
  with gr.Column(scale=5):
388
  msg = gr.Textbox(
@@ -433,52 +386,25 @@ with gr.Blocks(
433
  "biased, or otherwise offensive outputs.",
434
  elem_classes=["disclaimer"],
435
  )
436
- # _ = """
437
- msg.submit(
438
- # fn=conversation.user_turn,
439
- fn=predict,
440
- inputs=[msg, chatbot],
441
- outputs=[msg, chatbot],
442
- # queue=True,
443
- show_progress="full",
444
- api_name="predict",
445
- )
446
- submit.click(
447
- fn=lambda x, y: ("",) + predict(x, y)[1:], # clear msg
448
- inputs=[msg, chatbot],
449
- outputs=[msg, chatbot],
450
- queue=True,
451
- show_progress="full",
452
- )
453
- # """
454
 
455
- _ = """
456
  msg.submit(
457
  # fn=conversation.user_turn,
458
- fn=predict_str,
459
  inputs=[msg, chatbot],
460
  outputs=[msg, chatbot],
461
- queue=True,
462
  show_progress="full",
463
- api_name="predict",
464
- ).then(bot_str, chatbot, chatbot)
465
  submit.click(
466
- fn=lambda x, y: ("",) + predict_str(x, y)[1:], # clear msg
467
  inputs=[msg, chatbot],
468
  outputs=[msg, chatbot],
469
- queue=True,
470
  show_progress="full",
471
- ).then(bot_str, chatbot, chatbot)
472
- # """
473
 
474
  clear.click(lambda: None, None, chatbot, queue=False)
475
 
476
- # update buff Textbox, every: units in seconds)
477
- # https://huggingface.co/spaces/julien-c/nvidia-smi/discussions
478
- # does not work
479
- # AttributeError: 'Blocks' object has no attribute 'run_forever'
480
- # block.run_forever(lambda: ns.response, None, [buff], every=1)
481
-
482
  with gr.Accordion("For Chat/Translation API", open=False, visible=False):
483
  input_text = gr.Text()
484
  api_btn = gr.Button("Go", variant="primary")
@@ -491,7 +417,8 @@ with gr.Blocks(
491
  api_name="api",
492
  )
493
 
494
- block.load(update_buff, [], buff, every=1)
 
495
 
496
  # concurrency_count=5, max_size=20
497
  # max_size=36, concurrency_count=14
 
3
  # ruff: noqa: E501
4
  import os
5
  import platform
6
+ import random
7
  import time
8
  from dataclasses import asdict, dataclass
9
  from pathlib import Path
 
109
 
110
  try:
111
  model_loc, file_size = dl_hf_model(url)
112
+ except Exception as exc_:
113
+ logger.error(exc_)
114
+ raise SystemExit(1) from exc_
115
 
116
  LLM = AutoModelForCausalLM.from_pretrained(
117
  model_loc,
 
151
  def generate(
152
  question: str,
153
  llm=LLM,
154
+ config: GenerationConfig = GenerationConfig(),
155
  ):
156
  """Run model inference, will return a Generator if streaming is true."""
157
  # _ = prompt_template.format(question=question)
158
  # print(_)
159
 
 
 
160
  prompt = prompt_template.format(question=question)
161
 
162
  return llm(
163
  prompt,
 
164
  **asdict(config),
165
  )
166
 
 
168
  logger.debug(f"{asdict(GenerationConfig())=}")
169
 
170
 
171
+ def user(user_message, history):
172
+ # return user_message, history + [[user_message, None]]
173
+ history.append([user_message, None])
174
+ return user_message, history
 
 
 
175
 
 
 
 
 
 
176
 
177
+ def bot_(history):
178
+ user_message = history[-1][0]
179
+ resp = random.choice(["How are you?", "I love you", "I'm very hungry"])
180
+ bot_message = user_message + ": " + resp
181
+ history[-1][1] = ""
182
+ for character in bot_message:
183
+ history[-1][1] += character
184
+ ns.buff = history[-1][1]
185
+ time.sleep(0.02)
186
+ yield history
187
 
188
+ history[-1][1] = resp
189
+ yield history
190
 
 
 
191
 
192
+ def bot(history):
193
+ user_message = history[-1][0]
194
+ response = []
195
 
196
+ logger.debug(f"{user_message=}")
197
 
198
+ with about_time() as atime: # type: ignore
199
+ flag = 1
200
+ prefix = ""
201
+ then = time.time()
202
 
203
+ logger.debug("about to generate")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ config = GenerationConfig(reset=True)
206
+ for elm in generate(user_message, config=config):
207
+ if flag == 1:
208
+ logger.debug("in the loop")
209
+ prefix = f"({time.time() - then:.2f}s) "
210
+ flag = 0
211
+ print(prefix, end="", flush=True)
212
+ print(elm, end="", flush=True)
213
 
214
+ response.append(elm)
215
+ history[-1][1] = prefix + "".join(response)
216
+ yield history
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  _ = (
219
  f"(time elapsed: {atime.duration_human}, " # type: ignore
220
+ f"{atime.duration/len(''.join(response)):.1f}s/char)" # type: ignore
221
  )
222
 
223
+ history[-1][1] = "".join(response)
224
+ yield history
 
 
225
 
226
 
227
  def predict_api(prompt):
228
  logger.debug(f"{prompt=}")
 
229
  try:
230
  # user_prompt = prompt
231
  _ = GenerationConfig(
 
235
  repetition_penalty=1.0,
236
  max_new_tokens=512, # adjust as needed
237
  seed=42,
238
+ reset=True, # reset history (cache)
239
+ stream=False,
240
  threads=cpu_count,
241
  # stop=prompt_prefix[1:2],
242
  )
 
249
  for word in generator:
250
  print(word, end="", flush=True)
251
  response += word
 
252
  print("")
253
  logger.debug(f"{response=}")
254
  except Exception as exc:
 
260
  return response
261
 
262
 
 
 
 
 
263
  css = """
264
  .importantButton {
265
  background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
 
318
  theme=gr.themes.Soft(text_size="sm", spacing_size="sm"),
319
  css=css,
320
  ) as block:
321
+ # buff_var = gr.State("")
322
  with gr.Accordion("🎈 Info", open=False):
323
  # gr.HTML(
324
  # """<center><a href="https://huggingface.co/spaces/mikeee/mpt-30b-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate"></a> and spin a CPU UPGRADE to avoid the queue</center>"""
 
333
 
334
  # chatbot = gr.Chatbot().style(height=700) # 500
335
  chatbot = gr.Chatbot(height=500)
336
+
337
+ # buff = gr.Textbox(show_label=False, visible=True)
338
+
339
  with gr.Row():
340
  with gr.Column(scale=5):
341
  msg = gr.Textbox(
 
386
  "biased, or otherwise offensive outputs.",
387
  elem_classes=["disclaimer"],
388
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
 
390
  msg.submit(
391
  # fn=conversation.user_turn,
392
+ fn=user,
393
  inputs=[msg, chatbot],
394
  outputs=[msg, chatbot],
395
+ # queue=True,
396
  show_progress="full",
397
+ ).then(bot, chatbot, chatbot)
 
398
  submit.click(
399
+ fn=lambda x, y: ("",) + user(x, y)[1:], # clear msg
400
  inputs=[msg, chatbot],
401
  outputs=[msg, chatbot],
402
+ # queue=True,
403
  show_progress="full",
404
+ ).then(bot, chatbot, chatbot)
 
405
 
406
  clear.click(lambda: None, None, chatbot, queue=False)
407
 
 
 
 
 
 
 
408
  with gr.Accordion("For Chat/Translation API", open=False, visible=False):
409
  input_text = gr.Text()
410
  api_btn = gr.Button("Go", variant="primary")
 
417
  api_name="api",
418
  )
419
 
420
+ # block.load(update_buff, [], buff, every=1)
421
+ # block.load(update_buff, [buff_var], [buff_var, buff], every=1)
422
 
423
  # concurrency_count=5, max_size=20
424
  # max_size=36, concurrency_count=14