arnocandel commited on
Commit
24b4b28
1 Parent(s): b43c18e

Update with h2oGPT hash e35e6ce906c57495ee80b1e3b8507ad374f6a50d

Browse files
Files changed (4) hide show
  1. finetune.py +20 -5
  2. generate.py +51 -6
  3. gradio_runner.py +3 -2
  4. requirements.txt +3 -3
finetune.py CHANGED
@@ -30,6 +30,7 @@ class PromptType(Enum):
30
  human_bot_orig = 9
31
  prompt_answer = 10
32
  open_assistant = 11
 
33
 
34
 
35
  prompt_type_to_model_name = {
@@ -56,6 +57,8 @@ prompt_type_to_model_name = {
56
  'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
57
  'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
58
  'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
 
 
59
  ],
60
  'instruct': [],
61
  'instruct_with_end': ['databricks/dolly-v2-12b'],
@@ -63,15 +66,18 @@ prompt_type_to_model_name = {
63
  'human_bot': [
64
  'h2oai/h2ogpt-oasst1-512-12b',
65
  'h2oai/h2ogpt-oasst1-512-20b',
 
 
66
  'h2oai/h2ogpt-oig-oasst1-512-6.9b',
67
  'h2oai/h2ogpt-research-oasst1-512-30b', # private
68
  ],
69
  'dai_faq': [],
70
  'summarize': [],
71
  'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
72
- 'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b'],
73
  'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
74
  "open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
 
75
  }
76
 
77
  inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
@@ -222,8 +228,6 @@ def train(
222
  NOTE: for current pytorch 2.0, flash attention requires installing cuda 11.7 via https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=runfile_local and then when running, to avoid installing driver, docs, samples, just install toolkit. Then when pip installing flash attention do:
223
 
224
  CUDA_HOME=/usr/local/cuda-11.7 pip install flash-attn""")
225
- from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
226
- replace_llama_attn_with_flash_attn()
227
  assert (
228
  base_model
229
  ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
@@ -590,8 +594,8 @@ def train(
590
  tokenizer=tokenizer,
591
  train_dataset=train_data,
592
  eval_dataset=valid_data,
593
- # NOTE: CausalLM is not supporting Seq2SeqTrainingArguments arguments, but not incompatible
594
- args=transformers.Seq2SeqTrainingArguments(
595
  per_device_train_batch_size=micro_batch_size,
596
  per_device_eval_batch_size=1,
597
  eval_accumulation_steps=10,
@@ -901,6 +905,17 @@ Current Time: {}
901
  eos = "</s>"
902
  terminate_response = [start, PreResponse, pend, eos]
903
  chat_sep = eos
 
 
 
 
 
 
 
 
 
 
 
904
  else:
905
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
906
 
 
30
  human_bot_orig = 9
31
  prompt_answer = 10
32
  open_assistant = 11
33
+ wizard_lm = 12
34
 
35
 
36
  prompt_type_to_model_name = {
 
57
  'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
58
  'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
59
  'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
60
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
61
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
62
  ],
63
  'instruct': [],
64
  'instruct_with_end': ['databricks/dolly-v2-12b'],
 
66
  'human_bot': [
67
  'h2oai/h2ogpt-oasst1-512-12b',
68
  'h2oai/h2ogpt-oasst1-512-20b',
69
+ 'h2oai/h2ogpt-oig-oasst1-512-20b',
70
+ 'h2oai/h2ogpt-oig-oasst1-512-12b',
71
  'h2oai/h2ogpt-oig-oasst1-512-6.9b',
72
  'h2oai/h2ogpt-research-oasst1-512-30b', # private
73
  ],
74
  'dai_faq': [],
75
  'summarize': [],
76
  'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
77
+ 'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'],
78
  'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
79
  "open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
80
+ "wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
81
  }
82
 
83
  inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
 
228
  NOTE: for current pytorch 2.0, flash attention requires installing cuda 11.7 via https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=runfile_local and then when running, to avoid installing driver, docs, samples, just install toolkit. Then when pip installing flash attention do:
229
 
230
  CUDA_HOME=/usr/local/cuda-11.7 pip install flash-attn""")
 
 
231
  assert (
232
  base_model
233
  ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
 
594
  tokenizer=tokenizer,
595
  train_dataset=train_data,
596
  eval_dataset=valid_data,
597
+ # FIXME: might need Seq2SeqTrainingArguments for some models
598
+ args=transformers.TrainingArguments(
599
  per_device_train_batch_size=micro_batch_size,
600
  per_device_eval_batch_size=1,
601
  eval_accumulation_steps=10,
 
905
  eos = "</s>"
906
  terminate_response = [start, PreResponse, pend, eos]
907
  chat_sep = eos
908
+ elif prompt_type in [12, "12", "wizard_lm"]:
909
+ # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
910
+ preprompt = ''
911
+ start = ''
912
+ promptB = promptA = '%s%s' % (preprompt, start)
913
+ PreInstruct = ""
914
+ PreInput = None
915
+ PreResponse = "\n\n### Response"
916
+ eos = "</s>"
917
+ terminate_response = [PreResponse, eos]
918
+ chat_sep = eos
919
  else:
920
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
921
 
generate.py CHANGED
@@ -84,6 +84,7 @@ def main(
84
  api_open: bool = False,
85
  allow_api: bool = True,
86
  input_lines: int = 1,
 
87
 
88
  sanitize_user_prompt: bool = True,
89
  sanitize_bot_response: bool = True,
@@ -145,6 +146,8 @@ def main(
145
  :param api_open: If False, don't let API calls skip gradio queue
146
  :param allow_api: whether to allow API calls at all to gradio server
147
  :param input_lines: how many input lines to show for chat box (>1 forces shift-enter for submit, else enter is submit)
 
 
148
  :param sanitize_user_prompt: whether to remove profanity from user input
149
  :param sanitize_bot_response: whether to remove profanity and repeat lines from bot output
150
  :param extra_model_options: extra models to show in list in gradio
@@ -211,7 +214,7 @@ def main(
211
  if psutil.virtual_memory().available < 94*1024**3:
212
  # 12B uses ~94GB
213
  # 6.9B uses ~47GB
214
- base_model = 'h2oai/h2ogpt-oig-oasst1-512-6.9b'
215
 
216
  # get defaults
217
  model_lower = base_model.lower()
@@ -881,13 +884,17 @@ def evaluate(
881
  else:
882
  gen_kwargs.update(dict(pad_token_id=tokenizer.eos_token_id))
883
 
 
 
 
884
  decoder = functools.partial(tokenizer.decode,
885
- skip_special_tokens=True,
886
- clean_up_tokenization_spaces=True,
887
  )
 
 
 
888
  decoder_raw = functools.partial(tokenizer.decode,
889
- skip_special_tokens=False,
890
- clean_up_tokenization_spaces=True,
891
  )
892
 
893
  with torch.no_grad():
@@ -915,14 +922,16 @@ def evaluate(
915
  # some models specify special tokens that are part of normal prompt, so can't skip them
916
  inputs_decoded = prompt = inputs_decoded_raw
917
  decoder = decoder_raw
 
918
  elif inputs_decoded_raw.replace("<unk> ", "").replace("<unk>", "").replace('\n', ' ').replace(' ', '') == prompt.replace('\n', ' ').replace(' ', ''):
919
  inputs_decoded = prompt = inputs_decoded_raw
920
  decoder = decoder_raw
 
921
  else:
922
  print("WARNING: Special characters in prompt", flush=True)
923
  if stream_output:
924
  skip_prompt = False
925
- streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False)
926
  gen_kwargs.update(dict(streamer=streamer))
927
  target_func = generate_with_exceptions
928
  target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
@@ -1312,3 +1321,39 @@ if __name__ == "__main__":
1312
  python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6.9b
1313
  """
1314
  fire.Fire(main)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  api_open: bool = False,
85
  allow_api: bool = True,
86
  input_lines: int = 1,
87
+ auth: typing.List[typing.Tuple[str, str]] = None,
88
 
89
  sanitize_user_prompt: bool = True,
90
  sanitize_bot_response: bool = True,
 
146
  :param api_open: If False, don't let API calls skip gradio queue
147
  :param allow_api: whether to allow API calls at all to gradio server
148
  :param input_lines: how many input lines to show for chat box (>1 forces shift-enter for submit, else enter is submit)
149
+ :param auth: gradio auth for launcher in form [(user1, pass1), (user2, pass2), ...]
150
+ e.g. --auth=[('jon','password')] with no spaces
151
  :param sanitize_user_prompt: whether to remove profanity from user input
152
  :param sanitize_bot_response: whether to remove profanity and repeat lines from bot output
153
  :param extra_model_options: extra models to show in list in gradio
 
214
  if psutil.virtual_memory().available < 94*1024**3:
215
  # 12B uses ~94GB
216
  # 6.9B uses ~47GB
217
+ base_model = 'h2oai/h2ogpt-oig-oasst1-512-6.9b' if not base_model else base_model
218
 
219
  # get defaults
220
  model_lower = base_model.lower()
 
884
  else:
885
  gen_kwargs.update(dict(pad_token_id=tokenizer.eos_token_id))
886
 
887
+ decoder_kwargs = dict(skip_special_tokens=True,
888
+ clean_up_tokenization_spaces=True)
889
+
890
  decoder = functools.partial(tokenizer.decode,
891
+ **decoder_kwargs
 
892
  )
893
+ decoder_raw_kwargs = dict(skip_special_tokens=False,
894
+ clean_up_tokenization_spaces=True)
895
+
896
  decoder_raw = functools.partial(tokenizer.decode,
897
+ **decoder_raw_kwargs
 
898
  )
899
 
900
  with torch.no_grad():
 
922
  # some models specify special tokens that are part of normal prompt, so can't skip them
923
  inputs_decoded = prompt = inputs_decoded_raw
924
  decoder = decoder_raw
925
+ decoder_kwargs = decoder_raw_kwargs
926
  elif inputs_decoded_raw.replace("<unk> ", "").replace("<unk>", "").replace('\n', ' ').replace(' ', '') == prompt.replace('\n', ' ').replace(' ', ''):
927
  inputs_decoded = prompt = inputs_decoded_raw
928
  decoder = decoder_raw
929
+ decoder_kwargs = decoder_raw_kwargs
930
  else:
931
  print("WARNING: Special characters in prompt", flush=True)
932
  if stream_output:
933
  skip_prompt = False
934
+ streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
935
  gen_kwargs.update(dict(streamer=streamer))
936
  target_func = generate_with_exceptions
937
  target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
 
1321
  python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6.9b
1322
  """
1323
  fire.Fire(main)
1324
+
1325
+
1326
+ import pytest
1327
+
1328
+ @pytest.mark.parametrize(
1329
+ "base_model",
1330
+ [
1331
+ "h2oai/h2ogpt-oig-oasst1-512-6.9b",
1332
+ "h2oai/h2ogpt-oig-oasst1-512-12b",
1333
+ "h2oai/h2ogpt-oig-oasst1-512-20b",
1334
+ "h2oai/h2ogpt-oasst1-512-12b",
1335
+ "h2oai/h2ogpt-oasst1-512-20b",
1336
+ "h2oai/h2ogpt-gm-oasst1-en-1024-20b",
1337
+ "databricks/dolly-v2-12b",
1338
+ "h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2",
1339
+ "ehartford/WizardLM-7B-Uncensored",
1340
+ "ehartford/WizardLM-13B-Uncensored",
1341
+ "AlekseyKorshuk/vicuna-7b",
1342
+ "TheBloke/stable-vicuna-13B-HF",
1343
+ "decapoda-research/llama-7b-hf",
1344
+ "decapoda-research/llama-13b-hf",
1345
+ "decapoda-research/llama-30b-hf",
1346
+ "junelee/wizard-vicuna-13b",
1347
+ ]
1348
+ )
1349
+ def test_score_eval(base_model):
1350
+ main(
1351
+ base_model=base_model,
1352
+ chat=False,
1353
+ stream_output=False,
1354
+ gradio=False,
1355
+ eval_sharegpt_prompts_only=500,
1356
+ eval_sharegpt_as_output=False,
1357
+ num_beams=2,
1358
+ infer_devices=False,
1359
+ )
gradio_runner.py CHANGED
@@ -50,7 +50,7 @@ def go_gradio(**kwargs):
50
  """
51
  else:
52
  description = "For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio)<br>"
53
- description += "If this host is busy, try [gpt.h2o.ai 20B](https://gpt.h2o.ai) and [30B](http://gpu.hopto.org) and [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) and [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
54
  description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md)</p>"""
55
 
56
  if kwargs['verbose']:
@@ -921,7 +921,8 @@ def go_gradio(**kwargs):
921
  scheduler.start()
922
 
923
  demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
924
- favicon_path=favicon_path, prevent_thread_lock=True) # , enable_queue=True)
 
925
  print("Started GUI", flush=True)
926
  if kwargs['block_gradio_exit']:
927
  demo.block_thread()
 
50
  """
51
  else:
52
  description = "For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio)<br>"
53
+ description += "If this host is busy, try [12B](https://gpt.h2o.ai), [30B](http://gpt2.h2o.ai), [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
54
  description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md)</p>"""
55
 
56
  if kwargs['verbose']:
 
921
  scheduler.start()
922
 
923
  demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
924
+ favicon_path=favicon_path, prevent_thread_lock=True,
925
+ auth=kwargs['auth'])
926
  print("Started GUI", flush=True)
927
  if kwargs['block_gradio_exit']:
928
  demo.block_thread()
requirements.txt CHANGED
@@ -1,13 +1,13 @@
1
  # for generate (gradio server) and finetune
2
- datasets==2.11.0
3
  sentencepiece==0.1.97
4
  accelerate==0.18.0
5
  gradio==3.27.0
6
- huggingface_hub==0.13.4
7
  appdirs==1.4.4
8
  fire==0.5.0
9
  docutils==0.19
10
- torch==2.0.0
11
  evaluate==0.4.0
12
  rouge_score==0.1.2
13
  sacrebleu==2.3.1
 
1
  # for generate (gradio server) and finetune
2
+ datasets==2.12.0
3
  sentencepiece==0.1.97
4
  accelerate==0.18.0
5
  gradio==3.27.0
6
+ huggingface_hub==0.14.1
7
  appdirs==1.4.4
8
  fire==0.5.0
9
  docutils==0.19
10
+ torch==2.0.1
11
  evaluate==0.4.0
12
  rouge_score==0.1.2
13
  sacrebleu==2.3.1