layerdiffusion commited on
Commit
a1c7119
1 Parent(s): a9cb4cf
Files changed (2) hide show
  1. app.py +16 -1
  2. chat_interface.py +16 -10
app.py CHANGED
@@ -28,6 +28,7 @@ from diffusers.models.attention_processor import AttnProcessor2_0
28
  from transformers import CLIPTextModel, CLIPTokenizer
29
  from lib_omost.pipeline import StableDiffusionXLOmostPipeline
30
  from chat_interface import ChatInterface
 
31
 
32
  import lib_omost.canvas as omost_canvas
33
 
@@ -130,9 +131,23 @@ def chat_fn(message: str, history: list, seed:int, temperature: float, top_p: fl
130
 
131
  streamer = TextIteratorStreamer(llm_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  generate_kwargs = dict(
134
  input_ids=input_ids,
135
  streamer=streamer,
 
136
  max_new_tokens=max_new_tokens,
137
  do_sample=True,
138
  temperature=temperature,
@@ -148,7 +163,7 @@ def chat_fn(message: str, history: list, seed:int, temperature: float, top_p: fl
148
  for text in streamer:
149
  outputs.append(text)
150
  # print(outputs)
151
- yield "".join(outputs)
152
 
153
  print(f'Chat end at {time.time() - time_stamp:.2f} seconds:', message)
154
  return
 
28
  from transformers import CLIPTextModel, CLIPTokenizer
29
  from lib_omost.pipeline import StableDiffusionXLOmostPipeline
30
  from chat_interface import ChatInterface
31
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
32
 
33
  import lib_omost.canvas as omost_canvas
34
 
 
131
 
132
  streamer = TextIteratorStreamer(llm_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
133
 
134
+ def interactive_stopping_criteria(input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
135
+ if getattr(streamer, 'user_interrupted', False):
136
+ print('User stopped generation')
137
+ return True
138
+ else:
139
+ return False
140
+
141
+ stopping_criteria = StoppingCriteriaList([interactive_stopping_criteria])
142
+
143
+ def interrupter():
144
+ streamer.user_interrupted = True
145
+ return
146
+
147
  generate_kwargs = dict(
148
  input_ids=input_ids,
149
  streamer=streamer,
150
+ stopping_criteria=stopping_criteria,
151
  max_new_tokens=max_new_tokens,
152
  do_sample=True,
153
  temperature=temperature,
 
163
  for text in streamer:
164
  outputs.append(text)
165
  # print(outputs)
166
+ yield "".join(outputs), interrupter
167
 
168
  print(f'Chat end at {time.time() - time_stamp:.2f} seconds:', message)
169
  return
chat_interface.py CHANGED
@@ -20,7 +20,7 @@ from gradio.components import (
20
  State,
21
  Textbox,
22
  get_component_instance,
23
- Dataset
24
  )
25
  from gradio.events import Dependency, on
26
  from gradio.helpers import special_args
@@ -103,6 +103,8 @@ class ChatInterface(Blocks):
103
  self.pre_fn = pre_fn
104
  self.pre_fn_kwargs = pre_fn_kwargs
105
 
 
 
106
  self.multimodal = multimodal
107
  self.concurrency_limit = concurrency_limit
108
  self.fn = fn
@@ -287,7 +289,7 @@ class ChatInterface(Blocks):
287
  .then(
288
  submit_fn,
289
  [self.saved_input, self.chatbot_state] + self.additional_inputs,
290
- [self.chatbot, self.chatbot_state],
291
  show_api=False,
292
  concurrency_limit=cast(
293
  Union[int, Literal["default"], None], self.concurrency_limit
@@ -395,6 +397,11 @@ class ChatInterface(Blocks):
395
  def _setup_stop_events(
396
  self, event_triggers: list[Callable], event_to_cancel: Dependency
397
  ) -> None:
 
 
 
 
 
398
  if self.stop_btn and self.is_generator:
399
  if self.submit_btn:
400
  for event_trigger in event_triggers:
@@ -434,9 +441,8 @@ class ChatInterface(Blocks):
434
  queue=False,
435
  )
436
  self.stop_btn.click(
437
- None,
438
- None,
439
- None,
440
  cancels=event_to_cancel,
441
  show_api=False,
442
  )
@@ -545,7 +551,7 @@ class ChatInterface(Blocks):
545
  )
546
  generator = SyncToAsyncIterator(generator, self.limiter)
547
  try:
548
- first_response = await async_iteration(generator)
549
  if self.multimodal and isinstance(message, dict):
550
  for x in message["files"]:
551
  history.append([(x,), None])
@@ -553,21 +559,21 @@ class ChatInterface(Blocks):
553
  yield update, update
554
  else:
555
  update = history + [[message, first_response]]
556
- yield update, update
557
  except StopIteration:
558
  if self.multimodal and isinstance(message, dict):
559
  self._append_multimodal_history(message, None, history)
560
  yield history, history
561
  else:
562
  update = history + [[message, None]]
563
- yield update, update
564
- async for response in generator:
565
  if self.multimodal and isinstance(message, dict):
566
  update = history + [[message["text"], response]]
567
  yield update, update
568
  else:
569
  update = history + [[message, response]]
570
- yield update, update
571
 
572
  async def _api_submit_fn(
573
  self, message: str, history: list[list[str | None]], request: Request, *args
 
20
  State,
21
  Textbox,
22
  get_component_instance,
23
+ Dataset,
24
  )
25
  from gradio.events import Dependency, on
26
  from gradio.helpers import special_args
 
103
  self.pre_fn = pre_fn
104
  self.pre_fn_kwargs = pre_fn_kwargs
105
 
106
+ self.interrupter = State(None)
107
+
108
  self.multimodal = multimodal
109
  self.concurrency_limit = concurrency_limit
110
  self.fn = fn
 
289
  .then(
290
  submit_fn,
291
  [self.saved_input, self.chatbot_state] + self.additional_inputs,
292
+ [self.chatbot, self.chatbot_state, self.interrupter],
293
  show_api=False,
294
  concurrency_limit=cast(
295
  Union[int, Literal["default"], None], self.concurrency_limit
 
397
  def _setup_stop_events(
398
  self, event_triggers: list[Callable], event_to_cancel: Dependency
399
  ) -> None:
400
+ def perform_interrupt(ipc):
401
+ if ipc is not None:
402
+ ipc()
403
+ return
404
+
405
  if self.stop_btn and self.is_generator:
406
  if self.submit_btn:
407
  for event_trigger in event_triggers:
 
441
  queue=False,
442
  )
443
  self.stop_btn.click(
444
+ fn=perform_interrupt,
445
+ inputs=[self.interrupter],
 
446
  cancels=event_to_cancel,
447
  show_api=False,
448
  )
 
551
  )
552
  generator = SyncToAsyncIterator(generator, self.limiter)
553
  try:
554
+ first_response, first_interrupter = await async_iteration(generator)
555
  if self.multimodal and isinstance(message, dict):
556
  for x in message["files"]:
557
  history.append([(x,), None])
 
559
  yield update, update
560
  else:
561
  update = history + [[message, first_response]]
562
+ yield update, update, first_interrupter
563
  except StopIteration:
564
  if self.multimodal and isinstance(message, dict):
565
  self._append_multimodal_history(message, None, history)
566
  yield history, history
567
  else:
568
  update = history + [[message, None]]
569
+ yield update, update, first_interrupter
570
+ async for response, interrupter in generator:
571
  if self.multimodal and isinstance(message, dict):
572
  update = history + [[message["text"], response]]
573
  yield update, update
574
  else:
575
  update = history + [[message, response]]
576
+ yield update, update, interrupter
577
 
578
  async def _api_submit_fn(
579
  self, message: str, history: list[list[str | None]], request: Request, *args