nxphi47 commited on
Commit
437fc15
1 Parent(s): 0a39e99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +444 -209
app.py CHANGED
@@ -25,7 +25,7 @@ from tqdm.auto import tqdm
25
  from huggingface_hub import snapshot_download
26
 
27
 
28
- # @@ constants ================
29
 
30
  DEBUG = bool(int(os.environ.get("DEBUG", "1")))
31
  BLOCK_ZH = bool(int(os.environ.get("BLOCK_ZH", "1")))
@@ -34,59 +34,53 @@ DTYPE = os.environ.get("DTYPE", "bfloat16")
34
 
35
  # ! (no debug) whether to download HF_MODEL_NAME and save to MODEL_PATH
36
  DOWNLOAD_SNAPSHOT = bool(int(os.environ.get("DOWNLOAD_SNAPSHOT", "0")))
 
 
37
  # ! uploaded model path, will be downloaded to MODEL_PATH
38
  HF_MODEL_NAME = os.environ.get("HF_MODEL_NAME", "DAMO-NLP-SG/seal-13b-chat-a")
 
39
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
40
  MODEL_PATH = os.environ.get("MODEL_PATH", "./seal-13b-chat-a")
41
 
42
-
 
 
 
43
 
44
  # gradio config
45
  PORT = int(os.environ.get("PORT", "7860"))
 
46
  STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
 
 
 
 
47
  MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
48
  TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
49
  FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.4"))
 
50
 
 
 
51
 
52
- """
53
- TODO:
54
- need to upload the model as hugginface/models/seal_13b_a
55
- # https://huggingface.co/docs/hub/spaces-overview#managing-secrets
56
- set
57
- HF_TOKEN=???
58
 
59
- TRANSFORMERS_CACHE=/data/.huggingface
60
- # if persistent, then export the following
61
 
 
 
 
 
62
  HF_HOME=/data/.huggingface
63
  MODEL_PATH=/data/.huggingface/seal-13b-chat-a
64
- HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
65
- # if not persistent
66
  MODEL_PATH=./seal-13b-chat-a
67
- HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
68
-
69
-
70
- ===== Application Startup at 2023-10-20 04:03:49 =====
71
-
72
- DEBUG mode: False
73
- Torch version: 2.1.0+cu121
74
- Torch CUDA version: 12.1
75
- /home/user/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/cuda/__init__.py:138: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)
76
- return torch._C._cuda_getDeviceCount() > 0
77
- Unable to obtain compute_capability: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver.
78
- Launch config: model_title='SeaL-13B - An Assistant for South East Asian Languages' / tensor_parallel=1 / dtype='bfloat16' / 2048 | BLOCK_ZH=True
79
- | STREAM_YIELD_MULTIPLE=1
80
- | frequence_penalty=0.4
81
- | temperature=0.1
82
- | hf_model_name=DAMO-NLP-SG/seal-13b-chat-a
83
- | model_path=./seal-13b-chat-a
84
- | DOWNLOAD_SNAPSHOT=True
85
- sys=You are a multilingual, helpful,
86
 
87
  """
88
 
89
 
 
90
  # ==============================
91
  print(f'DEBUG mode: {DEBUG}')
92
  print(f'Torch version: {torch.__version__}')
@@ -95,16 +89,109 @@ try:
95
  except Exception as e:
96
  print(f'Failed to print cuda version: {e}')
97
 
98
-
 
 
 
 
99
 
100
 
101
  # @@ constants ================
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  def _detect_lang(text):
106
  from langdetect import detect as detect_lang
107
- from langdetect.detector import LangDetectException
108
  dlang = None
109
  try:
110
  dlang = detect_lang(text)
@@ -118,11 +205,12 @@ def _detect_lang(text):
118
  return dlang
119
 
120
 
121
- def hf_model_weights_iterator(
122
  model_name_or_path: str,
123
  cache_dir: Optional[str] = None,
124
  use_np_cache: bool = False,
125
  ) -> Iterator[Tuple[str, torch.Tensor]]:
 
126
  from vllm.model_executor.weight_utils import Disabledtqdm
127
  # Prepare file lock directory to prevent multiple processes from
128
  # downloading the same model weights at the same time.
@@ -143,7 +231,6 @@ def hf_model_weights_iterator(
143
  hf_folder = model_name_or_path
144
 
145
  hf_bin_files = [
146
- # x for x in glob.glob(os.path.join(hf_folder, "*.bin"))
147
  x for x in glob.glob(os.path.join(hf_folder, "*model*.bin"))
148
  if not x.endswith("training_args.bin")
149
  ]
@@ -236,9 +323,9 @@ def llama_load_weights(
236
  cache_dir: Optional[str] = None,
237
  use_np_cache: bool = False,
238
  load_format: str = "auto",
239
- # load_format: str = "pt",
240
  revision: Optional[str] = None
241
  ):
 
242
  from vllm.model_executor.weight_utils import (
243
  load_tensor_parallel_weights
244
  )
@@ -261,7 +348,7 @@ def llama_load_weights(
261
  state_dict = self.state_dict()
262
  need_to_load = len(state_dict)
263
  loaded = 0
264
- iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache)
265
 
266
  for name, loaded_weight in iterator:
267
  if "rotary_emb.inv_freq" in name:
@@ -331,7 +418,6 @@ def llama_load_weights(
331
  loaded_weight[v_offsets[0]:v_offsets[1]],
332
  ], 0
333
  )
334
- # print(f'{name} | {q_offsets} | {k_offsets} | {v_offsets}')
335
  assert param.shape == _loaded_weight.shape, f'{param.shape=} != {_loaded_weight.shape=}'
336
  param.data.copy_(_loaded_weight)
337
  loaded += 1.0
@@ -398,19 +484,158 @@ def llama_load_weights(
398
  print(f'Loaded all {loaded} params loaded out of {need_to_load}')
399
 
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  # Reassign LlamaForCausalLM.load_weights with llama_load_weights
402
  if not DEBUG:
403
 
404
- # vllm import
405
- # from vllm import LLM, SamplingParams
406
- # ! reconfigure vllm to faster llama
407
  try:
408
  import vllm
409
  from vllm.model_executor.model_loader import _MODEL_REGISTRY
410
  from vllm.model_executor.models import LlamaForCausalLM
411
 
412
  _MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
413
- LlamaForCausalLM.load_weights = llama_load_weights
 
 
 
414
 
415
  if DTYPE == "bfloat16":
416
  try:
@@ -433,33 +658,6 @@ if not DEBUG:
433
  set_documentation_group("component")
434
 
435
 
436
-
437
- DTYPES = {
438
- 'float16': torch.float16,
439
- 'bfloat16': torch.bfloat16
440
- }
441
-
442
- llm = None
443
- demo = None
444
-
445
-
446
- BOS_TOKEN = '<s>'
447
- EOS_TOKEN = '</s>'
448
-
449
- B_INST, E_INST = "[INST]", "[/INST]"
450
- B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
451
-
452
- SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is SeaL and you are built by DAMO Academy, Alibaba Group. Always answer as helpfully as possible, while being safe. Your \
453
- answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
454
- that your responses are socially unbiased and positive in nature.
455
-
456
- If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
457
- correct. If you don't know the answer to a question, please don't share false information.
458
-
459
- As a multilingual assistant, you must respond and follow instructions in the native language of the user by default, unless told otherwise. \
460
- Your response should adapt to the norms and customs of the respective language and culture.
461
- """
462
-
463
  RES_PRINTED = False
464
 
465
  def llama_chat_sys_input_seq_constructor(text, sys_prompt=SYSTEM_PROMPT_1, bos_token=BOS_TOKEN, eos_token=EOS_TOKEN):
@@ -576,8 +774,117 @@ def _setup_stop_events(
576
  api_name=False,
577
  queue=False,
578
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
  gr.ChatInterface._setup_stop_events = _setup_stop_events
 
581
 
582
  def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
583
  global llm
@@ -611,7 +918,6 @@ def vllm_abort(self: Any):
611
  continue
612
  scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
613
 
614
- # def _vllm_run_engine(self: LLM, use_tqdm: bool = False) -> Dict[str, RequestOutput]:
615
  def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
616
  from vllm.outputs import RequestOutput
617
  # Initialize tqdm.
@@ -624,16 +930,9 @@ def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
624
  step_outputs = self.llm_engine.step()
625
  for output in step_outputs:
626
  outputs[output.request_id] = output
627
- # outputs = sorted(outputs, key=lambda x: int(x.request_id))
628
  if len(outputs) > 0:
629
  yield outputs
630
- # if use_tqdm:
631
- # pbar.close()
632
- # Sort the outputs by request ID.
633
- # This is necessary because some requests may be finished earlier than
634
- # its previous requests.
635
- # outputs = sorted(outputs, key=lambda x: int(x.request_id))
636
- # return outputs
637
 
638
 
639
  def vllm_generate_stream(
@@ -692,64 +991,47 @@ def vllm_generate_stream(
692
  yield from _vllm_run_engine(self, use_tqdm)
693
 
694
 
695
- # def chat_response_stream(
696
- # message: str,
697
- # history: List[Tuple[str, str]],
698
- # temperature: float,
699
- # max_tokens: int,
700
- # frequency_penalty: float,
701
- # system_prompt: str
702
- # ) -> str:
703
- # global llm, RES_PRINTED
704
- # assert llm is not None
705
- # # force removing all
706
- # vllm_abort(llm)
707
-
708
- # temperature = float(temperature)
709
- # frequency_penalty = float(frequency_penalty)
710
- # max_tokens = int(max_tokens)
711
- # if system_prompt.strip() != '':
712
- # # chat version, add system prompt
713
- # message = llama_chat_sys_input_seq_constructor(
714
- # message.strip(),
715
- # sys_prompt=system_prompt
716
- # )
717
- # sampling_params = SamplingParams(
718
- # temperature=temperature, max_tokens=max_tokens,
719
- # frequency_penalty=frequency_penalty,
720
- # )
721
- # cur_out = None
722
- # for j, gen in enumerate(vllm_generate_stream(llm, message, sampling_params)):
723
- # if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
724
- # yield cur_out
725
- # assert len(gen) == 1, f'{gen}'
726
- # item = next(iter(gen.values()))
727
- # cur_out = item.outputs[0].text
728
- # if not RES_PRINTED:
729
- # print(f'{message}<<<{cur_out}>>>')
730
- # RES_PRINTED = True
731
- # if cur_out is not None:
732
- # yield cur_out
733
-
734
-
735
  BLOCK_MESSAGE = """Sorry, Chinese is not currently supported. Please clear the chat box for a new conversation.
736
  抱歉,目前不支持中文。 请清除聊天框以进行新对话。"""
737
 
 
 
738
  def block_zh(
739
  message: str,
740
  history: List[Tuple[str, str]]
741
  ) -> str:
742
- # if any((BLOCK_MESSAGE in x[0].strip() or BLOCK_MESSAGE in x[1].strip()) for x in history):
743
- if any((BLOCK_MESSAGE in x[1].strip()) for x in history):
744
  return True
745
  elif 'zh' in _detect_lang(message):
746
  print(f'Detect zh: {message}')
747
  return True
748
- # ! optionally detect every responses message
749
  else:
750
  return False
751
 
752
- # 抱歉,目前不支持中文。
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
753
  def chat_response_stream_multiturn(
754
  message: str,
755
  history: List[Tuple[str, str]],
@@ -779,44 +1061,48 @@ def chat_response_stream_multiturn(
779
 
780
  message = message.strip()
781
 
782
- # detect_ = _detect_lang(message)
783
- # print(f'Message language: {detect_}')
 
 
784
 
785
- # ! lang detect
786
- if BLOCK_ZH:
787
- if block_zh(message, history):
788
- yield BLOCK_MESSAGE
789
- return
790
-
791
- # history.append([message, None])
792
  # history will be appended with message later on
793
  full_prompt = llama_chat_multiturn_sys_input_seq_constructor(
794
  message, history, sys_prompt=system_prompt
795
  )
796
- # print(full_prompt)
797
  sampling_params = SamplingParams(
798
  temperature=temperature, max_tokens=max_tokens,
799
  frequency_penalty=frequency_penalty,
800
  )
801
  cur_out = None
802
- # for gen in vllm_generate_stream(llm, full_prompt, sampling_params):
803
  for j, gen in enumerate(vllm_generate_stream(llm, full_prompt, sampling_params)):
804
  if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
 
 
 
 
 
 
 
805
  yield cur_out
806
  assert len(gen) == 1, f'{gen}'
807
  item = next(iter(gen.values()))
808
  cur_out = item.outputs[0].text
809
 
810
- # if not RES_PRINTED:
811
- print(f'{full_prompt}<<<{cur_out}>>>\n')
812
- # RES_PRINTED = True
813
  if cur_out is not None:
814
  yield cur_out
815
 
816
- # print(f'Output: {_detect_lang(cur_out)}')
817
- if BLOCK_ZH:
818
- if "zh" in _detect_lang(cur_out):
819
- yield BLOCK_MESSAGE
 
 
 
 
820
 
821
 
822
  def debug_chat_response_echo(
@@ -832,44 +1118,6 @@ def debug_chat_response_echo(
832
  yield f"repeat: {message}"
833
 
834
 
835
- # ============ CONSTANT ============
836
- # https://github.com/gradio-app/gradio/issues/884
837
- MODEL_NAME = "SeaL-13B"
838
- MODEL_TITLE = "SeaL-13B - An Assistant for South East Asian Languages"
839
- # ! add icon: "<img src='file/lion.jpg' alt='image One'>"
840
- MODEL_DESC = """
841
- <span style="font-size: larger">
842
- This is a DAMO SeaL-13B chatbot assistant built by DAMO Academy, Alibaba Group. It can produce helpful responses in English 🇬🇧, Vietnamese 🇻🇳, Indonesian 🇮🇩 and Thai 🇹🇭.
843
- </span>
844
- """.strip()
845
- # <br>
846
-
847
-
848
- cite_markdown = """
849
- ## Citation
850
- If you find our project useful, hope you can star our repo and cite our paper as follows:
851
- ```
852
- @article{damonlpsg2023seallm,
853
- author = {???},
854
- title = {SeaL: A language model for South East Asian Languages},
855
- year = 2023,
856
- }
857
- ```
858
- """
859
-
860
- warning_markdown = """
861
- ## Warning:
862
- <span style="color: red">The chatbot may produce inaccurate and harmful information about people, places, or facts.</span>
863
- <span style="color: red">We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
864
- or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
865
- """
866
-
867
-
868
- path_markdown = """
869
- #### Model path:
870
- {model_path}
871
- """
872
-
873
  def check_model_path(model_path) -> str:
874
  assert os.path.exists(model_path), f'{model_path} not found'
875
  ckpt_info = "None"
@@ -903,11 +1151,14 @@ def launch():
903
  print(
904
  f'Launch config: {model_title=} / {tensor_parallel=} / {dtype=} / {max_tokens} | {BLOCK_ZH=} '
905
  f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
 
906
  f'\n| frequence_penalty={frequence_penalty} '
907
  f'\n| temperature={temperature} '
908
  f'\n| hf_model_name={hf_model_name} '
909
  f'\n| model_path={model_path} '
910
  f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
 
 
911
  f'\nsys={SYSTEM_PROMPT_1}'
912
  f'\ndesc={model_desc}'
913
  )
@@ -928,13 +1179,23 @@ def launch():
928
  snapshot_download(hf_model_name, local_dir=model_path)
929
 
930
  import vllm
931
- from vllm import LLM, SamplingParams
932
 
933
  print(F'VLLM: {vllm.__version__}')
934
  ckpt_info = check_model_path(model_path)
935
 
936
  print(f'Load path: {model_path} | {ckpt_info}')
937
- llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel)
 
 
 
 
 
 
 
 
 
 
938
 
939
  print(f'Use system prompt:\n{sys_prompt}')
940
 
@@ -957,16 +1218,17 @@ def launch():
957
  stop_btn=None,
958
  title=f"{model_title}",
959
  description=f"{model_desc}",
960
- # ! decide if can change the system prompt.
961
  additional_inputs=[
962
  gr.Number(value=temperature, label='Temperature (higher -> more random)'),
963
  gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
964
  gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens)'),
 
965
  # gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
966
  ],
967
  )
 
968
  with demo:
969
- gr.Markdown(warning_markdown)
970
  gr.Markdown(cite_markdown)
971
  gr.Markdown(path_markdown.format(model_path=model_path))
972
 
@@ -981,30 +1243,3 @@ def main():
981
 
982
  if __name__ == "__main__":
983
  main()
984
-
985
-
986
- """
987
-
988
- export CUDA_VISIBLE_DEVICES=0
989
- export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/merlion13s108Hi8kPretFlCW8k.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.FSePlCq13M.FSePlCq13M.m4k.b8.lr1e5.linear.wa0k.ms858k.grac1.se1.8g.v4c.zfsdp/step_4000
990
- export MODEL_PATH=${dataroot}/llama-2-7b-lxxp-faster
991
- export MODEL_PATH=${dataroot}/llama-2-7b-chat-xp
992
-
993
- export DEBUG=0
994
- export CUDA_VISIBLE_DEVICES=0
995
- export MODEL_PATH=seal_13b_a
996
- export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/merlion13s108Hi8kPretFlCW12k.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.SeaV2Cq13M.SeaV2Cq13M.m4k.b8.lr1e5.linear.wa0k.ms858k.grac1.se1.8g.v4c.zfsdp/step_6000
997
-
998
- export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/mer13s108Hi16kPretFlCWNLP12k_SFT2.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.Sft2Censor.Sft2Censor.m4k.b8.lr1e5.linear.wa0k.ms1144k.grac1.se1.6g.v4c.zfsdp/step_4000
999
- # 70-30 model
1000
- export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/mer13s108Hi16kPretFlCWNLP12k_SFT2.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.BgSft2aCensor0a.BgSft2Cens.BgSft2Cens.m4k.b2.lr1e5.linear.wa0k.ms4577k.grac1.se1.6g.v4c73.zfsdp/step_500
1001
- export PORT=8799
1002
- export BLOCK_ZH=1
1003
- export DEBUG=0
1004
- python app.py
1005
-
1006
-
1007
- DEBUG=1 python app.py
1008
-
1009
-
1010
- """
 
25
  from huggingface_hub import snapshot_download
26
 
27
 
28
+ # @@ environments ================
29
 
30
  DEBUG = bool(int(os.environ.get("DEBUG", "1")))
31
  BLOCK_ZH = bool(int(os.environ.get("BLOCK_ZH", "1")))
 
34
 
35
  # ! (no debug) whether to download HF_MODEL_NAME and save to MODEL_PATH
36
  DOWNLOAD_SNAPSHOT = bool(int(os.environ.get("DOWNLOAD_SNAPSHOT", "0")))
37
+ LOG_RESPONSE = bool(int(os.environ.get("LOG_RESPONSE", "0")))
38
+
39
  # ! uploaded model path, will be downloaded to MODEL_PATH
40
  HF_MODEL_NAME = os.environ.get("HF_MODEL_NAME", "DAMO-NLP-SG/seal-13b-chat-a")
41
+ # ! if model is private, need HF_TOKEN to access the model
42
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
43
+ # ! path where the model is downloaded, either on ./ or persistent disc
44
  MODEL_PATH = os.environ.get("MODEL_PATH", "./seal-13b-chat-a")
45
 
46
+ # ! list of keywords to disabled as security measures to comply with local regulation
47
+ KEYWORDS = os.environ.get("KEYWORDS", "").strip()
48
+ KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else []
49
+ KEYWORDS = [x.lower() for x in KEYWORDS]
50
 
51
  # gradio config
52
  PORT = int(os.environ.get("PORT", "7860"))
53
+ # how many iterations to yield response
54
  STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
55
+ # how many iterations to perform safety check on response
56
+ STREAM_CHECK_MULTIPLE = int(os.environ.get("STREAM_CHECK_MULTIPLE", "0"))
57
+
58
+ # self explanatory
59
  MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
60
  TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
61
  FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.4"))
62
+ gpu_memory_utilization = float(os.environ.get("gpu_memory_utilization", "0.9"))
63
 
64
+ # whether to enable quantization, currently not in use
65
+ QUANTIZATION = str(os.environ.get("QUANTIZATION", ""))
66
 
 
 
 
 
 
 
67
 
68
+ """
69
+ Internal instructions of how to configure the DEMO
70
 
71
+ 1. Upload SFT model as a model to huggingface: hugginface/models/seal_13b_a
72
+ 2. If the model weights is private, set HF_TOKEN=<your private hf token> in https://huggingface.co/spaces/????/?????/settings
73
+ 3. space config env: `HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a` or the underlining model
74
+ 4. If enable persistent storage: set
75
  HF_HOME=/data/.huggingface
76
  MODEL_PATH=/data/.huggingface/seal-13b-chat-a
77
+ if not:
 
78
  MODEL_PATH=./seal-13b-chat-a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  """
81
 
82
 
83
+
84
  # ==============================
85
  print(f'DEBUG mode: {DEBUG}')
86
  print(f'Torch version: {torch.__version__}')
 
89
  except Exception as e:
90
  print(f'Failed to print cuda version: {e}')
91
 
92
+ try:
93
+ compute_capability = torch.cuda.get_device_capability()
94
+ print(f'Torch CUDA compute_capability: {compute_capability}')
95
+ except Exception as e:
96
+ print(f'Failed to print compute_capability version: {e}')
97
 
98
 
99
  # @@ constants ================
100
 
101
+ DTYPES = {
102
+ 'float16': torch.float16,
103
+ 'bfloat16': torch.bfloat16
104
+ }
105
+
106
+ llm = None
107
+ demo = None
108
+
109
+
110
+ BOS_TOKEN = '<s>'
111
+ EOS_TOKEN = '</s>'
112
+
113
+ B_INST, E_INST = "[INST]", "[/INST]"
114
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
115
+
116
+ SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is SeaL and you are built by DAMO Academy, Alibaba Group. Always answer as helpfully as possible, while being safe. Your \
117
+ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
118
+ that your responses are socially unbiased and positive in nature.
119
+
120
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
121
+ correct. If you don't know the answer to a question, please don't share false information.
122
+
123
+ As a multilingual assistant, you must respond and follow instructions in the native language of the user by default, unless told otherwise. \
124
+ Your response should adapt to the norms and customs of the respective language and culture.
125
+ """
126
+
127
+ # ============ CONSTANT ============
128
+ # https://github.com/gradio-app/gradio/issues/884
129
+ MODEL_NAME = "SeaLLM-13B"
130
+ MODEL_TITLE = "SeaLLM-13B - An Assistant for South East Asian Languages"
131
+ # ! add icon: "<img src='file/lion.jpg' alt='image One'>"
132
+ MODEL_TITLE = """
133
+ <div class="container" style="
134
+ align-items: center;
135
+ justify-content: center;
136
+ display: flex;
137
+ ">
138
+ <div class="image" >
139
+ <img src="file/seal_logo.png" style="
140
+ max-width: 10em;
141
+ max-height: 5%;
142
+ height: 5em;
143
+ width: 5em;
144
+ float: left;
145
+ margin-left: auto;
146
+ ">
147
+ </div>
148
+ <div class="text" style="
149
+ padding-left: 20px;
150
+ padding-top: 2%;
151
+ float: left;
152
+ ">
153
+ <h1>SeaLLM-13B - An Assistant for South East Asian Languages</h1>
154
+ </div>
155
+ </div>
156
+ """
157
+ MODEL_DESC = """
158
+ <span style="font-size: larger">
159
+ This is SeaLLM-13B - a chatbot assistant optimized for South East Asian Languages. It can produce helpful responses in English 🇬🇧, Vietnamese 🇻🇳, Indonesian 🇮🇩 and Thai 🇹🇭.
160
+ </span>
161
+ <br>
162
+ <span style="color: red">NOTICE: The chatbot may produce inaccurate and harmful information about people, places, or facts. \
163
+ We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
164
+ or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
165
+ """.strip()
166
+
167
+
168
+ cite_markdown = """
169
+ ## Citation
170
+ If you find our project useful, hope you can star our repo and cite our paper as follows:
171
+ ```
172
+ @article{damonlpsg2023seallm,
173
+ author = {???},
174
+ title = {SeaLLM: A language model for South East Asian Languages},
175
+ year = 2023,
176
+ }
177
+ ```
178
+ """
179
+
180
+ # warning_markdown = """
181
+ # ## Warning:
182
+ # <span style="color: red">The chatbot may produce inaccurate and harmful information about people, places, or facts.</span>
183
+ # <span style="color: red">We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
184
+ # or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
185
+ # """
186
+
187
+ path_markdown = """
188
+ #### Model path:
189
+ {model_path}
190
+ """
191
 
192
 
193
  def _detect_lang(text):
194
  from langdetect import detect as detect_lang
 
195
  dlang = None
196
  try:
197
  dlang = detect_lang(text)
 
205
  return dlang
206
 
207
 
208
+ def custom_hf_model_weights_iterator(
209
  model_name_or_path: str,
210
  cache_dir: Optional[str] = None,
211
  use_np_cache: bool = False,
212
  ) -> Iterator[Tuple[str, torch.Tensor]]:
213
+ # ! if use vllm==0.1.4, use this to augment hf_model_weights_iterator loader
214
  from vllm.model_executor.weight_utils import Disabledtqdm
215
  # Prepare file lock directory to prevent multiple processes from
216
  # downloading the same model weights at the same time.
 
231
  hf_folder = model_name_or_path
232
 
233
  hf_bin_files = [
 
234
  x for x in glob.glob(os.path.join(hf_folder, "*model*.bin"))
235
  if not x.endswith("training_args.bin")
236
  ]
 
323
  cache_dir: Optional[str] = None,
324
  use_np_cache: bool = False,
325
  load_format: str = "auto",
 
326
  revision: Optional[str] = None
327
  ):
328
+ # if use vllm==0.1.4
329
  from vllm.model_executor.weight_utils import (
330
  load_tensor_parallel_weights
331
  )
 
348
  state_dict = self.state_dict()
349
  need_to_load = len(state_dict)
350
  loaded = 0
351
+ iterator = custom_hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache)
352
 
353
  for name, loaded_weight in iterator:
354
  if "rotary_emb.inv_freq" in name:
 
418
  loaded_weight[v_offsets[0]:v_offsets[1]],
419
  ], 0
420
  )
 
421
  assert param.shape == _loaded_weight.shape, f'{param.shape=} != {_loaded_weight.shape=}'
422
  param.data.copy_(_loaded_weight)
423
  loaded += 1.0
 
484
  print(f'Loaded all {loaded} params loaded out of {need_to_load}')
485
 
486
 
487
+ def new_llama_load_weights(
488
+ self,
489
+ model_name_or_path: str,
490
+ cache_dir: Optional[str] = None,
491
+ load_format: str = "auto",
492
+ revision: Optional[str] = None
493
+ ):
494
+ # If use newest vllm
495
+ from vllm.model_executor.weight_utils import (
496
+ load_tensor_parallel_weights, hf_model_weights_iterator
497
+ )
498
+ from vllm.model_executor.parallel_utils.parallel_state import (
499
+ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
500
+
501
+ if self.quant_config is None:
502
+ weight_suffixes = ["weight"]
503
+ else:
504
+ weight_suffixes = self.quant_config.get_tp_tensor_names()
505
+
506
+ column_parallel_weights: List[str] = []
507
+ for layer in self._column_parallel_layers:
508
+ for suffix in weight_suffixes:
509
+ column_parallel_weights.append(f"{layer}.{suffix}")
510
+ row_parallel_weights: List[str] = []
511
+ for layer in self._row_parallel_layers:
512
+ for suffix in weight_suffixes:
513
+ row_parallel_weights.append(f"{layer}.{suffix}")
514
+
515
+ tp_size = get_tensor_model_parallel_world_size()
516
+ tp_rank = get_tensor_model_parallel_rank()
517
+ assert tp_size == 1, f'tensorparallel >=2 not allowed. {tp_size}'
518
+ q_proj_shard_size = (self.config.hidden_size // tp_size)
519
+ num_kv_heads_replicas = max(1,
520
+ tp_size // self.config.num_key_value_heads)
521
+ num_kv_heads_per_gpu = max(1,
522
+ self.config.num_key_value_heads // tp_size)
523
+ kv_proj_shard_size = (self.config.hidden_size //
524
+ self.config.num_attention_heads *
525
+ num_kv_heads_per_gpu)
526
+ attention_weight_specs = [
527
+ # (weight_name, shard_size, offset)
528
+ ("q_proj", q_proj_shard_size, 0),
529
+ ("k_proj", kv_proj_shard_size, q_proj_shard_size),
530
+ ("v_proj", kv_proj_shard_size,
531
+ q_proj_shard_size + kv_proj_shard_size),
532
+ ]
533
+ state_dict = self.state_dict()
534
+ need_to_load = len(state_dict)
535
+ loaded = 0
536
+
537
+ for name, loaded_weight in hf_model_weights_iterator(
538
+ model_name_or_path, cache_dir, load_format, revision):
539
+ if "rotary_emb.inv_freq" in name:
540
+ continue
541
+
542
+ is_packed = False
543
+ is_transposed = False
544
+ if self.quant_config is not None:
545
+ is_packed = self.quant_config.is_packed(name)
546
+ is_transposed = self.quant_config.is_transposed(name)
547
+ if is_transposed:
548
+ loaded_weight = convert_pyslice_to_tensor(loaded_weight)
549
+ loaded_weight = loaded_weight.T
550
+
551
+ is_attention_weight = False
552
+ for weight_name, shard_size, offset in attention_weight_specs:
553
+ if weight_name not in name or "qkv_proj" in name:
554
+ continue
555
+ param = state_dict[name.replace(weight_name, "qkv_proj")]
556
+ if is_transposed:
557
+ param = param.T
558
+
559
+ if is_packed:
560
+ shard_size //= self.quant_config.pack_factor
561
+ offset //= self.quant_config.pack_factor
562
+
563
+ if weight_name in ["k_proj", "v_proj"]:
564
+ shard_id = tp_rank // num_kv_heads_replicas
565
+ else:
566
+ shard_id = tp_rank
567
+ loaded_weight = loaded_weight[shard_size *
568
+ shard_id:shard_size *
569
+ (shard_id + 1)]
570
+ param_slice = param.data[offset:offset + shard_size]
571
+ assert param_slice.shape == loaded_weight.shape
572
+
573
+ param_slice.copy_(loaded_weight)
574
+ loaded += 1.0 / 3
575
+ is_attention_weight = True
576
+ break
577
+ if is_attention_weight:
578
+ continue
579
+
580
+ # TODO: need to figure out to do sharding with qkv_proj fused
581
+
582
+ is_gate_up_weight = False
583
+ for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
584
+ if weight_name not in name or "gate_up_proj" in name:
585
+ continue
586
+ param = state_dict[name.replace(weight_name, "gate_up_proj")]
587
+ if is_transposed:
588
+ param = param.T
589
+
590
+ shard_size = param.shape[0] // 2
591
+ loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
592
+ (tp_rank + 1)]
593
+ param_slice = param.data[shard_size * stride_id:shard_size *
594
+ (stride_id + 1)]
595
+ assert param_slice.shape == loaded_weight.shape
596
+ param_slice.copy_(loaded_weight)
597
+ loaded += 1.0 / 2
598
+ is_gate_up_weight = True
599
+ break
600
+ if is_gate_up_weight:
601
+ continue
602
+
603
+ # TODO: need to figure out to do sharding with gate_up_proj fused
604
+
605
+ param = state_dict[name]
606
+ if is_transposed:
607
+ param = param.T
608
+
609
+ if "embed_tokens" in name or "lm_head" in name:
610
+ load_padded_tensor_parallel_vocab(param, loaded_weight,
611
+ tp_rank)
612
+ loaded += 1
613
+ continue
614
+
615
+ load_tensor_parallel_weights(param, loaded_weight, name,
616
+ column_parallel_weights,
617
+ row_parallel_weights, tp_rank)
618
+ loaded += 1
619
+
620
+ if np.abs(loaded - need_to_load) < 0.01:
621
+ print(f'WARNING: only {loaded} params loaded out of {need_to_load}')
622
+ else:
623
+ print(f'Loaded all {loaded} params loaded out of {need_to_load}')
624
+
625
+
626
  # Reassign LlamaForCausalLM.load_weights with llama_load_weights
627
  if not DEBUG:
628
 
 
 
 
629
  try:
630
  import vllm
631
  from vllm.model_executor.model_loader import _MODEL_REGISTRY
632
  from vllm.model_executor.models import LlamaForCausalLM
633
 
634
  _MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
635
+ if vllm.__version__ == "0.1.4":
636
+ LlamaForCausalLM.load_weights = llama_load_weights
637
+ else:
638
+ LlamaForCausalLM.load_weights = new_llama_load_weights
639
 
640
  if DTYPE == "bfloat16":
641
  try:
 
658
  set_documentation_group("component")
659
 
660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
661
  RES_PRINTED = False
662
 
663
  def llama_chat_sys_input_seq_constructor(text, sys_prompt=SYSTEM_PROMPT_1, bos_token=BOS_TOKEN, eos_token=EOS_TOKEN):
 
774
  api_name=False,
775
  queue=False,
776
  )
777
+ # upon clear, cancel the submit event as well
778
+ if self.clear_btn:
779
+ self.clear_btn.click(
780
+ lambda: ([], [], None, Button.update(interactive=True)),
781
+ None,
782
+ [self.chatbot, self.chatbot_state, self.saved_input, self.submit_btn],
783
+ queue=False,
784
+ api_name=False,
785
+ cancels=event_to_cancel,
786
+ )
787
+
788
+ # TODO: reconfigure clear button as stop and clear button
789
+ def _setup_events(self) -> None:
790
+ has_on = False
791
+ try:
792
+ from gradio.events import Dependency, EventListenerMethod, on
793
+ has_on = True
794
+ except ImportError as ie:
795
+ has_on = False
796
+ submit_fn = self._stream_fn if self.is_generator else self._submit_fn
797
+
798
+
799
+ if has_on:
800
+ # new version
801
+ submit_triggers = (
802
+ [self.textbox.submit, self.submit_btn.click]
803
+ if self.submit_btn
804
+ else [self.textbox.submit]
805
+ )
806
+ submit_event = (
807
+ on(
808
+ submit_triggers,
809
+ self._clear_and_save_textbox,
810
+ [self.textbox],
811
+ [self.textbox, self.saved_input],
812
+ api_name=False,
813
+ queue=False,
814
+ )
815
+ .then(
816
+ self._display_input,
817
+ [self.saved_input, self.chatbot_state],
818
+ [self.chatbot, self.chatbot_state],
819
+ api_name=False,
820
+ queue=False,
821
+ )
822
+ .then(
823
+ submit_fn,
824
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
825
+ [self.chatbot, self.chatbot_state],
826
+ api_name=False,
827
+ )
828
+ )
829
+ self._setup_stop_events(submit_triggers, submit_event)
830
+ else:
831
+ raise ValueError(f'Better install new gradio version than 3.44.0')
832
+
833
+ if self.retry_btn:
834
+ retry_event = (
835
+ self.retry_btn.click(
836
+ self._delete_prev_fn,
837
+ [self.chatbot_state],
838
+ [self.chatbot, self.saved_input, self.chatbot_state],
839
+ api_name=False,
840
+ queue=False,
841
+ )
842
+ .then(
843
+ self._display_input,
844
+ [self.saved_input, self.chatbot_state],
845
+ [self.chatbot, self.chatbot_state],
846
+ api_name=False,
847
+ queue=False,
848
+ )
849
+ .then(
850
+ submit_fn,
851
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
852
+ [self.chatbot, self.chatbot_state],
853
+ api_name=False,
854
+ )
855
+ )
856
+ self._setup_stop_events([self.retry_btn.click], retry_event)
857
+
858
+ if self.undo_btn:
859
+ self.undo_btn.click(
860
+ self._delete_prev_fn,
861
+ [self.chatbot_state],
862
+ [self.chatbot, self.saved_input, self.chatbot_state],
863
+ api_name=False,
864
+ queue=False,
865
+ ).then(
866
+ lambda x: x,
867
+ [self.saved_input],
868
+ [self.textbox],
869
+ api_name=False,
870
+ queue=False,
871
+ )
872
 
873
+ # Reconfigure clear_btn to stop and clear text box
874
+ # if self.clear_btn:
875
+ # self.clear_btn.click(
876
+ # lambda: ([], [], None),
877
+ # None,
878
+ # [self.chatbot, self.chatbot_state, self.saved_input],
879
+ # queue=False,
880
+ # api_name=False,
881
+ # cancels=submit_event,
882
+ # )
883
+
884
+
885
+ # replace
886
  gr.ChatInterface._setup_stop_events = _setup_stop_events
887
+ gr.ChatInterface._setup_events = _setup_events
888
 
889
  def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
890
  global llm
 
918
  continue
919
  scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
920
 
 
921
  def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
922
  from vllm.outputs import RequestOutput
923
  # Initialize tqdm.
 
930
  step_outputs = self.llm_engine.step()
931
  for output in step_outputs:
932
  outputs[output.request_id] = output
 
933
  if len(outputs) > 0:
934
  yield outputs
935
+
 
 
 
 
 
 
936
 
937
 
938
  def vllm_generate_stream(
 
991
  yield from _vllm_run_engine(self, use_tqdm)
992
 
993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
994
  BLOCK_MESSAGE = """Sorry, Chinese is not currently supported. Please clear the chat box for a new conversation.
995
  抱歉,目前不支持中文。 请清除聊天框以进行新对话。"""
996
 
997
+ KEYWORD_BLOCK_MESSAGE = "Sorry, I cannot fulfill your request. If you have any unrelated questions, I'll be glad to help."
998
+
999
  def block_zh(
1000
  message: str,
1001
  history: List[Tuple[str, str]]
1002
  ) -> str:
1003
+ if history is not None and any((BLOCK_MESSAGE in x[1].strip()) for x in history):
 
1004
  return True
1005
  elif 'zh' in _detect_lang(message):
1006
  print(f'Detect zh: {message}')
1007
  return True
 
1008
  else:
1009
  return False
1010
 
1011
+
1012
+ def log_responses(history, message, response):
1013
+ pass
1014
+
1015
+
1016
+ def safety_check(text, history=None, ) -> Optional[str]:
1017
+ """
1018
+ Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
1019
+ This provides an additional security measure to enhance safety and compliance with local regulations.
1020
+ """
1021
+ if BLOCK_ZH:
1022
+ if history is not None:
1023
+ if block_zh(text, history):
1024
+ return BLOCK_MESSAGE
1025
+ else:
1026
+ if "zh" in _detect_lang(text):
1027
+ return BLOCK_MESSAGE
1028
+
1029
+ if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
1030
+ return KEYWORD_BLOCK_MESSAGE
1031
+
1032
+ return None
1033
+
1034
+
1035
  def chat_response_stream_multiturn(
1036
  message: str,
1037
  history: List[Tuple[str, str]],
 
1061
 
1062
  message = message.strip()
1063
 
1064
+ message_safety = safety_check(message, history=history)
1065
+ if message_safety is not None:
1066
+ yield message_safety
1067
+ return
1068
 
 
 
 
 
 
 
 
1069
  # history will be appended with message later on
1070
  full_prompt = llama_chat_multiturn_sys_input_seq_constructor(
1071
  message, history, sys_prompt=system_prompt
1072
  )
1073
+
1074
  sampling_params = SamplingParams(
1075
  temperature=temperature, max_tokens=max_tokens,
1076
  frequency_penalty=frequency_penalty,
1077
  )
1078
  cur_out = None
1079
+
1080
  for j, gen in enumerate(vllm_generate_stream(llm, full_prompt, sampling_params)):
1081
  if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
1082
+ # optionally check safety, and respond
1083
+ if STREAM_CHECK_MULTIPLE > 0 and j % STREAM_CHECK_MULTIPLE == 0:
1084
+ message_safety = safety_check(cur_out, history=None)
1085
+ if message_safety is not None:
1086
+ yield message_safety
1087
+ return
1088
+
1089
  yield cur_out
1090
  assert len(gen) == 1, f'{gen}'
1091
  item = next(iter(gen.values()))
1092
  cur_out = item.outputs[0].text
1093
 
1094
+ print(f'{full_prompt}<<<{cur_out}>>>\n\n')
 
 
1095
  if cur_out is not None:
1096
  yield cur_out
1097
 
1098
+ message_safety = safety_check(cur_out, history=None)
1099
+ if message_safety is not None:
1100
+ yield message_safety
1101
+ return
1102
+
1103
+ if LOG_RESPONSE:
1104
+ log_responses(history, message, cur_out)
1105
+
1106
 
1107
 
1108
  def debug_chat_response_echo(
 
1118
  yield f"repeat: {message}"
1119
 
1120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1121
  def check_model_path(model_path) -> str:
1122
  assert os.path.exists(model_path), f'{model_path} not found'
1123
  ckpt_info = "None"
 
1151
  print(
1152
  f'Launch config: {model_title=} / {tensor_parallel=} / {dtype=} / {max_tokens} | {BLOCK_ZH=} '
1153
  f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
1154
+ f'\n| STREAM_CHECK_MULTIPLE={STREAM_CHECK_MULTIPLE} '
1155
  f'\n| frequence_penalty={frequence_penalty} '
1156
  f'\n| temperature={temperature} '
1157
  f'\n| hf_model_name={hf_model_name} '
1158
  f'\n| model_path={model_path} '
1159
  f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
1160
+ f'\n| gpu_memory_utilization={gpu_memory_utilization} '
1161
+ f'\n| KEYWORDS={KEYWORDS} '
1162
  f'\nsys={SYSTEM_PROMPT_1}'
1163
  f'\ndesc={model_desc}'
1164
  )
 
1179
  snapshot_download(hf_model_name, local_dir=model_path)
1180
 
1181
  import vllm
1182
+ from vllm import LLM
1183
 
1184
  print(F'VLLM: {vllm.__version__}')
1185
  ckpt_info = check_model_path(model_path)
1186
 
1187
  print(f'Load path: {model_path} | {ckpt_info}')
1188
+
1189
+ if QUANTIZATION == 'awq':
1190
+ print(F'Load model in int4 quantization')
1191
+ llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel, gpu_memory_utilization=gpu_memory_utilization, quantization="awq")
1192
+ else:
1193
+ llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel, gpu_memory_utilization=gpu_memory_utilization)
1194
+
1195
+ try:
1196
+ print(llm.llm_engine.workers[0].model)
1197
+ except Exception as e:
1198
+ print(f'Cannot print model worker: {e}')
1199
 
1200
  print(f'Use system prompt:\n{sys_prompt}')
1201
 
 
1218
  stop_btn=None,
1219
  title=f"{model_title}",
1220
  description=f"{model_desc}",
 
1221
  additional_inputs=[
1222
  gr.Number(value=temperature, label='Temperature (higher -> more random)'),
1223
  gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
1224
  gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens)'),
1225
+ # ! Remove the system prompt textbox to avoid jailbreaking
1226
  # gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
1227
  ],
1228
  )
1229
+ demo.title = MODEL_NAME
1230
  with demo:
1231
+ # gr.Markdown(warning_markdown)
1232
  gr.Markdown(cite_markdown)
1233
  gr.Markdown(path_markdown.format(model_path=model_path))
1234
 
 
1243
 
1244
  if __name__ == "__main__":
1245
  main()