vpcom commited on
Commit
8c6dcfa
1 Parent(s): 8b23b39

fix: wrong place; use the stream_fn instead

Browse files
Files changed (1) hide show
  1. app.py +36 -3
app.py CHANGED
@@ -507,21 +507,54 @@ class ChatInterface(gr.ChatInterface):
507
  clear_btn = clear_btn,
508
  autofocus = autofocus,
509
  )
510
-
511
- async def _submit_fn(
512
  self,
513
  message: str,
514
  history_with_input: list[list[str | None]],
515
  request: Request,
516
  *args,
517
- ) -> tuple[list[list[str | None]], list[list[str | None]]]:
518
  history = history_with_input[:-1]
 
519
  print(f'Message is {message}')
520
  if len(message)==0:
521
  message = random.choice(["ا","ب","پ","ت","ث","ج","چ","ح","خ","ل","م","ن","و",
522
  "د","ذ","ر","ز","ژ","س","ش","ص","ض","ط","ظ","ع","غ",
523
  "ف","ق","ه","ی",
524
  ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  inputs, _, _ = special_args(
526
  self.fn, inputs=[message, history, *args], request=request
527
  )
 
507
  clear_btn = clear_btn,
508
  autofocus = autofocus,
509
  )
510
+
511
+ async def _stream_fn(
512
  self,
513
  message: str,
514
  history_with_input: list[list[str | None]],
515
  request: Request,
516
  *args,
517
+ ) -> AsyncGenerator:
518
  history = history_with_input[:-1]
519
+
520
  print(f'Message is {message}')
521
  if len(message)==0:
522
  message = random.choice(["ا","ب","پ","ت","ث","ج","چ","ح","خ","ل","م","ن","و",
523
  "د","ذ","ر","ز","ژ","س","ش","ص","ض","ط","ظ","ع","غ",
524
  "ف","ق","ه","ی",
525
  ])
526
+
527
+ inputs, _, _ = special_args(
528
+ self.fn, inputs=[message, history, *args], request=request
529
+ )
530
+
531
+ if self.is_async:
532
+ generator = self.fn(*inputs)
533
+ else:
534
+ generator = await anyio.to_thread.run_sync(
535
+ self.fn, *inputs, limiter=self.limiter
536
+ )
537
+ generator = SyncToAsyncIterator(generator, self.limiter)
538
+ try:
539
+ first_response = await async_iteration(generator)
540
+ update = history + [[message, first_response]]
541
+ yield update, update
542
+ except StopIteration:
543
+ update = history + [[message, None]]
544
+ yield update, update
545
+ async for response in generator:
546
+ update = history + [[message, response]]
547
+ yield update, update
548
+
549
+ async def _submit_fn(
550
+ self,
551
+ message: str,
552
+ history_with_input: list[list[str | None]],
553
+ request: Request,
554
+ *args,
555
+ ) -> tuple[list[list[str | None]], list[list[str | None]]]:
556
+ history = history_with_input[:-1]
557
+
558
  inputs, _, _ = special_args(
559
  self.fn, inputs=[message, history, *args], request=request
560
  )