pseudotensor commited on
Commit
6dd6b04
1 Parent(s): 196f3c7

Update with h2oGPT hash da43063f5ead136baee5bd29201f79db6e26d2a2

Browse files
Files changed (3) hide show
  1. generate.py +120 -55
  2. gradio_runner.py +106 -83
  3. utils.py +34 -0
generate.py CHANGED
@@ -1,14 +1,15 @@
1
  import functools
 
2
  import sys
3
  import os
 
4
  import traceback
5
  import typing
6
- from threading import Thread
7
  from datetime import datetime
8
  import filelock
9
  import psutil
10
 
11
- from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial
12
 
13
  SEED = 1236
14
  set_seed(SEED)
@@ -107,7 +108,7 @@ def main(
107
  admin_pass = os.getenv("ADMIN_PASS")
108
  # will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
109
  # but becomes unrecoverable sometimes if raise, so just be silent for now
110
- raise_generate_gpu_exceptions = not is_public
111
 
112
  # allow set token directly
113
  use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
@@ -223,9 +224,10 @@ def main(
223
  eval_filename = os.path.join(scoring_path, eval_filename)
224
 
225
  # torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
226
- context_class = NullContext() if n_gpus > 1 or n_gpus == 0 else torch.device("cuda")
 
227
 
228
- with context_class:
229
  # ensure was set right above before examples generated
230
  assert not stream_output, "stream_output=True does not make sense with example loop"
231
  import time
@@ -240,7 +242,8 @@ def main(
240
  fun = partial(evaluate, model_state, debug=debug, save_dir=save_dir, is_low_mem=is_low_mem,
241
  raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
242
  chat_context=chat_context,
243
- concurrency_count=concurrency_count)
 
244
  else:
245
  assert eval_sharegpt_prompts_only > 0
246
 
@@ -288,7 +291,7 @@ def main(
288
  truncation=True,
289
  max_length=cutoff_len)
290
  try:
291
- score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
292
  except torch.cuda.OutOfMemoryError as e:
293
  print("GPU OOM 1: question: %s answer: %s exception: %s" % (prompt, res, str(e)), flush=True)
294
  traceback.print_exc()
@@ -655,6 +658,7 @@ def evaluate(
655
  is_low_mem=None,
656
  raise_generate_gpu_exceptions=None,
657
  chat_context=None,
 
658
  ):
659
  # ensure passed these
660
  assert concurrency_count is not None
@@ -829,55 +833,115 @@ def evaluate(
829
  )
830
 
831
  with torch.no_grad():
832
- # protection for gradio not keeping track of closed users,
833
- # else hit bitsandbytes lack of thread safety:
834
- # https://github.com/h2oai/h2ogpt/issues/104
835
- # but only makes sense if concurrency_count == 1
836
- context_class = NullContext #if concurrency_count > 1 else filelock.FileLock
837
- print('Pre-Generate: %s' % str(datetime.now()), flush=True)
838
- decoded_output = None
839
- with context_class("generate.lock"):
840
- print('Generate: %s' % str(datetime.now()), flush=True)
841
- # decoded tokenized prompt can deviate from prompt due to special characters
842
- inputs_decoded = decoder(input_ids[0])
843
- inputs_decoded_raw = decoder_raw(input_ids[0])
844
- if inputs_decoded == prompt:
845
- # normal
846
- pass
847
- elif inputs_decoded.lstrip() == prompt.lstrip():
848
- # sometimes extra space in front, make prompt same for prompt removal
849
- prompt = inputs_decoded
850
- elif inputs_decoded_raw == prompt:
851
- # some models specify special tokens that are part of normal prompt, so can't skip them
852
- inputs_decoded_raw = inputs_decoded
853
- decoder = decoder_raw
854
- else:
855
- print("WARNING: Special characters in prompt", flush=True)
856
- if stream_output:
857
- skip_prompt = False
858
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=skip_prompt)
859
- gen_kwargs.update(dict(streamer=streamer))
860
- target_func = generate_with_exceptions
861
- target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
862
- raise_generate_gpu_exceptions, **gen_kwargs)
863
- thread = Thread(target=target)
864
- thread.start()
865
- outputs = ""
866
- for new_text in streamer:
867
- outputs += new_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868
  yield prompter.get_response(outputs, prompt=inputs_decoded,
869
  sanitize_bot_response=sanitize_bot_response)
870
- decoded_output = outputs
871
- else:
872
- outputs = model.generate(**gen_kwargs)
873
- outputs = [decoder(s) for s in outputs.sequences]
874
- yield prompter.get_response(outputs, prompt=inputs_decoded,
875
- sanitize_bot_response=sanitize_bot_response)
876
- if outputs and len(outputs) >= 1:
877
- decoded_output = prompt + outputs[0]
878
- if save_dir and decoded_output:
879
- save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
880
- print('Post-Generate: %s decoded_output: %s' % (str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
881
 
882
 
883
  def generate_with_exceptions(func, prompt, inputs_decoded, raise_generate_gpu_exceptions, **kwargs):
@@ -908,7 +972,8 @@ def generate_with_exceptions(func, prompt, inputs_decoded, raise_generate_gpu_ex
908
  return
909
  else:
910
  clear_torch_cache()
911
- raise
 
912
 
913
 
914
  def get_generate_params(model_lower, chat,
 
1
  import functools
2
+ import queue
3
  import sys
4
  import os
5
+ import time
6
  import traceback
7
  import typing
 
8
  from datetime import datetime
9
  import filelock
10
  import psutil
11
 
12
+ from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread
13
 
14
  SEED = 1236
15
  set_seed(SEED)
 
108
  admin_pass = os.getenv("ADMIN_PASS")
109
  # will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
110
  # but becomes unrecoverable sometimes if raise, so just be silent for now
111
+ raise_generate_gpu_exceptions = True
112
 
113
  # allow set token directly
114
  use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
 
224
  eval_filename = os.path.join(scoring_path, eval_filename)
225
 
226
  # torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
227
+ device = 'cpu' if n_gpus == 0 else 'cuda'
228
+ context_class = NullContext if n_gpus > 1 or n_gpus == 0 else torch.device
229
 
230
+ with context_class(device):
231
  # ensure was set right above before examples generated
232
  assert not stream_output, "stream_output=True does not make sense with example loop"
233
  import time
 
242
  fun = partial(evaluate, model_state, debug=debug, save_dir=save_dir, is_low_mem=is_low_mem,
243
  raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
244
  chat_context=chat_context,
245
+ concurrency_count=concurrency_count,
246
+ lora_weights=lora_weights)
247
  else:
248
  assert eval_sharegpt_prompts_only > 0
249
 
 
291
  truncation=True,
292
  max_length=cutoff_len)
293
  try:
294
+ score = torch.sigmoid(smodel(**inputs).logits[0].float()).cpu().detach().numpy()[0]
295
  except torch.cuda.OutOfMemoryError as e:
296
  print("GPU OOM 1: question: %s answer: %s exception: %s" % (prompt, res, str(e)), flush=True)
297
  traceback.print_exc()
 
658
  is_low_mem=None,
659
  raise_generate_gpu_exceptions=None,
660
  chat_context=None,
661
+ lora_weights=None,
662
  ):
663
  # ensure passed these
664
  assert concurrency_count is not None
 
833
  )
834
 
835
  with torch.no_grad():
836
+ context_class_cast = NullContext if device == 'cpu' or lora_weights else torch.autocast
837
+ with context_class_cast(device):
838
+ # protection for gradio not keeping track of closed users,
839
+ # else hit bitsandbytes lack of thread safety:
840
+ # https://github.com/h2oai/h2ogpt/issues/104
841
+ # but only makes sense if concurrency_count == 1
842
+ context_class = NullContext #if concurrency_count > 1 else filelock.FileLock
843
+ print('Pre-Generate: %s' % str(datetime.now()), flush=True)
844
+ decoded_output = None
845
+ with context_class("generate.lock"):
846
+ print('Generate: %s' % str(datetime.now()), flush=True)
847
+ # decoded tokenized prompt can deviate from prompt due to special characters
848
+ inputs_decoded = decoder(input_ids[0])
849
+ inputs_decoded_raw = decoder_raw(input_ids[0])
850
+ if inputs_decoded == prompt:
851
+ # normal
852
+ pass
853
+ elif inputs_decoded.lstrip() == prompt.lstrip():
854
+ # sometimes extra space in front, make prompt same for prompt removal
855
+ prompt = inputs_decoded
856
+ elif inputs_decoded_raw == prompt:
857
+ # some models specify special tokens that are part of normal prompt, so can't skip them
858
+ inputs_decoded_raw = inputs_decoded
859
+ decoder = decoder_raw
860
+ else:
861
+ print("WARNING: Special characters in prompt", flush=True)
862
+ if stream_output:
863
+ skip_prompt = False
864
+ streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False)
865
+ gen_kwargs.update(dict(streamer=streamer))
866
+ target_func = generate_with_exceptions
867
+ target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
868
+ raise_generate_gpu_exceptions, **gen_kwargs)
869
+ bucket = queue.Queue()
870
+ thread = EThread(target=target, kwargs=dict(streamer=streamer), bucket=bucket)
871
+ thread.start()
872
+ outputs = ""
873
+ try:
874
+ for new_text in streamer:
875
+ if bucket.qsize() > 0 or thread.exc:
876
+ thread.join()
877
+ outputs += new_text
878
+ yield prompter.get_response(outputs, prompt=inputs_decoded,
879
+ sanitize_bot_response=sanitize_bot_response)
880
+ except BaseException:
881
+ # if any exception, raise that exception if was from thread, first
882
+ if thread.exc:
883
+ raise thread.exc
884
+ raise
885
+ finally:
886
+ # in case no exception and didn't join with thread yet, then join
887
+ if not thread.exc:
888
+ thread.join()
889
+ # in case raise StopIteration or broke queue loop in streamer, but still have exception
890
+ if thread.exc:
891
+ raise thread.exc
892
+ decoded_output = outputs
893
+ else:
894
+ outputs = model.generate(**gen_kwargs)
895
+ outputs = [decoder(s) for s in outputs.sequences]
896
  yield prompter.get_response(outputs, prompt=inputs_decoded,
897
  sanitize_bot_response=sanitize_bot_response)
898
+ if outputs and len(outputs) >= 1:
899
+ decoded_output = prompt + outputs[0]
900
+ if save_dir and decoded_output:
901
+ save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
902
+ print('Post-Generate: %s decoded_output: %s' % (str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
903
+
904
+
905
+ class H2OTextIteratorStreamer(TextIteratorStreamer):
906
+ """
907
+ normally, timeout required for now to handle exceptions, else get()
908
+ but with H2O version of TextIteratorStreamer, loop over block to handle
909
+ """
910
+ def __init__(self, tokenizer, skip_prompt: bool = False, timeout: typing.Optional[float] = None,
911
+ block=True, **decode_kwargs):
912
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
913
+ self.text_queue = queue.Queue()
914
+ self.stop_signal = None
915
+ self.do_stop = False
916
+ self.timeout = timeout
917
+ self.block = block
918
+
919
+ def on_finalized_text(self, text: str, stream_end: bool = False):
920
+ """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
921
+ self.text_queue.put(text, timeout=self.timeout)
922
+ if stream_end:
923
+ self.text_queue.put(self.stop_signal, timeout=self.timeout)
924
+
925
+ def __iter__(self):
926
+ return self
927
+
928
+ def __next__(self):
929
+ while True:
930
+ try:
931
+ value = self.stop_signal # value looks unused in pycharm, not true
932
+ if self.do_stop:
933
+ print("hit stop", flush=True)
934
+ # could raise or break, maybe best to raise and make parent see if any exception in thread
935
+ raise StopIteration()
936
+ #break
937
+ value = self.text_queue.get(block=self.block, timeout=self.timeout)
938
+ break
939
+ except queue.Empty:
940
+ time.sleep(0.01)
941
+ if value == self.stop_signal:
942
+ raise StopIteration()
943
+ else:
944
+ return value
945
 
946
 
947
  def generate_with_exceptions(func, prompt, inputs_decoded, raise_generate_gpu_exceptions, **kwargs):
 
972
  return
973
  else:
974
  clear_torch_cache()
975
+ if raise_generate_gpu_exceptions:
976
+ raise
977
 
978
 
979
  def get_generate_params(model_lower, chat,
gradio_runner.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import functools
2
  import inspect
3
  import os
@@ -246,7 +247,11 @@ def go_gradio(**kwargs):
246
  value=kwargs['top_k'], label="Top k",
247
  info='Num. tokens to sample from'
248
  )
249
- max_beams = 8 if not is_low_mem else 1
 
 
 
 
250
  num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
251
  value=min(max_beams, kwargs['num_beams']), label="Beams",
252
  info="Number of searches for optimal overall probability. "
@@ -262,7 +267,9 @@ def go_gradio(**kwargs):
262
  )
263
  early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
264
  value=kwargs['early_stopping'])
265
- max_max_time = 60 * 5 if not is_low_mem else 60
 
 
266
  max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1,
267
  value=min(max_max_time, kwargs['max_time']), label="Max. time",
268
  info="Max. time to search optimal output.")
@@ -309,9 +316,10 @@ def go_gradio(**kwargs):
309
  model_gpu = gr.Dropdown(n_gpus_list,
310
  label="GPU ID 2 [-1 = all GPUs, if Choose is enabled]",
311
  value=kwargs['gpu_id'])
312
- model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'])
 
313
  lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
314
- visible=kwargs['show_lora'])
315
  with gr.Row():
316
  with gr.Column(scale=50):
317
  new_model = gr.Textbox(label="New Model HF name/path")
@@ -354,15 +362,15 @@ def go_gradio(**kwargs):
354
  with gr.Column():
355
  with gr.Row():
356
  system_btn = gr.Button(value='Get System Info')
357
- system_text = gr.Textbox(label='System Info')
358
 
359
  with gr.Row():
360
  zip_btn = gr.Button("Zip")
361
- zip_text = gr.Textbox(label="Zip file name")
362
  file_output = gr.File()
363
  with gr.Row():
364
  s3up_btn = gr.Button("S3UP")
365
- s3up_text = gr.Textbox(label='S3UP result')
366
 
367
  # Get flagged data
368
  zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
@@ -395,12 +403,15 @@ def go_gradio(**kwargs):
395
  dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
396
  size="sm",
397
  )
 
 
398
  dark_mode_btn.click(
399
  None,
400
  None,
401
  None,
402
  _js=get_dark_js(),
403
  api_name="dark" if allow_api else None,
 
404
  )
405
 
406
  # Control chat and non-chat blocks, which can be independently used by chat checkbox swap
@@ -415,7 +426,8 @@ def go_gradio(**kwargs):
415
 
416
  chat.select(col_nochat_fun, chat, col_nochat, api_name="chat_checkbox" if allow_api else None) \
417
  .then(col_chat_fun, chat, col_chat) \
418
- .then(context_fun, chat, context)
 
419
 
420
  # examples after submit or any other buttons for chat or no chat
421
  if kwargs['examples'] is not None and kwargs['show_examples']:
@@ -514,6 +526,10 @@ def go_gradio(**kwargs):
514
  if sanitize_user_prompt:
515
  from better_profanity import profanity
516
  user_message1 = profanity.censor(user_message1)
 
 
 
 
517
 
518
  history = args_list[-1]
519
  if undo and history:
@@ -541,15 +557,17 @@ def go_gradio(**kwargs):
541
  :param retry:
542
  :return:
543
  """
544
- args_list = list(args).copy()
545
  history = args_list[-1] # model_state is -2
546
  if retry and history:
547
  history.pop()
548
  if not history:
549
  print("No history", flush=True)
 
 
550
  return
551
  # ensure output will be unique to models
552
- history = history.copy()
553
  instruction1 = history[-1][0]
554
  context1 = ''
555
  if kwargs['chat_history'] > 0:
@@ -571,6 +589,8 @@ def go_gradio(**kwargs):
571
  args_list[2] = context1[-kwargs['chat_history']:]
572
  model_state1 = args_list[-2]
573
  if model_state1[0] is None or model_state1[0] == no_model_str:
 
 
574
  return
575
  args_list = args_list[:-2]
576
  fun1 = partial(evaluate,
@@ -580,19 +600,25 @@ def go_gradio(**kwargs):
580
  for output in fun1(*tuple(args_list)):
581
  bot_message = output
582
  history[-1][1] = bot_message
583
- yield history
584
  except StopIteration:
585
- yield history
586
  except RuntimeError as e:
587
  if "generator raised StopIteration" in str(e):
588
  # assume last entry was bad, undo
589
  history.pop()
590
- yield history
591
- raise
 
 
 
 
592
  except Exception as e:
593
  # put error into user input
594
- history[-1][0] = "Exception: %s" % str(e)
595
- yield history
 
 
596
  raise
597
  return
598
 
@@ -603,11 +629,11 @@ def go_gradio(**kwargs):
603
  )
604
  bot_args = dict(fn=bot,
605
  inputs=inputs_list + [model_state] + [text_output],
606
- outputs=text_output,
607
  )
608
  retry_bot_args = dict(fn=functools.partial(bot, retry=True),
609
  inputs=inputs_list + [model_state] + [text_output],
610
- outputs=text_output,
611
  )
612
  undo_user_args = dict(fn=functools.partial(user, undo=True),
613
  inputs=inputs_list + [text_output],
@@ -621,11 +647,11 @@ def go_gradio(**kwargs):
621
  )
622
  bot_args2 = dict(fn=bot,
623
  inputs=inputs_list + [model_state2] + [text_output2],
624
- outputs=text_output2,
625
  )
626
  retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
627
  inputs=inputs_list + [model_state2] + [text_output2],
628
- outputs=text_output2,
629
  )
630
  undo_user_args2 = dict(fn=functools.partial(user, undo=True),
631
  inputs=inputs_list + [text_output2],
@@ -636,67 +662,61 @@ def go_gradio(**kwargs):
636
  return gr.Textbox.update(value='')
637
 
638
  if kwargs['auto_score']:
639
- # in case 2nd model, consume instruction first, so can clear quickly
640
- # bot doesn't consume instruction itself, just history from user, so why works
641
- submit_event = instruction.submit(**user_args, queue=queue,
642
- api_name='instruction' if allow_api else None) \
643
- .then(**user_args2, api_name='instruction2' if allow_api else None) \
644
- .then(clear_instruct, None, instruction) \
645
- .then(clear_instruct, None, iinput) \
646
- .then(**bot_args, api_name='instruction_bot' if allow_api else None, queue=queue) \
647
- .then(**score_args, api_name='instruction_bot_score' if allow_api else None, queue=queue) \
648
- .then(**bot_args2, api_name='instruction_bot2' if allow_api else None, queue=queue) \
649
- .then(**score_args2, api_name='instruction_bot_score2' if allow_api else None, queue=queue) \
650
- .then(clear_torch_cache)
651
- submit_event2 = submit.click(**user_args, api_name='submit' if allow_api else None) \
652
- .then(**user_args2, api_name='submit2' if allow_api else None) \
653
- .then(clear_instruct, None, instruction) \
654
- .then(clear_instruct, None, iinput) \
655
- .then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue) \
656
- .then(**score_args, api_name='submit_bot_score' if allow_api else None, queue=queue) \
657
- .then(**bot_args2, api_name='submit_bot2' if allow_api else None, queue=queue) \
658
- .then(**score_args2, api_name='submit_bot_score2' if allow_api else None, queue=queue) \
659
- .then(clear_torch_cache)
660
- submit_event3 = retry.click(**user_args, api_name='retry' if allow_api else None) \
661
- .then(**user_args2, api_name='retry2' if allow_api else None) \
662
- .then(clear_instruct, None, instruction) \
663
- .then(clear_instruct, None, iinput) \
664
- .then(**retry_bot_args, api_name='retry_bot' if allow_api else None, queue=queue) \
665
- .then(**score_args, api_name='retry_bot_score' if allow_api else None, queue=queue) \
666
- .then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None, queue=queue) \
667
- .then(**score_args2, api_name='retry_bot_score2' if allow_api else None, queue=queue) \
668
- .then(clear_torch_cache)
669
- submit_event4 = undo.click(**undo_user_args, api_name='undo' if allow_api else None) \
670
- .then(**undo_user_args2, api_name='undo2' if allow_api else None) \
671
- .then(clear_instruct, None, instruction) \
672
- .then(clear_instruct, None, iinput) \
673
- .then(**score_args, api_name='undo_score' if allow_api else None) \
674
- .then(**score_args2, api_name='undo_score2' if allow_api else None)
675
  else:
676
- submit_event = instruction.submit(**user_args,
677
- api_name='instruction' if allow_api else None) \
678
- .then(**user_args2, api_name='instruction2' if allow_api else None) \
679
- .then(clear_instruct, None, instruction) \
680
- .then(clear_instruct, None, iinput) \
681
- .then(**bot_args, api_name='instruction_bot' if allow_api else None, queue=queue) \
682
- .then(**bot_args2, api_name='instruction_bot2' if allow_api else None, queue=queue) \
683
- .then(clear_torch_cache)
684
- submit_event2 = submit.click(**user_args, api_name='submit' if allow_api else None) \
685
- .then(**user_args2, api_name='submit2' if allow_api else None) \
686
- .then(clear_instruct, None, instruction) \
687
- .then(clear_instruct, None, iinput) \
688
- .then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue) \
689
- .then(**bot_args2, api_name='submit_bot2' if allow_api else None, queue=queue) \
690
- .then(clear_torch_cache)
691
- submit_event3 = retry.click(**user_args, api_name='retry' if allow_api else None) \
692
- .then(**user_args2, api_name='retry2' if allow_api else None) \
693
- .then(clear_instruct, None, instruction) \
694
- .then(clear_instruct, None, iinput) \
695
- .then(**retry_bot_args, api_name='retry_bot' if allow_api else None, queue=queue) \
696
- .then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None, queue=queue) \
697
- .then(clear_torch_cache)
698
- submit_event4 = undo.click(**undo_user_args, api_name='undo' if allow_api else None) \
699
- .then(**undo_user_args2, api_name='undo2' if allow_api else None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
700
 
701
  # does both models
702
  clear.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \
@@ -864,9 +884,12 @@ def go_gradio(**kwargs):
864
  api_name='system_info' if allow_api else None, queue=False)
865
 
866
  # don't pass text_output, don't want to clear output, just stop it
867
- # FIXME: have to click once to stop output and second time to stop GPUs going
868
  stop_btn.click(lambda: None, None, None,
869
- cancels=[submit_event_nochat, submit_event, submit_event2, submit_event3],
 
 
 
870
  queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
871
  demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] else None)
872
 
@@ -888,7 +911,7 @@ def go_gradio(**kwargs):
888
 
889
  input_args_list = ['model_state']
890
  inputs_kwargs_list = ['debug', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0', 'is_low_mem',
891
- 'raise_generate_gpu_exceptions', 'chat_context', 'concurrency_count']
892
 
893
 
894
  def get_inputs_list(inputs_dict, model_lower):
 
1
+ import copy
2
  import functools
3
  import inspect
4
  import os
 
247
  value=kwargs['top_k'], label="Top k",
248
  info='Num. tokens to sample from'
249
  )
250
+ # FIXME: https://github.com/h2oai/h2ogpt/issues/106
251
+ if os.getenv('TESTINGFAIL'):
252
+ max_beams = 8 if not (is_low_mem or is_public) else 1
253
+ else:
254
+ max_beams = 1
255
  num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
256
  value=min(max_beams, kwargs['num_beams']), label="Beams",
257
  info="Number of searches for optimal overall probability. "
 
267
  )
268
  early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
269
  value=kwargs['early_stopping'])
270
+ max_max_time = 60 * 5 if not is_public else 60 * 2
271
+ if is_hf:
272
+ max_max_time = min(max_max_time, 60 * 1)
273
  max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1,
274
  value=min(max_max_time, kwargs['max_time']), label="Max. time",
275
  info="Max. time to search optimal output.")
 
316
  model_gpu = gr.Dropdown(n_gpus_list,
317
  label="GPU ID 2 [-1 = all GPUs, if Choose is enabled]",
318
  value=kwargs['gpu_id'])
319
+ model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'],
320
+ interactive=False)
321
  lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
322
+ visible=kwargs['show_lora'], interactive=False)
323
  with gr.Row():
324
  with gr.Column(scale=50):
325
  new_model = gr.Textbox(label="New Model HF name/path")
 
362
  with gr.Column():
363
  with gr.Row():
364
  system_btn = gr.Button(value='Get System Info')
365
+ system_text = gr.Textbox(label='System Info', interactive=False)
366
 
367
  with gr.Row():
368
  zip_btn = gr.Button("Zip")
369
+ zip_text = gr.Textbox(label="Zip file name", interactive=False)
370
  file_output = gr.File()
371
  with gr.Row():
372
  s3up_btn = gr.Button("S3UP")
373
+ s3up_text = gr.Textbox(label='S3UP result', interactive=False)
374
 
375
  # Get flagged data
376
  zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
 
403
  dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
404
  size="sm",
405
  )
406
+ # FIXME: Could add exceptions for non-chat but still streaming
407
+ exception_text = gr.Textbox(value="", visible=kwargs['chat'], label='Chat Exceptions', interactive=False)
408
  dark_mode_btn.click(
409
  None,
410
  None,
411
  None,
412
  _js=get_dark_js(),
413
  api_name="dark" if allow_api else None,
414
+ queue=False,
415
  )
416
 
417
  # Control chat and non-chat blocks, which can be independently used by chat checkbox swap
 
426
 
427
  chat.select(col_nochat_fun, chat, col_nochat, api_name="chat_checkbox" if allow_api else None) \
428
  .then(col_chat_fun, chat, col_chat) \
429
+ .then(context_fun, chat, context) \
430
+ .then(col_chat_fun, chat, exception_text)
431
 
432
  # examples after submit or any other buttons for chat or no chat
433
  if kwargs['examples'] is not None and kwargs['show_examples']:
 
526
  if sanitize_user_prompt:
527
  from better_profanity import profanity
528
  user_message1 = profanity.censor(user_message1)
529
+ if user_message1 in ['']:
530
+ # e.g. when user just hits enter in textbox,
531
+ # else will have <human>: <bot>: on single line, which seems to be "ok" for LLM but not usual
532
+ user_message1 = '\n'
533
 
534
  history = args_list[-1]
535
  if undo and history:
 
557
  :param retry:
558
  :return:
559
  """
560
+ args_list = copy.deepcopy(list(args))
561
  history = args_list[-1] # model_state is -2
562
  if retry and history:
563
  history.pop()
564
  if not history:
565
  print("No history", flush=True)
566
+ history = [['', None]]
567
+ yield history, ''
568
  return
569
  # ensure output will be unique to models
570
+ history = copy.deepcopy(history)
571
  instruction1 = history[-1][0]
572
  context1 = ''
573
  if kwargs['chat_history'] > 0:
 
589
  args_list[2] = context1[-kwargs['chat_history']:]
590
  model_state1 = args_list[-2]
591
  if model_state1[0] is None or model_state1[0] == no_model_str:
592
+ history = [['', None]]
593
+ yield history, ''
594
  return
595
  args_list = args_list[:-2]
596
  fun1 = partial(evaluate,
 
600
  for output in fun1(*tuple(args_list)):
601
  bot_message = output
602
  history[-1][1] = bot_message
603
+ yield history, ''
604
  except StopIteration:
605
+ yield history, ''
606
  except RuntimeError as e:
607
  if "generator raised StopIteration" in str(e):
608
  # assume last entry was bad, undo
609
  history.pop()
610
+ yield history, ''
611
+ else:
612
+ if history and len(history) > 0 and len(history[0]) > 1 and history[-1][1] is None:
613
+ history[-1][1] = ''
614
+ yield history, str(e)
615
+ raise
616
  except Exception as e:
617
  # put error into user input
618
+ ex = "Exception: %s" % str(e)
619
+ if history and len(history) > 0 and len(history[0]) > 1 and history[-1][1] is None:
620
+ history[-1][1] = ''
621
+ yield history, ex
622
  raise
623
  return
624
 
 
629
  )
630
  bot_args = dict(fn=bot,
631
  inputs=inputs_list + [model_state] + [text_output],
632
+ outputs=[text_output, exception_text],
633
  )
634
  retry_bot_args = dict(fn=functools.partial(bot, retry=True),
635
  inputs=inputs_list + [model_state] + [text_output],
636
+ outputs=[text_output, exception_text],
637
  )
638
  undo_user_args = dict(fn=functools.partial(user, undo=True),
639
  inputs=inputs_list + [text_output],
 
647
  )
648
  bot_args2 = dict(fn=bot,
649
  inputs=inputs_list + [model_state2] + [text_output2],
650
+ outputs=[text_output2, exception_text],
651
  )
652
  retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
653
  inputs=inputs_list + [model_state2] + [text_output2],
654
+ outputs=[text_output2, exception_text],
655
  )
656
  undo_user_args2 = dict(fn=functools.partial(user, undo=True),
657
  inputs=inputs_list + [text_output2],
 
662
  return gr.Textbox.update(value='')
663
 
664
  if kwargs['auto_score']:
665
+ score_args_submit = score_args
666
+ score_args2_submit = score_args2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
  else:
668
+ score_args_submit = dict(fn=lambda: None, inputs=None, outputs=None)
669
+ score_args2_submit = dict(fn=lambda: None, inputs=None, outputs=None)
670
+
671
+ # in case 2nd model, consume instruction first, so can clear quickly
672
+ # bot doesn't consume instruction itself, just history from user, so why works
673
+ submit_event1a = instruction.submit(**user_args, queue=queue,
674
+ api_name='instruction' if allow_api else None)
675
+ submit_event1b = submit_event1a.then(**user_args2, api_name='instruction2' if allow_api else None)
676
+ submit_event1c = submit_event1b.then(clear_instruct, None, instruction) \
677
+ .then(clear_instruct, None, iinput)
678
+ submit_event1d = submit_event1c.then(**bot_args, api_name='instruction_bot' if allow_api else None,
679
+ queue=queue)
680
+ submit_event1e = submit_event1d.then(**score_args_submit, api_name='instruction_bot_score' if allow_api else None,
681
+ queue=queue)
682
+ submit_event1f = submit_event1e.then(**bot_args2, api_name='instruction_bot2' if allow_api else None,
683
+ queue=queue)
684
+ submit_event1g = submit_event1f.then(**score_args2_submit,
685
+ api_name='instruction_bot_score2' if allow_api else None, queue=queue)
686
+ submit_event1h = submit_event1g.then(clear_torch_cache)
687
+
688
+ submit_event2a = submit.click(**user_args, api_name='submit' if allow_api else None)
689
+ submit_event2b = submit_event2a.then(**user_args2, api_name='submit2' if allow_api else None)
690
+ submit_event2c = submit_event2b.then(clear_instruct, None, instruction) \
691
+ .then(clear_instruct, None, iinput)
692
+ submit_event2d = submit_event2c.then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue)
693
+ submit_event2e = submit_event2d.then(**score_args_submit, api_name='submit_bot_score' if allow_api else None,
694
+ queue=queue)
695
+ submit_event2f = submit_event2e.then(**bot_args2, api_name='submit_bot2' if allow_api else None, queue=queue)
696
+ submit_event2g = submit_event2f.then(**score_args2_submit, api_name='submit_bot_score2' if allow_api else None,
697
+ queue=queue)
698
+ submit_event2h = submit_event2g.then(clear_torch_cache)
699
+
700
+ submit_event3a = retry.click(**user_args, api_name='retry' if allow_api else None)
701
+ submit_event3b = submit_event3a.then(**user_args2, api_name='retry2' if allow_api else None)
702
+ submit_event3c = submit_event3b.then(clear_instruct, None, instruction) \
703
+ .then(clear_instruct, None, iinput)
704
+ submit_event3d = submit_event3c.then(**retry_bot_args, api_name='retry_bot' if allow_api else None,
705
+ queue=queue)
706
+ submit_event3e = submit_event3d.then(**score_args_submit, api_name='retry_bot_score' if allow_api else None,
707
+ queue=queue)
708
+ submit_event3f = submit_event3e.then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None,
709
+ queue=queue)
710
+ submit_event3g = submit_event3f.then(**score_args2_submit, api_name='retry_bot_score2' if allow_api else None,
711
+ queue=queue)
712
+ submit_event3h = submit_event3g.then(clear_torch_cache)
713
+
714
+ submit_event4 = undo.click(**undo_user_args, api_name='undo' if allow_api else None) \
715
+ .then(**undo_user_args2, api_name='undo2' if allow_api else None) \
716
+ .then(clear_instruct, None, instruction) \
717
+ .then(clear_instruct, None, iinput) \
718
+ .then(**score_args_submit, api_name='undo_score' if allow_api else None) \
719
+ .then(**score_args2_submit, api_name='undo_score2' if allow_api else None)
720
 
721
  # does both models
722
  clear.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \
 
884
  api_name='system_info' if allow_api else None, queue=False)
885
 
886
  # don't pass text_output, don't want to clear output, just stop it
887
+ # cancel only stops outer generation, not inner generation or non-generation
888
  stop_btn.click(lambda: None, None, None,
889
+ cancels=[submit_event1d, submit_event1f,
890
+ submit_event2d, submit_event2f,
891
+ submit_event3d, submit_event3f,
892
+ submit_event_nochat],
893
  queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
894
  demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] else None)
895
 
 
911
 
912
  input_args_list = ['model_state']
913
  inputs_kwargs_list = ['debug', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0', 'is_low_mem',
914
+ 'raise_generate_gpu_exceptions', 'chat_context', 'concurrency_count', 'lora_weights']
915
 
916
 
917
  def get_inputs_list(inputs_dict, model_lower):
utils.py CHANGED
@@ -259,3 +259,37 @@ def wrapped_partial(func, *args, **kwargs):
259
  partial_func = functools.partial(func, *args, **kwargs)
260
  functools.update_wrapper(partial_func, func)
261
  return partial_func
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  partial_func = functools.partial(func, *args, **kwargs)
260
  functools.update_wrapper(partial_func, func)
261
  return partial_func
262
+
263
+
264
+ class ThreadException(Exception):
265
+ pass
266
+
267
+
268
+ class EThread(threading.Thread):
269
+ # Function that raises the custom exception
270
+ def __init__(self, group=None, target=None, name=None,
271
+ args=(), kwargs=None, *, daemon=None, bucket=None):
272
+ self.bucket = bucket
273
+ self.streamer = kwargs.get('streamer')
274
+ self.exc = None
275
+ super().__init__(group=group, target=target, name=name, args=args, kwargs=kwargs, daemon=daemon)
276
+
277
+ def run(self):
278
+ # Variable that stores the exception, if raised by someFunction
279
+ try:
280
+ super().run()
281
+ except BaseException as e:
282
+ print("thread exception: %s" % str(sys.exc_info()))
283
+ self.bucket.put(sys.exc_info())
284
+ self.exc = e
285
+ if self.streamer:
286
+ print("make stop: %s" % str(sys.exc_info()), flush=True)
287
+ self.streamer.do_stop = True
288
+
289
+ def join(self, timeout=None):
290
+ threading.Thread.join(self)
291
+ # Since join() returns in caller thread
292
+ # we re-raise the caught exception
293
+ # if any was caught
294
+ if self.exc:
295
+ raise self.exc