nxphi47 commited on
Commit
c821309
1 Parent(s): ae33a24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -525
app.py CHANGED
@@ -18,6 +18,11 @@ import filelock
18
  import glob
19
  import json
20
  import time
 
 
 
 
 
21
 
22
  from gradio_client.documentation import document, set_documentation_group
23
 
@@ -278,455 +283,6 @@ path_markdown = """
278
 
279
 
280
 
281
-
282
- def custom_hf_model_weights_iterator(
283
- model_name_or_path: str,
284
- cache_dir: Optional[str] = None,
285
- use_np_cache: bool = False,
286
- ) -> Iterator[Tuple[str, torch.Tensor]]:
287
- # ! if use vllm==0.1.4, use this to augment hf_model_weights_iterator loader
288
- from vllm.model_executor.weight_utils import Disabledtqdm
289
- # Prepare file lock directory to prevent multiple processes from
290
- # downloading the same model weights at the same time.
291
- lock_dir = cache_dir if cache_dir is not None else "/tmp"
292
- lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
293
- lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
294
-
295
- # Download model weights from huggingface.
296
- is_local = os.path.isdir(model_name_or_path)
297
- if not is_local:
298
- with lock:
299
- hf_folder = snapshot_download(model_name_or_path,
300
- allow_patterns="*.bin",
301
- cache_dir=cache_dir,
302
- local_files_only=True,
303
- tqdm_class=Disabledtqdm)
304
- else:
305
- hf_folder = model_name_or_path
306
-
307
- hf_bin_files = [
308
- x for x in glob.glob(os.path.join(hf_folder, "*model*.bin"))
309
- if not x.endswith("training_args.bin")
310
- ]
311
- hf_safetensors_files = [
312
- x for x in glob.glob(os.path.join(hf_folder, "*model*.safetensors"))
313
- if not x.endswith("training_args.bin")
314
- ]
315
-
316
- if use_np_cache:
317
- # Convert the model weights from torch tensors to numpy arrays for
318
- # faster loading.
319
- np_folder = os.path.join(hf_folder, "np")
320
- os.makedirs(np_folder, exist_ok=True)
321
- weight_names_file = os.path.join(np_folder, "weight_names.json")
322
- with lock:
323
- if not os.path.exists(weight_names_file):
324
- weight_names = []
325
- for bin_file in hf_bin_files:
326
- state = torch.load(bin_file, map_location="cpu")
327
- for name, param in state.items():
328
- param_path = os.path.join(np_folder, name)
329
- with open(param_path, "wb") as f:
330
- np.save(f, param.cpu().detach().numpy())
331
- weight_names.append(name)
332
- with open(weight_names_file, "w") as f:
333
- json.dump(weight_names, f)
334
-
335
- with open(weight_names_file, "r") as f:
336
- weight_names = json.load(f)
337
-
338
- for name in weight_names:
339
- param_path = os.path.join(np_folder, name)
340
- with open(param_path, "rb") as f:
341
- param = np.load(f)
342
- yield name, torch.from_numpy(param)
343
- else:
344
- if len(hf_bin_files) > 0:
345
- print(F'Load bin files: {hf_bin_files}')
346
- for bin_file in hf_bin_files:
347
- state = torch.load(bin_file, map_location="cpu")
348
- for name, param in state.items():
349
- yield name, param
350
- del state
351
- torch.cuda.empty_cache()
352
- elif len(hf_safetensors_files) > 0:
353
- print(F'Load safetensor files: {hf_safetensors_files}')
354
- from safetensors.torch import load_file
355
- for safe_file in hf_safetensors_files:
356
- # state = torch.load(bin_file, map_location="cpu")
357
- state = load_file(safe_file)
358
- for name, param in state.items():
359
- yield name, param
360
- del state
361
- torch.cuda.empty_cache()
362
- else:
363
- raise ValueError(f'no files available either bin or safe')
364
-
365
-
366
- def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
367
- """convert PySafeSlice object from safetensors to torch.Tensor
368
-
369
- PySafeSlice object supports indexing, which is done before loading the
370
- actual tensor and can reduce the amount of memory being read into the
371
- memory. However, it does not support more advanced functionalities
372
- like `.view()` or `.t()`. Therefore, if we need to modify the loaded
373
- tensor with these more complicated operators, we need to convert to
374
- tensor first.
375
- """
376
- if not isinstance(x, torch.Tensor):
377
- x = x[:]
378
- return x
379
-
380
-
381
- def load_padded_tensor_parallel_vocab(
382
- param: torch.Tensor,
383
- loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
384
- tensor_model_parallel_rank: int,
385
- ) -> None:
386
- shard_size = param.shape[0]
387
- start_idx = tensor_model_parallel_rank * shard_size
388
- end_idx = (tensor_model_parallel_rank + 1) * shard_size
389
- loaded_weight = loaded_weight[start_idx:end_idx]
390
- loaded_weight = convert_pyslice_to_tensor(loaded_weight)
391
- param[:loaded_weight.shape[0]].copy_(loaded_weight)
392
-
393
-
394
- def llama_load_weights(
395
- self,
396
- model_name_or_path: str,
397
- cache_dir: Optional[str] = None,
398
- use_np_cache: bool = False,
399
- load_format: str = "auto",
400
- revision: Optional[str] = None
401
- ):
402
- # if use vllm==0.1.4
403
- from vllm.model_executor.weight_utils import (
404
- load_tensor_parallel_weights
405
- )
406
- from vllm.model_executor.parallel_utils.parallel_state import (
407
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
408
- tp_size = get_tensor_model_parallel_world_size()
409
- tensor_model_parallel_rank = get_tensor_model_parallel_rank()
410
-
411
- q_proj_shard_size = (self.config.hidden_size // tp_size)
412
- kv_proj_shard_size = (self.config.hidden_size //
413
- self.config.num_attention_heads *
414
- getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) // tp_size)
415
- attention_weight_specs = [
416
- # (weight_name, shard_size, offset)
417
- ("q_proj", q_proj_shard_size, 0),
418
- ("k_proj", kv_proj_shard_size, q_proj_shard_size),
419
- ("v_proj", kv_proj_shard_size,
420
- q_proj_shard_size + kv_proj_shard_size),
421
- ]
422
- state_dict = self.state_dict()
423
- need_to_load = len(state_dict)
424
- loaded = 0
425
- iterator = custom_hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache)
426
-
427
- for name, loaded_weight in iterator:
428
- if "rotary_emb.inv_freq" in name:
429
- continue
430
-
431
- if "embed_tokens" in name or "lm_head" in name:
432
- param = state_dict[name]
433
- # Consider padding in the vocab size.
434
- padded_vocab_size = (param.shape[0] * tp_size)
435
- # num_extra_rows = padded_vocab_size - self.config.vocab_size
436
- num_extra_rows = padded_vocab_size - loaded_weight.size(0)
437
- load_size = loaded_weight.size()
438
- extra_rows = torch.empty(num_extra_rows,
439
- loaded_weight.shape[1])
440
- extra_rows = extra_rows.to(loaded_weight)
441
- loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
442
- if num_extra_rows > 0:
443
- print(f'Add empty to {num_extra_rows} extra row for {name}')
444
- print(f'Load: {name} | {padded_vocab_size=} | {self.config.vocab_size=} | {num_extra_rows=} | {param.size()=} | {loaded_weight.size()=} | {load_size=}')
445
-
446
- is_attention_weight = False
447
- for weight_name, shard_size, offset in attention_weight_specs:
448
- if weight_name not in name or "qkv_proj" in name:
449
- continue
450
- param = state_dict[name.replace(weight_name, "qkv_proj")]
451
-
452
- loaded_weight = loaded_weight[
453
- shard_size * tensor_model_parallel_rank:shard_size *
454
- (tensor_model_parallel_rank + 1)]
455
- param_slice = param.data[offset:offset + shard_size]
456
- assert param_slice.shape == loaded_weight.shape
457
-
458
- param_slice.copy_(loaded_weight)
459
- loaded += 1.0 / 3
460
- is_attention_weight = True
461
- break
462
- if is_attention_weight:
463
- continue
464
-
465
- # ! qkv_proj is sharded differently if concatenated into qkv
466
- # qkv: qqqq kkkk vvvv
467
- # lweight: qq0qq1 kk0kk1 vv0vv1
468
- # q_shard_size: hidden_size // tp_size = qq
469
- # qkv_s0: qq0_kk0_vv0
470
- # qkv_s1: qq1_kk1_vv1
471
- if "qkv_proj" in name:
472
- param = state_dict[name]
473
- # loaded_weight
474
- qsize = self.config.hidden_size
475
- kvsize = self.config.hidden_size // self.config.num_attention_heads * getattr(self.config, "num_key_value_heads", self.config.num_attention_heads)
476
- q_offsets = (
477
- q_proj_shard_size * tensor_model_parallel_rank,
478
- q_proj_shard_size * (tensor_model_parallel_rank + 1)
479
- )
480
- k_offsets = (
481
- qsize + kv_proj_shard_size * tensor_model_parallel_rank,
482
- qsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1)
483
- )
484
- v_offsets = (
485
- qsize + kvsize + kv_proj_shard_size * tensor_model_parallel_rank,
486
- qsize + kvsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1)
487
- )
488
- _loaded_weight = torch.cat(
489
- [
490
- loaded_weight[q_offsets[0]:q_offsets[1]],
491
- loaded_weight[k_offsets[0]:k_offsets[1]],
492
- loaded_weight[v_offsets[0]:v_offsets[1]],
493
- ], 0
494
- )
495
- assert param.shape == _loaded_weight.shape, f'{param.shape=} != {_loaded_weight.shape=}'
496
- param.data.copy_(_loaded_weight)
497
- loaded += 1.0
498
- is_attention_weight = True
499
- if is_attention_weight:
500
- continue
501
-
502
-
503
- is_gate_up_weight = False
504
- for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
505
- if weight_name not in name or "gate_up_proj" in name:
506
- continue
507
- param = state_dict[name.replace(weight_name, "gate_up_proj")]
508
- shard_size = param.shape[0] // 2
509
- loaded_weight = loaded_weight[
510
- shard_size * tensor_model_parallel_rank:shard_size *
511
- (tensor_model_parallel_rank + 1)]
512
- param_slice = param.data[shard_size * stride_id:shard_size *
513
- (stride_id + 1)]
514
- assert param_slice.shape == loaded_weight.shape
515
- param_slice.copy_(loaded_weight)
516
- loaded += 1.0 / 2
517
- is_gate_up_weight = True
518
- break
519
- if is_gate_up_weight:
520
- continue
521
-
522
- if "gate_up_proj" in name:
523
- param = state_dict[name]
524
- shard_size = param.shape[0] // 2
525
- intermediate_size = self.config.intermediate_size
526
- g_offsets = (
527
- shard_size * tensor_model_parallel_rank,
528
- shard_size * (tensor_model_parallel_rank + 1)
529
- )
530
- u_offsets = (
531
- intermediate_size + shard_size * tensor_model_parallel_rank,
532
- intermediate_size + shard_size * (tensor_model_parallel_rank + 1)
533
- )
534
- _loaded_weight = torch.cat(
535
- [
536
- loaded_weight[g_offsets[0]:g_offsets[1]],
537
- loaded_weight[u_offsets[0]:u_offsets[1]],
538
- ], 0
539
- )
540
- assert param.shape == _loaded_weight.shape
541
- param.data.copy_(_loaded_weight)
542
- loaded += 1.0
543
- is_gate_up_weight = True
544
- if is_gate_up_weight:
545
- continue
546
-
547
-
548
- param = state_dict[name]
549
- load_tensor_parallel_weights(param, loaded_weight, name,
550
- self._column_parallel_weights,
551
- self._row_parallel_weights,
552
- tensor_model_parallel_rank)
553
- loaded += 1
554
-
555
- if np.abs(loaded - need_to_load) < 0.01:
556
- print(f'WARNING: only {loaded} params loaded out of {need_to_load}')
557
- else:
558
- print(f'Loaded all {loaded} params loaded out of {need_to_load}')
559
-
560
-
561
- def new_llama_load_weights(
562
- self,
563
- model_name_or_path: str,
564
- cache_dir: Optional[str] = None,
565
- load_format: str = "auto",
566
- revision: Optional[str] = None
567
- ):
568
- # If use newest vllm, not been thoroughly tested yet.
569
- from vllm.model_executor.weight_utils import (
570
- load_tensor_parallel_weights, hf_model_weights_iterator
571
- )
572
- from vllm.model_executor.parallel_utils.parallel_state import (
573
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
574
-
575
- if self.quant_config is None:
576
- weight_suffixes = ["weight"]
577
- else:
578
- weight_suffixes = self.quant_config.get_tp_tensor_names()
579
-
580
- column_parallel_weights: List[str] = []
581
- for layer in self._column_parallel_layers:
582
- for suffix in weight_suffixes:
583
- column_parallel_weights.append(f"{layer}.{suffix}")
584
- row_parallel_weights: List[str] = []
585
- for layer in self._row_parallel_layers:
586
- for suffix in weight_suffixes:
587
- row_parallel_weights.append(f"{layer}.{suffix}")
588
-
589
- tp_size = get_tensor_model_parallel_world_size()
590
- tp_rank = get_tensor_model_parallel_rank()
591
- assert tp_size == 1, f'tensorparallel >=2 not allowed. {tp_size}'
592
- q_proj_shard_size = (self.config.hidden_size // tp_size)
593
- num_kv_heads_replicas = max(1,
594
- tp_size // self.config.num_key_value_heads)
595
- num_kv_heads_per_gpu = max(1,
596
- self.config.num_key_value_heads // tp_size)
597
- kv_proj_shard_size = (self.config.hidden_size //
598
- self.config.num_attention_heads *
599
- num_kv_heads_per_gpu)
600
- attention_weight_specs = [
601
- # (weight_name, shard_size, offset)
602
- ("q_proj", q_proj_shard_size, 0),
603
- ("k_proj", kv_proj_shard_size, q_proj_shard_size),
604
- ("v_proj", kv_proj_shard_size,
605
- q_proj_shard_size + kv_proj_shard_size),
606
- ]
607
- state_dict = self.state_dict()
608
- need_to_load = len(state_dict)
609
- loaded = 0
610
-
611
- for name, loaded_weight in hf_model_weights_iterator(
612
- model_name_or_path, cache_dir, load_format, revision):
613
- if "rotary_emb.inv_freq" in name:
614
- continue
615
-
616
- is_packed = False
617
- is_transposed = False
618
- if self.quant_config is not None:
619
- is_packed = self.quant_config.is_packed(name)
620
- is_transposed = self.quant_config.is_transposed(name)
621
- if is_transposed:
622
- loaded_weight = convert_pyslice_to_tensor(loaded_weight)
623
- loaded_weight = loaded_weight.T
624
-
625
- is_attention_weight = False
626
- for weight_name, shard_size, offset in attention_weight_specs:
627
- if weight_name not in name or "qkv_proj" in name:
628
- continue
629
- param = state_dict[name.replace(weight_name, "qkv_proj")]
630
- if is_transposed:
631
- param = param.T
632
-
633
- if is_packed:
634
- shard_size //= self.quant_config.pack_factor
635
- offset //= self.quant_config.pack_factor
636
-
637
- if weight_name in ["k_proj", "v_proj"]:
638
- shard_id = tp_rank // num_kv_heads_replicas
639
- else:
640
- shard_id = tp_rank
641
- loaded_weight = loaded_weight[shard_size *
642
- shard_id:shard_size *
643
- (shard_id + 1)]
644
- param_slice = param.data[offset:offset + shard_size]
645
- assert param_slice.shape == loaded_weight.shape
646
-
647
- param_slice.copy_(loaded_weight)
648
- loaded += 1.0 / 3
649
- is_attention_weight = True
650
- break
651
- if is_attention_weight:
652
- continue
653
-
654
- # TODO: need to figure out to do sharding with qkv_proj fused
655
-
656
- is_gate_up_weight = False
657
- for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
658
- if weight_name not in name or "gate_up_proj" in name:
659
- continue
660
- param = state_dict[name.replace(weight_name, "gate_up_proj")]
661
- if is_transposed:
662
- param = param.T
663
-
664
- shard_size = param.shape[0] // 2
665
- loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
666
- (tp_rank + 1)]
667
- param_slice = param.data[shard_size * stride_id:shard_size *
668
- (stride_id + 1)]
669
- assert param_slice.shape == loaded_weight.shape
670
- param_slice.copy_(loaded_weight)
671
- loaded += 1.0 / 2
672
- is_gate_up_weight = True
673
- break
674
- if is_gate_up_weight:
675
- continue
676
-
677
- # TODO: need to figure out to do sharding with gate_up_proj fused
678
-
679
- param = state_dict[name]
680
- if is_transposed:
681
- param = param.T
682
-
683
- if "embed_tokens" in name or "lm_head" in name:
684
- load_padded_tensor_parallel_vocab(param, loaded_weight,
685
- tp_rank)
686
- loaded += 1
687
- continue
688
-
689
- load_tensor_parallel_weights(param, loaded_weight, name,
690
- column_parallel_weights,
691
- row_parallel_weights, tp_rank)
692
- loaded += 1
693
-
694
- if np.abs(loaded - need_to_load) < 0.01:
695
- print(f'WARNING: only {loaded} params loaded out of {need_to_load}')
696
- else:
697
- print(f'Loaded all {loaded} params loaded out of {need_to_load}')
698
-
699
-
700
- # Reassign LlamaForCausalLM.load_weights with llama_load_weights
701
- if not DEBUG:
702
-
703
- try:
704
- import vllm
705
- from vllm.model_executor.model_loader import _MODEL_REGISTRY
706
- from vllm.model_executor.models import LlamaForCausalLM
707
-
708
- _MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
709
- if vllm.__version__ == "0.1.4":
710
- LlamaForCausalLM.load_weights = llama_load_weights
711
- else:
712
- LlamaForCausalLM.load_weights = new_llama_load_weights
713
-
714
- if DTYPE == "bfloat16":
715
- try:
716
- compute_capability = torch.cuda.get_device_capability()
717
- if compute_capability[0] < 8:
718
- gpu_name = torch.cuda.get_device_name()
719
- print(
720
- "Bfloat16 is only supported on GPUs with compute capability "
721
- f"of at least 8.0. Your {gpu_name} GPU has compute capability "
722
- f"{compute_capability[0]}.{compute_capability[1]}. --> Move to FLOAT16")
723
- DTYPE = "float16"
724
- except Exception as e:
725
- print(f'Unable to obtain compute_capability: {e}')
726
- except Exception as e:
727
- print(f'Failing import and reconfigure VLLM: {str(e)}')
728
-
729
-
730
  # ! ==================================================================
731
 
732
  set_documentation_group("component")
@@ -734,41 +290,6 @@ set_documentation_group("component")
734
 
735
  RES_PRINTED = False
736
 
737
- def llama_chat_sys_input_seq_constructor(text, sys_prompt=SYSTEM_PROMPT_1, bos_token=BOS_TOKEN, eos_token=EOS_TOKEN):
738
- return f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {text} {E_INST}"
739
-
740
-
741
- def llama_chat_multiturn_sys_input_seq_constructor(
742
- message: str,
743
- history: List[Tuple[str, str]],
744
- sys_prompt=SYSTEM_PROMPT_1,
745
- bos_token=BOS_TOKEN,
746
- eos_token=EOS_TOKEN,
747
- include_end_instruct=True,
748
- ):
749
- """
750
- ```
751
- <bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
752
- <bos>[INST] Prompt [/INST] Answer <eos>
753
- <bos>[INST] Prompt [/INST]
754
- ```
755
- """
756
- text = ''
757
- end_instr = f" {E_INST}" if include_end_instruct else ""
758
- for i, (prompt, res) in enumerate(history):
759
- if i == 0:
760
- text += f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {prompt}{end_instr}"
761
- else:
762
- text += f"{bos_token}{B_INST} {prompt}{end_instr}"
763
-
764
- if res is not None:
765
- text += f" {res} {eos_token} "
766
- if len(history) == 0 or text.strip() == '':
767
- text = f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {message}{end_instr}"
768
- else:
769
- text += f"{bos_token}{B_INST} {message}{end_instr}"
770
- return text
771
-
772
 
773
  @document()
774
  class ChatBot(gr.Chatbot):
@@ -966,29 +487,63 @@ def _setup_events(self) -> None:
966
  )
967
 
968
  # Reconfigure clear_btn to stop and clear text box
969
- # if self.clear_btn:
970
- # self.clear_btn.click(
971
- # lambda: ([], [], None),
972
- # None,
973
- # [self.chatbot, self.chatbot_state, self.saved_input],
974
- # queue=False,
975
- # api_name=False,
976
- # cancels=submit_event,
977
- # )
978
 
979
 
980
  def _display_input(
981
- self, message: str, history: list[list[str | None]]
982
- ) -> tuple[list[list[str | None]], list[list[str | None]]]:
983
  if message is not None and message.strip() != "":
984
  history.append([message, None])
985
  return history, history
986
 
987
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
988
  # replace
989
  gr.ChatInterface._setup_stop_events = _setup_stop_events
990
  gr.ChatInterface._setup_events = _setup_events
991
  gr.ChatInterface._display_input = _display_input
 
992
 
993
 
994
  @document()
@@ -1036,25 +591,6 @@ class CustomTabbedInterface(gr.Blocks):
1036
  interface.render()
1037
 
1038
 
1039
-
1040
- # def vllm_abort(self: Any):
1041
- # sh = self.llm_engine.scheduler
1042
- # for g in (sh.waiting + sh.running + sh.swapped):
1043
- # sh.abort_seq_group(g.request_id)
1044
-
1045
- # from vllm.sequence import SequenceStatus
1046
- # scheduler = self.llm_engine.scheduler
1047
- # for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
1048
- # for seq_group in state_queue:
1049
- # # if seq_group.request_id == request_id:
1050
- # # Remove the sequence group from the state queue.
1051
- # state_queue.remove(seq_group)
1052
- # for seq in seq_group.seqs:
1053
- # if seq.is_finished():
1054
- # continue
1055
- # scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
1056
-
1057
-
1058
  def vllm_abort(self):
1059
  sh = self.llm_engine.scheduler
1060
  for g in (sh.waiting + sh.running + sh.swapped):
@@ -1231,6 +767,14 @@ def chatml_format(message, history=None, system_prompt=None):
1231
  return chatml_chat_convo_format(conversations, True, default_system=system_prompt)
1232
 
1233
 
 
 
 
 
 
 
 
 
1234
  def chat_response_stream_multiturn(
1235
  message: str,
1236
  history: List[Tuple[str, str]],
@@ -1242,6 +786,9 @@ def chat_response_stream_multiturn(
1242
  system_prompt: Optional[str] = SYSTEM_PROMPT_1
1243
  ) -> str:
1244
  global LOG_FILE, LOG_PATH
 
 
 
1245
  from vllm import LLM, SamplingParams
1246
  """Build multi turn
1247
  <bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
@@ -1274,16 +821,12 @@ def chat_response_stream_multiturn(
1274
 
1275
  message_safety = safety_check(message, history=history)
1276
  if message_safety is not None:
1277
- yield message_safety
1278
- return
1279
 
1280
  # history will be appended with message later on
1281
 
1282
- # full_prompt = llama_chat_multiturn_sys_input_seq_constructor(
1283
- # message, history, sys_prompt=system_prompt
1284
- # )
1285
  full_prompt = chatml_format(message.strip(), history=history, system_prompt=system_prompt)
1286
- # print(full_prompt)
1287
 
1288
  if len(tokenizer.encode(full_prompt, add_special_tokens=False)) >= 4050:
1289
  raise gr.Error(f"Conversation or prompt is too long, please clear the chatbox or try shorter input.")
@@ -1334,6 +877,89 @@ def chat_response_stream_multiturn(
1334
  if message_safety is not None:
1335
  yield message_safety
1336
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1337
 
1338
 
1339
  def maybe_log_conv_file(current_time, history, message, response, **kwargs):
@@ -1715,6 +1341,48 @@ CHAT_EXAMPLES = [
1715
 
1716
  # performance items
1717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1718
 
1719
  def launch_demo():
1720
  global demo, llm, DEBUG, LOG_FILE
@@ -1817,7 +1485,7 @@ def launch_demo():
1817
  gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation'),
1818
  gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens'),
1819
  gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens'),
1820
- gr.Textbox(value="[STOP],[END],<s>,</s>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1),
1821
  gr.Number(value=0, label='current_time', visible=False),
1822
  ],
1823
  outputs=[
@@ -1829,11 +1497,13 @@ def launch_demo():
1829
  description=FILE_UPLOAD_DESCRIPTION,
1830
  allow_flagging=False,
1831
  examples=[
1832
- ["upload_chat.json", "chat", 0.2, 1024, 0.5, 0, "[STOP],[END],<s>,</s>"],
1833
- ["upload_few_shot.json", "few-shot", 0.2, 128, 0.5, 0, "[STOP],[END],<s>,</s>,\\n"]
1834
  ],
1835
  cache_examples=False,
1836
  )
 
 
1837
 
1838
  demo_chat = gr.ChatInterface(
1839
  response_fn,
@@ -1869,8 +1539,8 @@ def launch_demo():
1869
  descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
1870
 
1871
  demo = CustomTabbedInterface(
1872
- interface_list=[demo_chat, demo_file_upload],
1873
- tab_names=["Chat Interface", "Batch Inference"],
1874
  title=f"{model_title}",
1875
  description=descriptions,
1876
  )
 
18
  import glob
19
  import json
20
  import time
21
+ from gradio.routes import Request
22
+ from gradio.utils import SyncToAsyncIterator, async_iteration
23
+ from gradio.helpers import special_args
24
+ import anyio
25
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
26
 
27
  from gradio_client.documentation import document, set_documentation_group
28
 
 
283
 
284
 
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  # ! ==================================================================
287
 
288
  set_documentation_group("component")
 
290
 
291
  RES_PRINTED = False
292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
  @document()
295
  class ChatBot(gr.Chatbot):
 
487
  )
488
 
489
  # Reconfigure clear_btn to stop and clear text box
 
 
 
 
 
 
 
 
 
490
 
491
 
492
  def _display_input(
493
+ self, message: str, history: List[List[Union[str, None]]]
494
+ ) -> Tuple[List[List[Union[str, None]]], List[List[list[Union[str, None]]]]]:
495
  if message is not None and message.strip() != "":
496
  history.append([message, None])
497
  return history, history
498
 
499
 
500
+ async def _stream_fn(
501
+ self,
502
+ message: str,
503
+ history_with_input,
504
+ request: Request,
505
+ *args,
506
+ ) -> AsyncGenerator:
507
+ history = history_with_input[:-1]
508
+ inputs, _, _ = special_args(
509
+ self.fn, inputs=[message, history, *args], request=request
510
+ )
511
+
512
+ if self.is_async:
513
+ generator = self.fn(*inputs)
514
+ else:
515
+ generator = await anyio.to_thread.run_sync(
516
+ self.fn, *inputs, limiter=self.limiter
517
+ )
518
+ generator = SyncToAsyncIterator(generator, self.limiter)
519
+ try:
520
+ first_response = await async_iteration(generator)
521
+ update = history + [[message, first_response]]
522
+ yield update, update
523
+ except StopIteration:
524
+ update = history + [[message, None]]
525
+ yield update, update
526
+ try:
527
+ async for response in generator:
528
+ update = history + [[message, response]]
529
+ yield update, update
530
+ except Exception as e:
531
+ # if "invalid" in str(e):
532
+ # yield history, history
533
+ # raise e
534
+ # else:
535
+ # raise e
536
+ yield history, history
537
+ raise e
538
+
539
+
540
+
541
+
542
  # replace
543
  gr.ChatInterface._setup_stop_events = _setup_stop_events
544
  gr.ChatInterface._setup_events = _setup_events
545
  gr.ChatInterface._display_input = _display_input
546
+ gr.ChatInterface._stream_fn = _stream_fn
547
 
548
 
549
  @document()
 
591
  interface.render()
592
 
593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
  def vllm_abort(self):
595
  sh = self.llm_engine.scheduler
596
  for g in (sh.waiting + sh.running + sh.swapped):
 
767
  return chatml_chat_convo_format(conversations, True, default_system=system_prompt)
768
 
769
 
770
+ def debug_chat_response_stream_multiturn(*args, **kwargs):
771
+ message = "This is a debugging message"
772
+ for i in range(len(message)):
773
+ time.sleep(0.05)
774
+ yield message[:i]
775
+
776
+
777
+
778
  def chat_response_stream_multiturn(
779
  message: str,
780
  history: List[Tuple[str, str]],
 
786
  system_prompt: Optional[str] = SYSTEM_PROMPT_1
787
  ) -> str:
788
  global LOG_FILE, LOG_PATH
789
+ if DEBUG:
790
+ yield from debug_chat_response_stream_multiturn()
791
+ return
792
  from vllm import LLM, SamplingParams
793
  """Build multi turn
794
  <bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
 
821
 
822
  message_safety = safety_check(message, history=history)
823
  if message_safety is not None:
824
+ # yield message_safety
825
+ raise gr.Error(message_safety)
826
 
827
  # history will be appended with message later on
828
 
 
 
 
829
  full_prompt = chatml_format(message.strip(), history=history, system_prompt=system_prompt)
 
830
 
831
  if len(tokenizer.encode(full_prompt, add_special_tokens=False)) >= 4050:
832
  raise gr.Error(f"Conversation or prompt is too long, please clear the chatbox or try shorter input.")
 
877
  if message_safety is not None:
878
  yield message_safety
879
  return
880
+
881
+
882
+
883
+ def debug_generate_free_form_stream(message):
884
+ output = " This is a debugging message...."
885
+ for i in range(len(output)):
886
+ time.sleep(0.05)
887
+ yield message + output[:i]
888
+
889
+
890
+ def generate_free_form_stream(
891
+ message: str,
892
+ temperature: float,
893
+ max_tokens: int,
894
+ frequency_penalty: float,
895
+ presence_penalty: float,
896
+ current_time: Optional[float] = None,
897
+ stop_strings: str = '<s>,</s>,<|im_start|>,<|im_end|>',
898
+ ) -> str:
899
+ global LOG_FILE, LOG_PATH
900
+ if DEBUG:
901
+ yield from debug_generate_free_form_stream(message)
902
+ return
903
+ from vllm import LLM, SamplingParams
904
+ """Build multi turn
905
+ """
906
+ global llm, RES_PRINTED
907
+ assert llm is not None
908
+ tokenizer = llm.get_tokenizer()
909
+ # force removing all
910
+ vllm_abort(llm)
911
+
912
+ temperature = float(temperature)
913
+ frequency_penalty = float(frequency_penalty)
914
+ max_tokens = int(max_tokens)
915
+
916
+ stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
917
+ stop_strings = list(set(stop_strings + ['</s>', '<|im_start|>']))
918
+
919
+ sampling_params = SamplingParams(
920
+ temperature=temperature,
921
+ max_tokens=max_tokens,
922
+ frequency_penalty=frequency_penalty,
923
+ presence_penalty=presence_penalty,
924
+ stop=stop_strings,
925
+ # ignore_eos=True,
926
+ )
927
+
928
+ # full_prompt = message
929
+ if len(message) == 0:
930
+ raise gr.Error("The message cannot be empty!")
931
+
932
+ message_safety = safety_check(message)
933
+ if message_safety is not None:
934
+ raise gr.Error(message_safety)
935
+
936
+ if len(tokenizer.encode(message, add_special_tokens=False)) >= 4050:
937
+ raise gr.Error(f"Prompt is too long!")
938
+
939
+ cur_out = None
940
+ for j, gen in enumerate(vllm_generate_stream(llm, message, sampling_params)):
941
+ if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
942
+ # optionally check safety, and respond
943
+ if STREAM_CHECK_MULTIPLE > 0 and j % STREAM_CHECK_MULTIPLE == 0:
944
+ message_safety = safety_check(cur_out, history=None)
945
+ if message_safety is not None:
946
+ raise gr.Error(message_safety)
947
+ yield message + cur_out
948
+ assert len(gen) == 1, f'{gen}'
949
+ item = next(iter(gen.values()))
950
+ cur_out = item.outputs[0].text
951
+ #cur_out = "Our system is under maintenance, will be back soon!"
952
+ if j >= max_tokens - 2:
953
+ gr.Warning(f'The response hits limit of {max_tokens} tokens. Consider increase the max tokens parameter in the Additional Inputs.')
954
+
955
+ if cur_out is not None:
956
+ yield message + cur_out
957
+
958
+ message_safety = safety_check(message + cur_out, history=None)
959
+ if message_safety is not None:
960
+ raise gr.Error(message_safety)
961
+
962
+
963
 
964
 
965
  def maybe_log_conv_file(current_time, history, message, response, **kwargs):
 
1341
 
1342
  # performance items
1343
 
1344
+ def create_free_form_generation_demo():
1345
+ global short_model_path
1346
+ max_tokens = MAX_TOKENS
1347
+ temperature = TEMPERATURE
1348
+ frequence_penalty = FREQUENCE_PENALTY
1349
+ presence_penalty = PRESENCE_PENALTY
1350
+
1351
+ introduction = """
1352
+ ## Free-form:
1353
+ Put any context string (like few-shot prompts) and get the model to generate.
1354
+ """
1355
+
1356
+ with gr.Blocks() as demo_free_form:
1357
+ gr.Markdown(introduction)
1358
+
1359
+ with gr.Row():
1360
+ txt = gr.Textbox(
1361
+ scale=4,
1362
+ lines=16,
1363
+ show_label=False,
1364
+ placeholder="Enter any free form text and submit",
1365
+ container=False,
1366
+ )
1367
+ with gr.Row():
1368
+ free_submit_button = gr.Button('Submit')
1369
+ with gr.Row():
1370
+ temp = gr.Number(value=temperature, label='Temperature', info="Higher -> more random")
1371
+ length = gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation')
1372
+ freq_pen = gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens')
1373
+ pres_pen = gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens')
1374
+ stop_strings = gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1)
1375
+
1376
+ free_submit_button.click(
1377
+ generate_free_form_stream,
1378
+ [txt, temp, length, freq_pen, pres_pen, stop_strings],
1379
+ txt
1380
+ )
1381
+ return demo_free_form
1382
+
1383
+
1384
+
1385
+
1386
 
1387
  def launch_demo():
1388
  global demo, llm, DEBUG, LOG_FILE
 
1485
  gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation'),
1486
  gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens'),
1487
  gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens'),
1488
+ gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1),
1489
  gr.Number(value=0, label='current_time', visible=False),
1490
  ],
1491
  outputs=[
 
1497
  description=FILE_UPLOAD_DESCRIPTION,
1498
  allow_flagging=False,
1499
  examples=[
1500
+ ["upload_chat.json", "chat", 0.2, 1024, 0.5, 0, "<s>,</s>,<|im_start|>"],
1501
+ ["upload_few_shot.json", "few-shot", 0.2, 128, 0.5, 0, "<s>,</s>,<|im_start|>,\\n"]
1502
  ],
1503
  cache_examples=False,
1504
  )
1505
+
1506
+ demo_free_form = create_free_form_generation_demo()
1507
 
1508
  demo_chat = gr.ChatInterface(
1509
  response_fn,
 
1539
  descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
1540
 
1541
  demo = CustomTabbedInterface(
1542
+ interface_list=[demo_chat, demo_file_upload, demo_free_form],
1543
+ tab_names=["Chat Interface", "Batch Inference", "Free-form"],
1544
  title=f"{model_title}",
1545
  description=descriptions,
1546
  )