pseudotensor commited on
Commit
0539589
·
1 Parent(s): eeb7ca1

Update with h2oGPT hash 3513278043665f503945eb05d56c1ec1152d1006

Browse files
Files changed (5) hide show
  1. generate.py +31 -15
  2. gpt_langchain.py +40 -8
  3. gradio_runner.py +8 -6
  4. requirements.txt +2 -1
  5. utils.py +0 -1
generate.py CHANGED
@@ -33,7 +33,6 @@ from typing import Union
33
 
34
  import fire
35
  import torch
36
- from peft import PeftModel
37
  from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
38
  from accelerate import init_empty_weights, infer_auto_device_map
39
 
@@ -710,6 +709,7 @@ def get_model(
710
  base_model,
711
  **model_kwargs
712
  )
 
713
  model = PeftModel.from_pretrained(
714
  model,
715
  lora_weights,
@@ -727,6 +727,7 @@ def get_model(
727
  base_model,
728
  **model_kwargs
729
  )
 
730
  model = PeftModel.from_pretrained(
731
  model,
732
  lora_weights,
@@ -827,24 +828,27 @@ no_default_param_names = [
827
  'iinput_nochat',
828
  ]
829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
830
  eval_func_param_names = ['instruction',
831
  'iinput',
832
  'context',
833
  'stream_output',
834
  'prompt_type',
835
- 'prompt_dict',
836
- 'temperature',
837
- 'top_p',
838
- 'top_k',
839
- 'num_beams',
840
- 'max_new_tokens',
841
- 'min_new_tokens',
842
- 'early_stopping',
843
- 'max_time',
844
- 'repetition_penalty',
845
- 'num_return_sequences',
846
- 'do_sample',
847
- 'chat',
848
  'instruction_nochat',
849
  'iinput_nochat',
850
  'langchain_mode',
@@ -900,6 +904,9 @@ def evaluate_from_str(
900
  # only used for submit_nochat_api
901
  user_kwargs['chat'] = False
902
  user_kwargs['stream_output'] = False
 
 
 
903
 
904
  assert set(list(default_kwargs.keys())) == set(eval_func_param_names)
905
  # correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
@@ -1083,7 +1090,6 @@ def evaluate(
1083
  db=db1,
1084
  user_path=user_path,
1085
  detect_user_path_changes_every_query=detect_user_path_changes_every_query,
1086
- max_new_tokens=max_new_tokens,
1087
  cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary
1088
  use_openai_embedding=use_openai_embedding,
1089
  use_openai_model=use_openai_model,
@@ -1096,10 +1102,20 @@ def evaluate(
1096
  document_choice=document_choice,
1097
  db_type=db_type,
1098
  top_k_docs=top_k_docs,
 
 
 
1099
  temperature=temperature,
1100
  repetition_penalty=repetition_penalty,
1101
  top_k=top_k,
1102
  top_p=top_p,
 
 
 
 
 
 
 
1103
  prompt_type=prompt_type,
1104
  prompt_dict=prompt_dict,
1105
  n_jobs=n_jobs,
 
33
 
34
  import fire
35
  import torch
 
36
  from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
37
  from accelerate import init_empty_weights, infer_auto_device_map
38
 
 
709
  base_model,
710
  **model_kwargs
711
  )
712
+ from peft import PeftModel # loads cuda, so avoid in global scope
713
  model = PeftModel.from_pretrained(
714
  model,
715
  lora_weights,
 
727
  base_model,
728
  **model_kwargs
729
  )
730
+ from peft import PeftModel # loads cuda, so avoid in global scope
731
  model = PeftModel.from_pretrained(
732
  model,
733
  lora_weights,
 
828
  'iinput_nochat',
829
  ]
830
 
831
+ gen_hyper = ['temperature',
832
+ 'top_p',
833
+ 'top_k',
834
+ 'num_beams',
835
+ 'max_new_tokens',
836
+ 'min_new_tokens',
837
+ 'early_stopping',
838
+ 'max_time',
839
+ 'repetition_penalty',
840
+ 'num_return_sequences',
841
+ 'do_sample',
842
+ ]
843
+
844
  eval_func_param_names = ['instruction',
845
  'iinput',
846
  'context',
847
  'stream_output',
848
  'prompt_type',
849
+ 'prompt_dict'] + \
850
+ gen_hyper + \
851
+ ['chat',
 
 
 
 
 
 
 
 
 
 
852
  'instruction_nochat',
853
  'iinput_nochat',
854
  'langchain_mode',
 
904
  # only used for submit_nochat_api
905
  user_kwargs['chat'] = False
906
  user_kwargs['stream_output'] = False
907
+ if 'langchain_mode' not in user_kwargs:
908
+ # if user doesn't specify, then assume disabled, not use default
909
+ user_kwargs['langchain_mode'] = 'Disabled'
910
 
911
  assert set(list(default_kwargs.keys())) == set(eval_func_param_names)
912
  # correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
 
1090
  db=db1,
1091
  user_path=user_path,
1092
  detect_user_path_changes_every_query=detect_user_path_changes_every_query,
 
1093
  cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary
1094
  use_openai_embedding=use_openai_embedding,
1095
  use_openai_model=use_openai_model,
 
1102
  document_choice=document_choice,
1103
  db_type=db_type,
1104
  top_k_docs=top_k_docs,
1105
+
1106
+ # gen_hyper:
1107
+ do_sample=do_sample,
1108
  temperature=temperature,
1109
  repetition_penalty=repetition_penalty,
1110
  top_k=top_k,
1111
  top_p=top_p,
1112
+ num_beams=num_beams,
1113
+ min_new_tokens=min_new_tokens,
1114
+ max_new_tokens=max_new_tokens,
1115
+ early_stopping=early_stopping,
1116
+ max_time=max_time,
1117
+ num_return_sequences=num_return_sequences,
1118
+
1119
  prompt_type=prompt_type,
1120
  prompt_dict=prompt_dict,
1121
  n_jobs=n_jobs,
gpt_langchain.py CHANGED
@@ -22,6 +22,7 @@ from langchain.embeddings import HuggingFaceInstructEmbeddings
22
  from tqdm import tqdm
23
 
24
  from enums import DocumentChoices
 
25
  from prompter import non_hf_types, PromptType
26
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
27
  get_device, ProgressParallel, remove, hash_file, clear_torch_cache
@@ -261,11 +262,17 @@ def get_answer_from_sources(chain, sources, question):
261
 
262
  def get_llm(use_openai_model=False, model_name=None, model=None,
263
  tokenizer=None, stream_output=False,
264
- max_new_tokens=256,
265
  temperature=0.1,
266
- repetition_penalty=1.0,
267
  top_k=40,
268
  top_p=0.7,
 
 
 
 
 
 
 
269
  prompt_type=None,
270
  prompt_dict=None,
271
  prompter=None,
@@ -312,10 +319,20 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
312
  load_in_8bit=load_8bit)
313
 
314
  max_max_tokens = tokenizer.model_max_length
315
- gen_kwargs = dict(max_new_tokens=max_new_tokens,
 
 
 
 
 
 
 
 
 
 
316
  return_full_text=True,
317
- early_stopping=False,
318
  handle_long_generation='hole')
 
319
 
320
  if stream_output:
321
  skip_prompt = False
@@ -1235,11 +1252,17 @@ def _run_qa_db(query=None,
1235
  show_rank=False,
1236
  load_db_if_exists=False,
1237
  db=None,
1238
- max_new_tokens=256,
1239
  temperature=0.1,
1240
- repetition_penalty=1.0,
1241
  top_k=40,
1242
  top_p=0.7,
 
 
 
 
 
 
 
1243
  langchain_mode=None,
1244
  document_choice=[DocumentChoices.All_Relevant.name],
1245
  n_jobs=-1,
@@ -1274,14 +1297,21 @@ def _run_qa_db(query=None,
1274
  assert prompt_dict is not None # should at least be {} or ''
1275
  else:
1276
  prompt_dict = ''
 
1277
  llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
1278
  model=model, tokenizer=tokenizer,
1279
  stream_output=stream_output,
1280
- max_new_tokens=max_new_tokens,
1281
  temperature=temperature,
1282
- repetition_penalty=repetition_penalty,
1283
  top_k=top_k,
1284
  top_p=top_p,
 
 
 
 
 
 
 
1285
  prompt_type=prompt_type,
1286
  prompt_dict=prompt_dict,
1287
  prompter=prompter,
@@ -1609,6 +1639,7 @@ def get_some_dbs_from_hf(dest='.', db_zips=None):
1609
  assert os.path.isdir(os.path.join(dest, dir_expected)), "Missing path for %s" % dir_expected
1610
  assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected
1611
 
 
1612
  def _create_local_weaviate_client():
1613
  WEAVIATE_URL = os.getenv('WEAVIATE_URL', "http://localhost:8080")
1614
  WEAVIATE_USERNAME = os.getenv('WEAVIATE_USERNAME')
@@ -1629,5 +1660,6 @@ def _create_local_weaviate_client():
1629
  print(f"Failed to create Weaviate client: {e}")
1630
  return None
1631
 
 
1632
  if __name__ == '__main__':
1633
  pass
 
22
  from tqdm import tqdm
23
 
24
  from enums import DocumentChoices
25
+ from generate import gen_hyper
26
  from prompter import non_hf_types, PromptType
27
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
28
  get_device, ProgressParallel, remove, hash_file, clear_torch_cache
 
262
 
263
  def get_llm(use_openai_model=False, model_name=None, model=None,
264
  tokenizer=None, stream_output=False,
265
+ do_sample=False,
266
  temperature=0.1,
 
267
  top_k=40,
268
  top_p=0.7,
269
+ num_beams=1,
270
+ max_new_tokens=256,
271
+ min_new_tokens=1,
272
+ early_stopping=False,
273
+ max_time=180,
274
+ repetition_penalty=1.0,
275
+ num_return_sequences=1,
276
  prompt_type=None,
277
  prompt_dict=None,
278
  prompter=None,
 
319
  load_in_8bit=load_8bit)
320
 
321
  max_max_tokens = tokenizer.model_max_length
322
+ gen_kwargs = dict(do_sample=do_sample,
323
+ temperature=temperature,
324
+ top_k=top_k,
325
+ top_p=top_p,
326
+ num_beams=num_beams,
327
+ max_new_tokens=max_new_tokens,
328
+ min_new_tokens=min_new_tokens,
329
+ early_stopping=early_stopping,
330
+ max_time=max_time,
331
+ repetition_penalty=repetition_penalty,
332
+ num_return_sequences=num_return_sequences,
333
  return_full_text=True,
 
334
  handle_long_generation='hole')
335
+ assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0
336
 
337
  if stream_output:
338
  skip_prompt = False
 
1252
  show_rank=False,
1253
  load_db_if_exists=False,
1254
  db=None,
1255
+ do_sample=False,
1256
  temperature=0.1,
 
1257
  top_k=40,
1258
  top_p=0.7,
1259
+ num_beams=1,
1260
+ max_new_tokens=256,
1261
+ min_new_tokens=1,
1262
+ early_stopping=False,
1263
+ max_time=180,
1264
+ repetition_penalty=1.0,
1265
+ num_return_sequences=1,
1266
  langchain_mode=None,
1267
  document_choice=[DocumentChoices.All_Relevant.name],
1268
  n_jobs=-1,
 
1297
  assert prompt_dict is not None # should at least be {} or ''
1298
  else:
1299
  prompt_dict = ''
1300
+ assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
1301
  llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
1302
  model=model, tokenizer=tokenizer,
1303
  stream_output=stream_output,
1304
+ do_sample=do_sample,
1305
  temperature=temperature,
 
1306
  top_k=top_k,
1307
  top_p=top_p,
1308
+ num_beams=num_beams,
1309
+ max_new_tokens=max_new_tokens,
1310
+ min_new_tokens=min_new_tokens,
1311
+ early_stopping=early_stopping,
1312
+ max_time=max_time,
1313
+ repetition_penalty=repetition_penalty,
1314
+ num_return_sequences=num_return_sequences,
1315
  prompt_type=prompt_type,
1316
  prompt_dict=prompt_dict,
1317
  prompter=prompter,
 
1639
  assert os.path.isdir(os.path.join(dest, dir_expected)), "Missing path for %s" % dir_expected
1640
  assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected
1641
 
1642
+
1643
  def _create_local_weaviate_client():
1644
  WEAVIATE_URL = os.getenv('WEAVIATE_URL', "http://localhost:8080")
1645
  WEAVIATE_USERNAME = os.getenv('WEAVIATE_USERNAME')
 
1660
  print(f"Failed to create Weaviate client: {e}")
1661
  return None
1662
 
1663
+
1664
  if __name__ == '__main__':
1665
  pass
gradio_runner.py CHANGED
@@ -649,7 +649,7 @@ def go_gradio(**kwargs):
649
  inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
650
  chunk, chunk_size],
651
  outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
652
- api_name='add_to_shared' if allow_api else None) \
653
  .then(clear_file_list, outputs=fileup_output, queue=queue) \
654
  .then(update_radio_to_user, inputs=None, outputs=langchain_mode, queue=False)
655
 
@@ -664,7 +664,7 @@ def go_gradio(**kwargs):
664
  inputs=[url_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
665
  chunk, chunk_size],
666
  outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
667
- api_name='add_url_to_shared' if allow_api else None) \
668
  .then(clear_textbox, outputs=url_text, queue=queue) \
669
  .then(update_radio_to_user, inputs=None, outputs=langchain_mode, queue=False)
670
 
@@ -673,7 +673,7 @@ def go_gradio(**kwargs):
673
  inputs=[user_text_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
674
  chunk, chunk_size],
675
  outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
676
- api_name='add_text_to_shared' if allow_api else None) \
677
  .then(clear_textbox, outputs=user_text_text, queue=queue) \
678
  .then(update_radio_to_user, inputs=None, outputs=langchain_mode, queue=False)
679
 
@@ -695,7 +695,7 @@ def go_gradio(**kwargs):
695
  inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
696
  chunk, chunk_size],
697
  outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
698
- api_name='add_to_my' if allow_api else None) \
699
  .then(clear_file_list, outputs=fileup_output, queue=queue) \
700
  .then(update_radio_to_my, inputs=None, outputs=langchain_mode, queue=False)
701
  # .then(make_invisible, outputs=add_to_shared_db_btn, queue=queue)
@@ -706,7 +706,7 @@ def go_gradio(**kwargs):
706
  inputs=[url_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
707
  chunk, chunk_size],
708
  outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
709
- api_name='add_url_to_my' if allow_api else None) \
710
  .then(clear_textbox, outputs=url_text, queue=queue) \
711
  .then(update_radio_to_my, inputs=None, outputs=langchain_mode, queue=False)
712
 
@@ -715,7 +715,7 @@ def go_gradio(**kwargs):
715
  inputs=[user_text_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
716
  chunk, chunk_size],
717
  outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
718
- api_name='add_txt_to_my' if allow_api else None) \
719
  .then(clear_textbox, outputs=user_text_text, queue=queue) \
720
  .then(update_radio_to_my, inputs=None, outputs=langchain_mode, queue=False)
721
 
@@ -1788,6 +1788,8 @@ def get_db(db1, langchain_mode, dbs=None):
1788
 
1789
  def get_source_files_given_langchain_mode(db1, langchain_mode='UserData', dbs=None):
1790
  db = get_db(db1, langchain_mode, dbs=dbs)
 
 
1791
  return get_source_files(db=db, exceptions=None)
1792
 
1793
 
 
649
  inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
650
  chunk, chunk_size],
651
  outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
652
+ api_name='add_to_shared' if allow_api and allow_upload_to_user_data else None) \
653
  .then(clear_file_list, outputs=fileup_output, queue=queue) \
654
  .then(update_radio_to_user, inputs=None, outputs=langchain_mode, queue=False)
655
 
 
664
  inputs=[url_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
665
  chunk, chunk_size],
666
  outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
667
+ api_name='add_url_to_shared' if allow_api and allow_upload_to_user_data else None) \
668
  .then(clear_textbox, outputs=url_text, queue=queue) \
669
  .then(update_radio_to_user, inputs=None, outputs=langchain_mode, queue=False)
670
 
 
673
  inputs=[user_text_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
674
  chunk, chunk_size],
675
  outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
676
+ api_name='add_text_to_shared' if allow_api and allow_upload_to_user_data else None) \
677
  .then(clear_textbox, outputs=user_text_text, queue=queue) \
678
  .then(update_radio_to_user, inputs=None, outputs=langchain_mode, queue=False)
679
 
 
695
  inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
696
  chunk, chunk_size],
697
  outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
698
+ api_name='add_to_my' if allow_api and allow_upload_to_my_data else None) \
699
  .then(clear_file_list, outputs=fileup_output, queue=queue) \
700
  .then(update_radio_to_my, inputs=None, outputs=langchain_mode, queue=False)
701
  # .then(make_invisible, outputs=add_to_shared_db_btn, queue=queue)
 
706
  inputs=[url_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
707
  chunk, chunk_size],
708
  outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
709
+ api_name='add_url_to_my' if allow_api and allow_upload_to_my_data else None) \
710
  .then(clear_textbox, outputs=url_text, queue=queue) \
711
  .then(update_radio_to_my, inputs=None, outputs=langchain_mode, queue=False)
712
 
 
715
  inputs=[user_text_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
716
  chunk, chunk_size],
717
  outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
718
+ api_name='add_txt_to_my' if allow_api and allow_upload_to_my_data else None) \
719
  .then(clear_textbox, outputs=user_text_text, queue=queue) \
720
  .then(update_radio_to_my, inputs=None, outputs=langchain_mode, queue=False)
721
 
 
1788
 
1789
  def get_source_files_given_langchain_mode(db1, langchain_mode='UserData', dbs=None):
1790
  db = get_db(db1, langchain_mode, dbs=dbs)
1791
+ if langchain_mode in ['ChatLLM', 'LLM'] or db is None:
1792
+ return "Sources: N/A"
1793
  return get_source_files(db=db, exceptions=None)
1794
 
1795
 
requirements.txt CHANGED
@@ -56,7 +56,8 @@ einops==0.6.1
56
  instructorembedding==1.0.1
57
 
58
  # for gpt4all .env file, but avoid worrying about imports
59
- python-dotenv==1.0.0# optional for chat with PDF
 
60
  langchain==0.0.193
61
  pypdf==3.8.1
62
  tiktoken==0.3.3
 
56
  instructorembedding==1.0.1
57
 
58
  # for gpt4all .env file, but avoid worrying about imports
59
+ python-dotenv==1.0.0
60
+ # optional for chat with PDF
61
  langchain==0.0.193
62
  pypdf==3.8.1
63
  tiktoken==0.3.3
utils.py CHANGED
@@ -14,7 +14,6 @@ import time
14
  import traceback
15
  import zipfile
16
  from datetime import datetime
17
- from enum import Enum
18
 
19
  import filelock
20
  import requests, uuid
 
14
  import traceback
15
  import zipfile
16
  from datetime import datetime
 
17
 
18
  import filelock
19
  import requests, uuid