pseudotensor commited on
Commit
0f993c6
·
1 Parent(s): 08ab504

Update with h2oGPT hash 03227623260f552fd7e2b8c51409308bc7242933

Browse files
Files changed (13) hide show
  1. client_test.py +42 -19
  2. create_data.py +60 -69
  3. finetune.py +7 -11
  4. generate.py +236 -242
  5. gpt4all_llm.py +162 -26
  6. gpt_langchain.py +561 -183
  7. gradio_runner.py +252 -110
  8. gradio_themes.py +41 -2
  9. h2oai_pipeline.py +96 -22
  10. prompter.py +119 -22
  11. requirements.txt +12 -11
  12. stopping.py +6 -4
  13. utils.py +83 -8
client_test.py CHANGED
@@ -23,7 +23,7 @@ HOST="https://h2oai-h2ogpt-chatbot.hf.space" python client_test.py
23
  Result:
24
 
25
  Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
26
- {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.'}
27
 
28
 
29
  For demo:
@@ -33,9 +33,15 @@ HOST="https://gpt.h2o.ai" python client_test.py
33
  Result:
34
 
35
  Loaded as API: https://gpt.h2o.ai ✔
36
- {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.'}
 
 
 
 
 
37
 
38
  """
 
39
  import time
40
  import os
41
  import markdown # pip install markdown
@@ -56,7 +62,7 @@ def get_client(serialize=True):
56
  return client
57
 
58
 
59
- def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_tokens=50):
60
  from collections import OrderedDict
61
  kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
62
  iinput='', # only for chat=True
@@ -79,12 +85,13 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token
79
  chat=chat,
80
  instruction_nochat=prompt if not chat else '',
81
  iinput_nochat='', # only for chat=False
82
- langchain_mode='Disabled',
 
83
  document_choice=['All'],
84
  )
85
  if chat:
86
  # add chatbot output on end. Assumes serialize=False
87
- kwargs.update(dict(chatbot=[['', None]]))
88
 
89
  return kwargs, list(kwargs.values())
90
 
@@ -103,22 +110,29 @@ def run_client_nochat(prompt, prompt_type, max_new_tokens):
103
  *tuple(args),
104
  api_name=api_name,
105
  )
 
106
  res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
107
- response=md_to_text(res))
 
108
  print(res_dict)
109
  return res_dict
110
 
111
 
112
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
113
  def test_client_chat():
114
- return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50)
115
-
116
 
117
- def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens):
118
- kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output, max_new_tokens=max_new_tokens)
119
 
 
120
  client = get_client(serialize=False)
121
 
 
 
 
 
 
 
122
  res = client.predict(*tuple(args), api_name='/instruction')
123
  args[-1] += [res[-1]]
124
 
@@ -127,8 +141,8 @@ def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens):
127
  if not kwargs['stream_output']:
128
  res = client.predict(*tuple(args), api_name='/instruction_bot')
129
  res_dict['response'] = res[0][-1][1]
130
- print(md_to_text(res_dict['response']))
131
- return res_dict
132
  else:
133
  job = client.submit(*tuple(args), api_name='/instruction_bot')
134
  res1 = ''
@@ -137,15 +151,24 @@ def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens):
137
  if outputs_list:
138
  res = job.communicator.job.outputs[-1]
139
  res1 = res[0][-1][-1]
140
- res1 = md_to_text(res1)
141
  print(res1)
142
  time.sleep(0.1)
143
- print(job.outputs())
144
- res_dict['response'] = res1
145
- return res_dict
146
-
147
-
148
- def md_to_text(md):
 
 
 
 
 
 
 
 
 
149
  assert md is not None, "Markdown is None"
150
  html = markdown.markdown(md)
151
  soup = BeautifulSoup(html, features='html.parser')
 
23
  Result:
24
 
25
  Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
26
+ {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.', 'sources': ''}
27
 
28
 
29
  For demo:
 
33
  Result:
34
 
35
  Loaded as API: https://gpt.h2o.ai ✔
36
+ {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.', 'sources': ''}
37
+
38
+ NOTE: Raw output from API for nochat case is a string of a python dict and will remain so if other entries are added to dict:
39
+
40
+ {'response': "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", 'sources': ''}
41
+
42
 
43
  """
44
+ import ast
45
  import time
46
  import os
47
  import markdown # pip install markdown
 
62
  return client
63
 
64
 
65
+ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_tokens=50, langchain_mode='Disabled'):
66
  from collections import OrderedDict
67
  kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
68
  iinput='', # only for chat=True
 
85
  chat=chat,
86
  instruction_nochat=prompt if not chat else '',
87
  iinput_nochat='', # only for chat=False
88
+ langchain_mode=langchain_mode,
89
+ top_k_docs=4,
90
  document_choice=['All'],
91
  )
92
  if chat:
93
  # add chatbot output on end. Assumes serialize=False
94
+ kwargs.update(dict(chatbot=[]))
95
 
96
  return kwargs, list(kwargs.values())
97
 
 
110
  *tuple(args),
111
  api_name=api_name,
112
  )
113
+ print("Raw client result: %s" % res, flush=True)
114
  res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
115
+ response=md_to_text(ast.literal_eval(res)['response']),
116
+ sources=ast.literal_eval(res)['sources'])
117
  print(res_dict)
118
  return res_dict
119
 
120
 
121
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
122
  def test_client_chat():
123
+ return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50,
124
+ langchain_mode='Disabled')
125
 
 
 
126
 
127
+ def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode):
128
  client = get_client(serialize=False)
129
 
130
+ kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
131
+ max_new_tokens=max_new_tokens, langchain_mode=langchain_mode)
132
+ return run_client(client, prompt, args, kwargs)
133
+
134
+
135
+ def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
136
  res = client.predict(*tuple(args), api_name='/instruction')
137
  args[-1] += [res[-1]]
138
 
 
141
  if not kwargs['stream_output']:
142
  res = client.predict(*tuple(args), api_name='/instruction_bot')
143
  res_dict['response'] = res[0][-1][1]
144
+ print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
145
+ return res_dict, client
146
  else:
147
  job = client.submit(*tuple(args), api_name='/instruction_bot')
148
  res1 = ''
 
151
  if outputs_list:
152
  res = job.communicator.job.outputs[-1]
153
  res1 = res[0][-1][-1]
154
+ res1 = md_to_text(res1, do_md_to_text=do_md_to_text)
155
  print(res1)
156
  time.sleep(0.1)
157
+ full_outputs = job.outputs()
158
+ if verbose:
159
+ print('job.outputs: %s' % str(full_outputs))
160
+ # ensure get ending to avoid race
161
+ # -1 means last response if streaming
162
+ # 0 means get text_output, ignore exception_text
163
+ # 0 means get list within text_output that looks like [[prompt], [answer]]
164
+ # 1 means get bot answer, so will have last bot answer
165
+ res_dict['response'] = md_to_text(full_outputs[-1][0][0][1], do_md_to_text=do_md_to_text)
166
+ return res_dict, client
167
+
168
+
169
+ def md_to_text(md, do_md_to_text=True):
170
+ if not do_md_to_text:
171
+ return md
172
  assert md is not None, "Markdown is None"
173
  html = markdown.markdown(md)
174
  soup = BeautifulSoup(html, features='html.parser')
create_data.py CHANGED
@@ -23,7 +23,7 @@ import pandas as pd
23
  import numpy as np
24
  from tqdm import tqdm
25
 
26
- from utils import flatten_list
27
 
28
 
29
  def parse_rst_file(filepath):
@@ -184,7 +184,7 @@ def setup_dai_docs(path=None, dst="working_dir_docs", from_hf=False):
184
  return dst
185
 
186
 
187
- def rst_to_outputs(files, min_len=30, max_len=2048//2 - 30):
188
  # account for sequence length (context window) including prompt and input and output
189
 
190
  # os.system('pandoc -f rst -t plain ./expert_settings/nlp_settings.rst')
@@ -274,22 +274,6 @@ def test_scrape_dai_docs_all_pandoc():
274
  f.write(json.dumps(save_thing, indent=2))
275
 
276
 
277
- def remove(path: str):
278
- try:
279
- if path is not None and os.path.exists(path):
280
- if os.path.isdir(path):
281
- shutil_rmtree(path, ignore_errors=True)
282
- else:
283
- with contextlib.suppress(FileNotFoundError):
284
- os.remove(path)
285
- except:
286
- pass
287
-
288
-
289
- def shutil_rmtree(*args, **kwargs):
290
- return shutil.rmtree(*args, **kwargs)
291
-
292
-
293
  def test_config_to_json():
294
  """
295
  Needs to run from Driverless AI source directory.
@@ -310,15 +294,18 @@ def test_config_to_json():
310
  [
311
  {
312
  'prompt_type': 'plain',
313
- 'instruction': f"<human>: What does {k} do?\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace("\n", ""),
 
314
  },
315
  {
316
  'prompt_type': 'plain',
317
- 'instruction': f"<human>: Explain {k}.\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace("\n", ""),
 
318
  },
319
  {
320
  'prompt_type': 'plain',
321
- 'instruction': f"<human>: How can I do this: {title}.\n<bot>: Set the {k.replace('_', ' ')} config.toml\n<human>:".replace("\n", ""),
 
322
  } if title and comment else None,
323
  {
324
  'prompt_type': 'human_bot',
@@ -420,7 +407,8 @@ def test_prep_instruct_vicuna():
420
  from datasets import load_dataset
421
  filename = 'ShareGPT_unfiltered_cleaned_split.json'
422
  if not os.path.exists(filename):
423
- os.system('wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % filename)
 
424
  data = load_dataset("json", data_files={"train": filename})["train"]
425
  training_rows = []
426
  for i in range(data.num_rows):
@@ -440,6 +428,7 @@ def test_prep_instruct_vicuna():
440
  with open(filename + ".generate_human_bot.train_plain.json", "wt") as f:
441
  f.write(json.dumps(training_rows, indent=2))
442
 
 
443
  POSTFIX = ".generate_human_bot.train_plain.json"
444
 
445
  # https://bair.berkeley.edu/blog/2023/04/03/koala/
@@ -497,10 +486,10 @@ useful_oig_files = ['unified_rallio_safety_and_prosocial.jsonl.parquet',
497
  'unified_mathqa_flanv2_kojma_cot.jsonl.parquet',
498
  'unified_merged_code_xp3.jsonl.parquet',
499
  'unified_multi_news.jsonl.parquet',
500
- #'unified_multi_sum.jsonl.parquet'
501
  'unified_ni.jsonl.gz.parquet',
502
  'unified_openai_summarize_tldr.jsonl.parquet',
503
- #'unified_oscar_en_sample_dialog.jsonl.parquet', # create text containing these N words, not specific
504
  'unified_plot_screenplay_books_dialog.jsonl.parquet',
505
  'unified_soda_dialog.jsonl.parquet',
506
  'unified_unnatural_instructions.jsonl.parquet',
@@ -546,8 +535,8 @@ def test_merge_shuffle_small_sample_oig_data():
546
 
547
  def test_join_jsons():
548
  files = ['config.json'] * 1 + \
549
- ['dai_docs.train_cleaned.json'] * 2 + \
550
- ['dai_faq.json'] * 3
551
  print(files)
552
  lst = []
553
  [lst.extend(json.load(open(fil, 'rt'))) for fil in files]
@@ -570,11 +559,10 @@ def test_make_rlhf_good_data(filename):
570
  f.write(json.dumps(new_rows, indent=2))
571
 
572
 
573
-
574
  def test_show_prompts():
575
  files = ['config.json'] * 1 + \
576
- ['dai_docs.train_cleaned.json'] * 1 + \
577
- ['dai_faq.json'] * 1
578
  file_points = [json.load(open(fil, 'rt')) for fil in files]
579
  from prompter import generate_prompt
580
  for data_points in file_points:
@@ -600,7 +588,7 @@ def test_get_open_datasets():
600
  'license:openrail++',
601
  'license:openrail',
602
  'license:bigscience-bloom-rail-1.0',
603
- #'license:agpl-3.0',
604
  'license:other',
605
  'license:unknown',
606
  # 'license:mpl-2.0', # ok, but would have to include original copyright, license, source, copies in distribution
@@ -610,13 +598,13 @@ def test_get_open_datasets():
610
  'license:cc-by-3.0',
611
  'license:cc-by-2.0',
612
  'license:cc-by-2.5',
613
- #'license:cc-by-sa-4.0', # would require same license
614
  'license:odbl',
615
  'license:pddl',
616
  'license:ms-pl',
617
  'license:zlib',
618
  ]
619
- # bad license: cc-by-nc-4.0
620
 
621
  from huggingface_hub import list_datasets
622
  datasets = flatten_list([[x for x in list_datasets(filter=y)] for y in open_tags])
@@ -656,12 +644,12 @@ def test_get_open_datasets():
656
  'language:' not in str(x.tags) or
657
  'language:en' in str(x.tags)]
658
  small_open_english_tasked_datasets = [x for x in open_english_tasked_datasets if
659
- 'n<1K' in str(x.tags) or
660
- '1K<n<10K' in str(x.tags) or
661
- '1K0<n<100K' in str(x.tags) or
662
- '100K<n<1M' in str(x.tags) or
663
- 'size_category' not in str(x.tags)
664
- ]
665
  # 'aeslc' : email_body, subject -> summarization?
666
  # load_dataset(open_tasked_datasets[0].id).data['train'].to_pandas()
667
  ids = [x.id for x in small_open_english_tasked_datasets]
@@ -689,7 +677,8 @@ def test_get_open_datasets():
689
  'humarin/chatgpt-paraphrases', # Paraphrase using ChatGPT
690
  'Jeska/vaccinchat', # not useful
691
  'alespalla/chatbot_instruction_prompts', # mixes alpaca
692
- 'allenai/prosocial-dialog', # already exlucded, but wrongly in other datasets that say more permissive license
 
693
  'AlekseyKorshuk/persona-chat', # low quality
694
  'bavard/personachat_truecased', # low quality
695
  'adamlin/daily_dialog', # medium quality conversations
@@ -724,7 +713,8 @@ def test_get_open_datasets():
724
  # some ids clearly speech related
725
  small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if 'speech' not in x.id]
726
  # HF testing
727
- small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if 'hf-internal-testing' not in x.id]
 
728
  small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
729
  'chinese' not in x.id]
730
 
@@ -738,7 +728,6 @@ def test_get_open_datasets():
738
  # grep "pip install" getdata9.log
739
  # NOTE: Some datasets have default config, but others are there. Don't know how to access them.
740
 
741
-
742
  """
743
  https://huggingface.co/datasets/wikihow/blob/main/wikihow.py
744
  https://github.com/mahnazkoupaee/WikiHow-Dataset
@@ -773,7 +762,7 @@ def test_get_open_datasets():
773
  def do_one(data_id, num_downloads):
774
  from datasets import load_dataset
775
  out_file = "data_%s.parquet" % str(data_id.replace('/', '_'))
776
- if os.path.isfile(out_file) and os.path.getsize(out_file) > 1024**3:
777
  return
778
  try:
779
  print("Loading data_id %s num_downloads: %s" % (data_id, num_downloads), flush=True)
@@ -881,23 +870,21 @@ useful = ['Dahoas/instruct-human-assistant-prompt',
881
  'lmqg/qg_squad', # context QA
882
  'lmqg/qg_squadshifts', # context QA
883
  'lmqg/qg_subjqa', # context QA
884
- 'pszemraj/HC3-textgen-qa', # QA medium, has human responses -- humans tend to provide links instead of trying to answer
 
885
  'pythonist/newdata', # long context, QA, brief A
886
  'ropes', # long background, situation, question, A
887
  'wikitablequestions', # table -> QA
888
  'bigscience/p3', # context QA but short answers
889
  ]
890
 
891
-
892
-
893
  code_useful = ['0n1xus/codexglue',
894
  'openai_humaneval',
895
  'koutch/staqc',
896
  ]
897
 
898
-
899
  maybe_useful = ['AlekseyKorshuk/comedy-scripts',
900
- 'openbookqa', # hard to parse, low reasoning
901
  'qed', # reasonable QA, but low reasoning
902
  'selqa', # candidate answers
903
  'HuggingFaceH4/instruction-pilot-outputs-filtered',
@@ -905,7 +892,6 @@ maybe_useful = ['AlekseyKorshuk/comedy-scripts',
905
  'npc-engine/light-batch-summarize-dialogue', # dialog summarize, kinda low specific quality
906
  ]
907
 
908
-
909
  summary_useful = ['austin/rheum_abstracts',
910
  'CarperAI/openai_summarize_comparisons', # summarize chosen/rejected
911
  'CarperAI/openai_summarize_tldr', # summarize QA
@@ -928,14 +914,12 @@ summary_useful = ['austin/rheum_abstracts',
928
  'stacked-summaries/stacked-xsum-1024',
929
  ]
930
 
931
-
932
  math_useful = [
933
- 'competition_math'
934
- ]
935
-
936
 
937
  skipped = ['c4', # maybe useful, used for flan, but skipped due to size
938
- ]
939
 
940
  """
941
  To get training data from oig:
@@ -958,14 +942,14 @@ def test_assemble_and_detox():
958
  text_list = df[['text']].values.ravel().tolist()
959
  new_text = []
960
  max_len = 2048 # uber cutoff
961
- MAX_LEN = 2048//2 - 30 # max len per question/answer
962
  for text in tqdm(text_list):
963
  human_starts = [m.start() for m in re.finditer('<human>: ', text)]
964
  if len(human_starts) == 1:
965
  human_starts = [0, len(text)] # always go into for loop below
966
  blurb = ''
967
  for i in range(len(human_starts) - 1):
968
- interaction = text[human_starts[i]: human_starts[i+1]][:max_len]
969
  blurb += interaction
970
  if len(blurb) >= MAX_LEN:
971
  blurb = get_sentences(blurb, length=MAX_LEN)[0]
@@ -1002,17 +986,17 @@ def test_basic_cleaning():
1002
  from profanity_check import predict
1003
  df_list = []
1004
  for data in useful_oig_files:
1005
- #for data in useful_oig_files[:5]:
1006
- #for data in ['unified_openai_summarize_tldr.jsonl.parquet']:
1007
  print("Processing %s" % data, flush=True)
1008
  df = pd.read_parquet(data)
1009
  df = df.reset_index(drop=True)
1010
  # NOTE: Not correct if multiple human-bot interactions, but those dialogs even more desired
1011
- #avg_chars = len(df['text'][0])/(df['text'][0].count(human)+df['text'][0].count(bot))
1012
- df['avg_words'] = df['text'].apply(lambda x: x.count(' ') / (x.count(human) + x.count(bot))/2.0)
1013
  df['avg_bot_words'] = df['text'].apply(lambda x: x.split(bot)[1].count(' ') / x.count(bot))
1014
- #df['bad_words'] = df['text'].apply(lambda x: profanity.contains_profanity(x))
1015
- #low_quality_patterns = ['Write the rest of this wikipedia article']
1016
  res = predict(df['text'])
1017
  df['bad_words'] = res
1018
  df = df.reset_index(drop=True)
@@ -1215,7 +1199,7 @@ def count_human_bot_lengths(df, human=None, bot=None):
1215
  assert len(text)
1216
  list_what = []
1217
  for ii in range(len(starts) - 1):
1218
- interaction = text[starts[ii]: starts[ii+1]]
1219
  if other in interaction:
1220
  interaction = interaction[:interaction.find(other)]
1221
  interaction.strip()
@@ -1416,9 +1400,13 @@ def test_add_open_assistant(fixup_personality, only_personality, deberta_grading
1416
  conv2['message_id'] = None
1417
  conversations = [c for c in conversations if c['message_id']]
1418
  if only_personality:
1419
- all_rows.extend([dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if 'h2oGPT' in c['text']])
 
 
1420
  else:
1421
- all_rows.extend([dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if "What is H2O.ai" not in c['text']])
 
 
1422
  unhelpful = get_unhelpful_list()
1423
  all_rows = [x for x in all_rows if not any(u in x['input'] for u in unhelpful)]
1424
  personality = create_personality_data()
@@ -1484,6 +1472,7 @@ def test_finalize_to_json():
1484
  n_jobs=-1,
1485
  )
1486
  return df[(df['profanity'] == 0)].reset_index(drop=True)
 
1487
  print("Before cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
1488
  df = final_clean(df)
1489
  print("After cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
@@ -1721,7 +1710,7 @@ def test_check_unhelpful():
1721
  # file = 'h2ogpt-oig-oasst1-instruct-cleaned-v2.json'
1722
 
1723
  unhelpful = get_unhelpful_list()
1724
- #data = json.load(open(file, 'rt'))
1725
  df = pd.read_json(file)
1726
 
1727
  use_reward_score_threshold = False
@@ -1733,7 +1722,7 @@ def test_check_unhelpful():
1733
  from nltk.translate.bleu_score import sentence_bleu
1734
 
1735
  def get_bleu(actual, expected_list):
1736
- #return bleu.sentence_score(actual, expected_list).score
1737
  return sentence_bleu(expected_list, actual)
1738
 
1739
  threshold = 0.0
@@ -1770,12 +1759,13 @@ def test_check_unhelpful():
1770
  # pip install sentence_transformers-2.2.2
1771
  from sentence_transformers import SentenceTransformer
1772
  # sent_model = 'bert-base-nli-mean-tokens'
1773
- #sent_model = 'nli-distilroberta-base-v2'
1774
  sent_model = 'all-MiniLM-L6-v2'
1775
  model = SentenceTransformer(sent_model)
1776
  sentence_embeddings = model.encode(unhelpful)
1777
  from sklearn.metrics.pairwise import cosine_similarity
1778
- bots = [x for x in tqdm(bots) if np.max(cosine_similarity(model.encode(x), sentence_embeddings)) < cosine_sim_threshold]
 
1779
 
1780
  bads_bots = {}
1781
  string_all = str(bots)
@@ -1787,7 +1777,8 @@ def test_check_unhelpful():
1787
  pp.pprint(bads_bots)
1788
 
1789
  total_bads_bots = sum(list(bads_bots.values()))
1790
- print('threshold: %g use_bleu_threshold: %g total_bads_bots: %s total_bots: %s total_humans: %s' % (threshold, use_bleu_threshold, total_bads_bots, len(bots), len(humans)), flush=True)
 
1791
 
1792
  # assert len(bads) == 0, bads
1793
  assert len(bads_bots) == 0, bads_bots
 
23
  import numpy as np
24
  from tqdm import tqdm
25
 
26
+ from utils import flatten_list, remove
27
 
28
 
29
  def parse_rst_file(filepath):
 
184
  return dst
185
 
186
 
187
+ def rst_to_outputs(files, min_len=30, max_len=2048 // 2 - 30):
188
  # account for sequence length (context window) including prompt and input and output
189
 
190
  # os.system('pandoc -f rst -t plain ./expert_settings/nlp_settings.rst')
 
274
  f.write(json.dumps(save_thing, indent=2))
275
 
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  def test_config_to_json():
278
  """
279
  Needs to run from Driverless AI source directory.
 
294
  [
295
  {
296
  'prompt_type': 'plain',
297
+ 'instruction': f"<human>: What does {k} do?\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace(
298
+ "\n", ""),
299
  },
300
  {
301
  'prompt_type': 'plain',
302
+ 'instruction': f"<human>: Explain {k}.\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace(
303
+ "\n", ""),
304
  },
305
  {
306
  'prompt_type': 'plain',
307
+ 'instruction': f"<human>: How can I do this: {title}.\n<bot>: Set the {k.replace('_', ' ')} config.toml\n<human>:".replace(
308
+ "\n", ""),
309
  } if title and comment else None,
310
  {
311
  'prompt_type': 'human_bot',
 
407
  from datasets import load_dataset
408
  filename = 'ShareGPT_unfiltered_cleaned_split.json'
409
  if not os.path.exists(filename):
410
+ os.system(
411
+ 'wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % filename)
412
  data = load_dataset("json", data_files={"train": filename})["train"]
413
  training_rows = []
414
  for i in range(data.num_rows):
 
428
  with open(filename + ".generate_human_bot.train_plain.json", "wt") as f:
429
  f.write(json.dumps(training_rows, indent=2))
430
 
431
+
432
  POSTFIX = ".generate_human_bot.train_plain.json"
433
 
434
  # https://bair.berkeley.edu/blog/2023/04/03/koala/
 
486
  'unified_mathqa_flanv2_kojma_cot.jsonl.parquet',
487
  'unified_merged_code_xp3.jsonl.parquet',
488
  'unified_multi_news.jsonl.parquet',
489
+ # 'unified_multi_sum.jsonl.parquet'
490
  'unified_ni.jsonl.gz.parquet',
491
  'unified_openai_summarize_tldr.jsonl.parquet',
492
+ # 'unified_oscar_en_sample_dialog.jsonl.parquet', # create text containing these N words, not specific
493
  'unified_plot_screenplay_books_dialog.jsonl.parquet',
494
  'unified_soda_dialog.jsonl.parquet',
495
  'unified_unnatural_instructions.jsonl.parquet',
 
535
 
536
  def test_join_jsons():
537
  files = ['config.json'] * 1 + \
538
+ ['dai_docs.train_cleaned.json'] * 2 + \
539
+ ['dai_faq.json'] * 3
540
  print(files)
541
  lst = []
542
  [lst.extend(json.load(open(fil, 'rt'))) for fil in files]
 
559
  f.write(json.dumps(new_rows, indent=2))
560
 
561
 
 
562
  def test_show_prompts():
563
  files = ['config.json'] * 1 + \
564
+ ['dai_docs.train_cleaned.json'] * 1 + \
565
+ ['dai_faq.json'] * 1
566
  file_points = [json.load(open(fil, 'rt')) for fil in files]
567
  from prompter import generate_prompt
568
  for data_points in file_points:
 
588
  'license:openrail++',
589
  'license:openrail',
590
  'license:bigscience-bloom-rail-1.0',
591
+ # 'license:agpl-3.0',
592
  'license:other',
593
  'license:unknown',
594
  # 'license:mpl-2.0', # ok, but would have to include original copyright, license, source, copies in distribution
 
598
  'license:cc-by-3.0',
599
  'license:cc-by-2.0',
600
  'license:cc-by-2.5',
601
+ # 'license:cc-by-sa-4.0', # would require same license
602
  'license:odbl',
603
  'license:pddl',
604
  'license:ms-pl',
605
  'license:zlib',
606
  ]
607
+ # bad license: cc-by-nc-4.0
608
 
609
  from huggingface_hub import list_datasets
610
  datasets = flatten_list([[x for x in list_datasets(filter=y)] for y in open_tags])
 
644
  'language:' not in str(x.tags) or
645
  'language:en' in str(x.tags)]
646
  small_open_english_tasked_datasets = [x for x in open_english_tasked_datasets if
647
+ 'n<1K' in str(x.tags) or
648
+ '1K<n<10K' in str(x.tags) or
649
+ '1K0<n<100K' in str(x.tags) or
650
+ '100K<n<1M' in str(x.tags) or
651
+ 'size_category' not in str(x.tags)
652
+ ]
653
  # 'aeslc' : email_body, subject -> summarization?
654
  # load_dataset(open_tasked_datasets[0].id).data['train'].to_pandas()
655
  ids = [x.id for x in small_open_english_tasked_datasets]
 
677
  'humarin/chatgpt-paraphrases', # Paraphrase using ChatGPT
678
  'Jeska/vaccinchat', # not useful
679
  'alespalla/chatbot_instruction_prompts', # mixes alpaca
680
+ 'allenai/prosocial-dialog',
681
+ # already exlucded, but wrongly in other datasets that say more permissive license
682
  'AlekseyKorshuk/persona-chat', # low quality
683
  'bavard/personachat_truecased', # low quality
684
  'adamlin/daily_dialog', # medium quality conversations
 
713
  # some ids clearly speech related
714
  small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if 'speech' not in x.id]
715
  # HF testing
716
+ small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
717
+ 'hf-internal-testing' not in x.id]
718
  small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
719
  'chinese' not in x.id]
720
 
 
728
  # grep "pip install" getdata9.log
729
  # NOTE: Some datasets have default config, but others are there. Don't know how to access them.
730
 
 
731
  """
732
  https://huggingface.co/datasets/wikihow/blob/main/wikihow.py
733
  https://github.com/mahnazkoupaee/WikiHow-Dataset
 
762
  def do_one(data_id, num_downloads):
763
  from datasets import load_dataset
764
  out_file = "data_%s.parquet" % str(data_id.replace('/', '_'))
765
+ if os.path.isfile(out_file) and os.path.getsize(out_file) > 1024 ** 3:
766
  return
767
  try:
768
  print("Loading data_id %s num_downloads: %s" % (data_id, num_downloads), flush=True)
 
870
  'lmqg/qg_squad', # context QA
871
  'lmqg/qg_squadshifts', # context QA
872
  'lmqg/qg_subjqa', # context QA
873
+ 'pszemraj/HC3-textgen-qa',
874
+ # QA medium, has human responses -- humans tend to provide links instead of trying to answer
875
  'pythonist/newdata', # long context, QA, brief A
876
  'ropes', # long background, situation, question, A
877
  'wikitablequestions', # table -> QA
878
  'bigscience/p3', # context QA but short answers
879
  ]
880
 
 
 
881
  code_useful = ['0n1xus/codexglue',
882
  'openai_humaneval',
883
  'koutch/staqc',
884
  ]
885
 
 
886
  maybe_useful = ['AlekseyKorshuk/comedy-scripts',
887
+ 'openbookqa', # hard to parse, low reasoning
888
  'qed', # reasonable QA, but low reasoning
889
  'selqa', # candidate answers
890
  'HuggingFaceH4/instruction-pilot-outputs-filtered',
 
892
  'npc-engine/light-batch-summarize-dialogue', # dialog summarize, kinda low specific quality
893
  ]
894
 
 
895
  summary_useful = ['austin/rheum_abstracts',
896
  'CarperAI/openai_summarize_comparisons', # summarize chosen/rejected
897
  'CarperAI/openai_summarize_tldr', # summarize QA
 
914
  'stacked-summaries/stacked-xsum-1024',
915
  ]
916
 
 
917
  math_useful = [
918
+ 'competition_math'
919
+ ]
 
920
 
921
  skipped = ['c4', # maybe useful, used for flan, but skipped due to size
922
+ ]
923
 
924
  """
925
  To get training data from oig:
 
942
  text_list = df[['text']].values.ravel().tolist()
943
  new_text = []
944
  max_len = 2048 # uber cutoff
945
+ MAX_LEN = 2048 // 2 - 30 # max len per question/answer
946
  for text in tqdm(text_list):
947
  human_starts = [m.start() for m in re.finditer('<human>: ', text)]
948
  if len(human_starts) == 1:
949
  human_starts = [0, len(text)] # always go into for loop below
950
  blurb = ''
951
  for i in range(len(human_starts) - 1):
952
+ interaction = text[human_starts[i]: human_starts[i + 1]][:max_len]
953
  blurb += interaction
954
  if len(blurb) >= MAX_LEN:
955
  blurb = get_sentences(blurb, length=MAX_LEN)[0]
 
986
  from profanity_check import predict
987
  df_list = []
988
  for data in useful_oig_files:
989
+ # for data in useful_oig_files[:5]:
990
+ # for data in ['unified_openai_summarize_tldr.jsonl.parquet']:
991
  print("Processing %s" % data, flush=True)
992
  df = pd.read_parquet(data)
993
  df = df.reset_index(drop=True)
994
  # NOTE: Not correct if multiple human-bot interactions, but those dialogs even more desired
995
+ # avg_chars = len(df['text'][0])/(df['text'][0].count(human)+df['text'][0].count(bot))
996
+ df['avg_words'] = df['text'].apply(lambda x: x.count(' ') / (x.count(human) + x.count(bot)) / 2.0)
997
  df['avg_bot_words'] = df['text'].apply(lambda x: x.split(bot)[1].count(' ') / x.count(bot))
998
+ # df['bad_words'] = df['text'].apply(lambda x: profanity.contains_profanity(x))
999
+ # low_quality_patterns = ['Write the rest of this wikipedia article']
1000
  res = predict(df['text'])
1001
  df['bad_words'] = res
1002
  df = df.reset_index(drop=True)
 
1199
  assert len(text)
1200
  list_what = []
1201
  for ii in range(len(starts) - 1):
1202
+ interaction = text[starts[ii]: starts[ii + 1]]
1203
  if other in interaction:
1204
  interaction = interaction[:interaction.find(other)]
1205
  interaction.strip()
 
1400
  conv2['message_id'] = None
1401
  conversations = [c for c in conversations if c['message_id']]
1402
  if only_personality:
1403
+ all_rows.extend(
1404
+ [dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if
1405
+ 'h2oGPT' in c['text']])
1406
  else:
1407
+ all_rows.extend(
1408
+ [dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if
1409
+ "What is H2O.ai" not in c['text']])
1410
  unhelpful = get_unhelpful_list()
1411
  all_rows = [x for x in all_rows if not any(u in x['input'] for u in unhelpful)]
1412
  personality = create_personality_data()
 
1472
  n_jobs=-1,
1473
  )
1474
  return df[(df['profanity'] == 0)].reset_index(drop=True)
1475
+
1476
  print("Before cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
1477
  df = final_clean(df)
1478
  print("After cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
 
1710
  # file = 'h2ogpt-oig-oasst1-instruct-cleaned-v2.json'
1711
 
1712
  unhelpful = get_unhelpful_list()
1713
+ # data = json.load(open(file, 'rt'))
1714
  df = pd.read_json(file)
1715
 
1716
  use_reward_score_threshold = False
 
1722
  from nltk.translate.bleu_score import sentence_bleu
1723
 
1724
  def get_bleu(actual, expected_list):
1725
+ # return bleu.sentence_score(actual, expected_list).score
1726
  return sentence_bleu(expected_list, actual)
1727
 
1728
  threshold = 0.0
 
1759
  # pip install sentence_transformers-2.2.2
1760
  from sentence_transformers import SentenceTransformer
1761
  # sent_model = 'bert-base-nli-mean-tokens'
1762
+ # sent_model = 'nli-distilroberta-base-v2'
1763
  sent_model = 'all-MiniLM-L6-v2'
1764
  model = SentenceTransformer(sent_model)
1765
  sentence_embeddings = model.encode(unhelpful)
1766
  from sklearn.metrics.pairwise import cosine_similarity
1767
+ bots = [x for x in tqdm(bots) if
1768
+ np.max(cosine_similarity(model.encode(x), sentence_embeddings)) < cosine_sim_threshold]
1769
 
1770
  bads_bots = {}
1771
  string_all = str(bots)
 
1777
  pp.pprint(bads_bots)
1778
 
1779
  total_bads_bots = sum(list(bads_bots.values()))
1780
+ print('threshold: %g use_bleu_threshold: %g total_bads_bots: %s total_bots: %s total_humans: %s' % (
1781
+ threshold, use_bleu_threshold, total_bads_bots, len(bots), len(humans)), flush=True)
1782
 
1783
  # assert len(bads) == 0, bads
1784
  assert len(bads_bots) == 0, bads_bots
finetune.py CHANGED
@@ -65,7 +65,8 @@ def train(
65
  micro_batch_size: int = 4,
66
  gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
67
  fp16=True,
68
- train_8bit=True,
 
69
 
70
  # general training hyperparams
71
  num_epochs: float = 1,
@@ -185,10 +186,12 @@ def train(
185
  model = model_loader.from_pretrained(
186
  base_model,
187
  load_in_8bit=train_8bit,
 
188
  device_map=device_map,
189
  torch_dtype=torch.float16,
190
  max_memory=max_memory,
191
  local_files_only=local_files_only,
 
192
  resume_download=resume_download,
193
  use_auth_token=use_auth_token,
194
  )
@@ -200,19 +203,12 @@ def train(
200
 
201
  tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
202
 
203
- if train_8bit:
204
  from peft import (
205
- prepare_model_for_int8_training,
206
  )
207
 
208
- if "gpt-neox" not in base_model or True:
209
- model = prepare_model_for_int8_training(model)
210
- else:
211
- model = prepare_model_for_int8_training(
212
- model,
213
- output_embedding_layer_name="embed_out", # keep output logits in float32
214
- layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
215
- )
216
 
217
  from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
218
  try:
 
65
  micro_batch_size: int = 4,
66
  gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
67
  fp16=True,
68
+ train_8bit=False,
69
+ train_4bit=False,
70
 
71
  # general training hyperparams
72
  num_epochs: float = 1,
 
186
  model = model_loader.from_pretrained(
187
  base_model,
188
  load_in_8bit=train_8bit,
189
+ load_in_4bit=train_4bit,
190
  device_map=device_map,
191
  torch_dtype=torch.float16,
192
  max_memory=max_memory,
193
  local_files_only=local_files_only,
194
+ trust_remote_code=True,
195
  resume_download=resume_download,
196
  use_auth_token=use_auth_token,
197
  )
 
203
 
204
  tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
205
 
206
+ if train_8bit or train_4bit:
207
  from peft import (
208
+ prepare_model_for_kbit_training,
209
  )
210
 
211
+ model = prepare_model_for_kbit_training(model)
 
 
 
 
 
 
 
212
 
213
  from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
214
  try:
generate.py CHANGED
@@ -9,24 +9,25 @@ import os
9
  import time
10
  import traceback
11
  import typing
 
12
  from datetime import datetime
13
  import filelock
14
  import psutil
15
 
 
 
 
 
16
  from loaders import get_loaders
17
  from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
18
- import_matplotlib, get_device, makedirs
19
 
20
  import_matplotlib()
21
- from matplotlib import pyplot as plt
22
 
23
  SEED = 1236
24
  set_seed(SEED)
25
 
26
- os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
27
  from typing import Union
28
- import numpy as np
29
- import pandas as pd
30
 
31
  import fire
32
  import torch
@@ -34,7 +35,7 @@ from peft import PeftModel
34
  from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
35
  from accelerate import init_empty_weights, infer_auto_device_map
36
 
37
- from prompter import Prompter, inv_prompt_type_to_model_lower
38
  from stopping import get_stopping
39
 
40
  eval_extra_columns = ['prompt', 'response', 'score']
@@ -47,12 +48,14 @@ scratch_base_dir = '/tmp/'
47
 
48
  def main(
49
  load_8bit: bool = False,
 
50
  load_half: bool = True,
51
  infer_devices: bool = True,
52
  base_model: str = '',
53
  tokenizer_base_model: str = '',
54
  lora_weights: str = "",
55
  gpu_id: int = 0,
 
56
 
57
  prompt_type: Union[int, str] = None,
58
  # input to generation
@@ -68,6 +71,7 @@ def main(
68
  early_stopping: Union[bool, str] = None,
69
  max_time: float = None,
70
 
 
71
  debug: bool = False,
72
  save_dir: str = None,
73
  share: bool = True,
@@ -80,15 +84,18 @@ def main(
80
  src_lang: str = "English",
81
  tgt_lang: str = "Russian",
82
 
 
 
83
  gradio: bool = True,
84
  gradio_avoid_processing_markdown: bool = False,
 
85
  chat: bool = True,
86
  chat_context: bool = False,
87
  stream_output: bool = True,
88
  show_examples: bool = None,
89
  verbose: bool = False,
90
- h2ocolors: bool = True,
91
- height: int = 400,
92
  show_lora: bool = True,
93
  login_mode_if_model0: bool = False,
94
  block_gradio_exit: bool = True,
@@ -107,13 +114,16 @@ def main(
107
  score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
108
  auto_score: bool = True,
109
 
110
- eval_sharegpt_prompts_only: int = 0,
111
- eval_sharegpt_prompts_only_seed: int = 1234,
112
- eval_sharegpt_as_output: bool = False,
 
113
 
114
  langchain_mode: str = 'Disabled',
115
  visible_langchain_modes: list = ['UserData', 'MyData'],
 
116
  user_path: str = None,
 
117
  load_db_if_exists: bool = True,
118
  keep_sources_in_context: bool = False,
119
  db_type: str = 'chroma',
@@ -127,7 +137,7 @@ def main(
127
  enable_sources_list: bool = True,
128
  chunk: bool = True,
129
  chunk_size: int = 512,
130
- k: int = 4,
131
  n_jobs: int = -1,
132
  enable_captions: bool = True,
133
  captions_model: str = "Salesforce/blip-image-captioning-base",
@@ -138,12 +148,14 @@ def main(
138
  """
139
 
140
  :param load_8bit: load model in 8-bit using bitsandbytes
 
141
  :param load_half: load model in float16
142
  :param infer_devices: whether to control devices with gpu_id. If False, then spread across GPUs
143
- :param base_model: model HF-type name
144
- :param tokenizer_base_model: tokenizer HF-type name
145
  :param lora_weights: LORA weights path/HF link
146
  :param gpu_id: if infer_devices, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1
 
147
  :param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
148
  :param temperature: generation temperature
149
  :param top_p: generation top_p
@@ -156,6 +168,7 @@ def main(
156
  :param min_new_tokens: generation min tokens
157
  :param early_stopping: generation early stopping
158
  :param max_time: maximum time to allow for generation
 
159
  :param debug: enable debug mode
160
  :param save_dir: directory chat data is saved to
161
  :param share: whether to share the gradio app with sharable URL
@@ -166,8 +179,16 @@ def main(
166
  :param offload_folder: path for spilling model onto disk
167
  :param src_lang: source languages to include if doing translation (None = all)
168
  :param tgt_lang: target languages to include if doing translation (None = all)
 
 
169
  :param gradio: whether to enable gradio, or to enable benchmark mode
170
  :param gradio_avoid_processing_markdown:
 
 
 
 
 
 
171
  :param chat: whether to enable chat mode with chat history
172
  :param chat_context: whether to use extra helpful context if human_bot
173
  :param stream_output: whether to stream output from generate
@@ -190,32 +211,37 @@ def main(
190
  :param extra_lora_options: extra LORA to show in list in gradio
191
  :param score_model: which model to score responses (None means no scoring)
192
  :param auto_score: whether to automatically score responses
193
- :param eval_sharegpt_prompts_only: for no gradio benchmark, if using ShareGPT prompts for eval
194
- :param eval_sharegpt_prompts_only_seed: for no gradio benchmark, if seed for ShareGPT sampling
195
- :param eval_sharegpt_as_output: for no gradio benchmark, whether to test ShareGPT output itself
 
196
  :param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
197
  WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
198
- :param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode
 
 
 
199
  :param visible_langchain_modes: dbs to generate at launch to be ready for LLM
200
  Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
201
  But wiki_full is expensive and requires preparation
202
  To allow scratch space only live in session, add 'MyData' to list
203
  Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
204
  FIXME: Avoid 'All' for now, not implemented
 
205
  :param load_db_if_exists: Whether to load chroma db if exists or re-generate db
206
  :param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
207
- :param db_type: 'faiss' for in-memory or 'chroma' for persisted on disk
208
  :param use_openai_embedding: Whether to use OpenAI embeddings for vector db
209
  :param use_openai_model: Whether to use OpenAI model for use with vector db
210
  :param hf_embedding_model: Which HF embedding model to use for vector db
211
  :param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db
212
  :param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
213
  :param enable_url_upload: Whether to allow upload from URL
214
- :param enable_text_upload: Whether to allow uplaod of text
215
  :param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
216
  :param chunk: Whether to chunk data (True unless know data is already optimally chunked)
217
  :param chunk_size: Size of chunks, with typically top-4 passed to LLM, so neesd to be in context length
218
- :param k: number of chunks to give LLM
219
  :param n_jobs: Number of processors to use when consuming documents (-1 = all, is default)
220
  :param enable_captions: Whether to support captions using BLIP for image files as documents, then preloads that model
221
  :param captions_model: Which model to use for captions.
@@ -233,7 +259,10 @@ def main(
233
  is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
234
  is_gpth2oai = bool(os.getenv("GPT_H2O_AI"))
235
  is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
236
- is_low_mem = is_hf # assumes run on 24GB consumer GPU
 
 
 
237
  admin_pass = os.getenv("ADMIN_PASS")
238
  # will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
239
  # but becomes unrecoverable sometimes if raise, so just be silent for now
@@ -265,21 +294,23 @@ def main(
265
  # by default don't sample, too chatty
266
  do_sample = False if do_sample is None else do_sample
267
 
268
- if is_low_mem:
269
  if not base_model:
270
  base_model = 'h2oai/h2ogpt-oasst1-512-12b'
271
  # don't set load_8bit if passed base_model, doesn't always work so can't just override
272
  load_8bit = True
 
273
  else:
274
  base_model = 'h2oai/h2ogpt-oasst1-512-20b' if not base_model else base_model
275
- if is_low_mem:
276
  load_8bit = True
 
277
  if is_hf:
278
  # must override share if in spaces
279
  share = False
280
  save_dir = os.getenv('SAVE_DIR', save_dir)
281
  score_model = os.getenv('SCORE_MODEL', score_model)
282
- if score_model == 'None':
283
  score_model = ''
284
  concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count))
285
  api_open = bool(int(os.getenv('API_OPEN', api_open)))
@@ -289,6 +320,7 @@ def main(
289
  if n_gpus == 0:
290
  gpu_id = None
291
  load_8bit = False
 
292
  load_half = False
293
  infer_devices = False
294
  torch.backends.cudnn.benchmark = True
@@ -328,12 +360,15 @@ def main(
328
  max_new_tokens, min_new_tokens, early_stopping, max_time,
329
  repetition_penalty, num_return_sequences,
330
  do_sample,
 
 
331
  )
332
 
333
  locals_dict = locals()
334
  locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
335
- print(f"Generating model with params:\n{locals_print}", flush=True)
336
- print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()), flush=True)
 
337
 
338
  if langchain_mode != "Disabled":
339
  # SECOND PLACE where LangChain referenced, but all imports are kept local so not required
@@ -353,7 +388,9 @@ def main(
353
  # FIXME: All should be avoided until scans over each db, shouldn't be separate db
354
  continue
355
  persist_directory1 = 'db_dir_%s' % langchain_mode1 # single place, no special names for each case
356
- db = prep_langchain(persist_directory1, load_db_if_exists, db_type, use_openai_embedding,
 
 
357
  langchain_mode1, user_path,
358
  hf_embedding_model,
359
  kwargs_make_db=locals())
@@ -367,174 +404,30 @@ def main(
367
  assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
368
  assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
369
 
370
- if not gradio:
371
- if eval_sharegpt_prompts_only > 0:
372
- # override default examples with shareGPT ones for human-level eval purposes only
373
- eval_filename = 'ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json'
374
- if not os.path.isfile(eval_filename):
375
- os.system(
376
- 'wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % eval_filename)
377
- import json
378
- data = json.load(open(eval_filename, 'rt'))
379
- # focus on data that starts with human, else likely chopped from other data
380
- turn_start = 0 # odd in general
381
- data = [x for x in data if len(x['conversations']) > turn_start + 1 and
382
- x['conversations'][turn_start]['from'] == 'human' and
383
- x['conversations'][turn_start + 1]['from'] == 'gpt']
384
- np.random.seed(eval_sharegpt_prompts_only_seed)
385
- example1 = examples[-1] # pick reference example
386
- examples = []
387
- responses = []
388
- for i in list(np.random.randint(0, len(data), size=eval_sharegpt_prompts_only)):
389
- assert data[i]['conversations'][turn_start]['from'] == 'human'
390
- instruction = data[i]['conversations'][turn_start]['value']
391
- assert data[i]['conversations'][turn_start + 1]['from'] == 'gpt'
392
- output = data[i]['conversations'][turn_start + 1]['value']
393
- examplenew = example1.copy()
394
- assert not chat, "No gradio must use chat=False, uses nochat instruct"
395
- examplenew[eval_func_param_names.index('instruction_nochat')] = instruction
396
- examplenew[eval_func_param_names.index('iinput_nochat')] = '' # no input
397
- examplenew[eval_func_param_names.index('context')] = get_context(chat_context, prompt_type)
398
- examples.append(examplenew)
399
- responses.append(output)
400
-
401
- num_examples = len(examples)
402
- scoring_path = 'scoring'
403
- os.makedirs(scoring_path, exist_ok=True)
404
- if eval_sharegpt_as_output:
405
- used_base_model = 'gpt35'
406
- used_lora_weights = ''
407
- else:
408
- used_base_model = str(base_model.split('/')[-1])
409
- used_lora_weights = str(lora_weights.split('/')[-1])
410
- eval_filename = "df_scores_%s_%s_%s_%s_%s_%s.parquet" % (num_examples, eval_sharegpt_prompts_only,
411
- eval_sharegpt_prompts_only_seed,
412
- eval_sharegpt_as_output,
413
- used_base_model,
414
- used_lora_weights)
415
- eval_filename = os.path.join(scoring_path, eval_filename)
416
-
417
- # torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
418
- device = 'cpu' if n_gpus == 0 else 'cuda'
419
- context_class = NullContext if n_gpus > 1 or n_gpus == 0 else torch.device
420
-
421
- with context_class(device):
422
- # ensure was set right above before examples generated
423
- assert not stream_output, "stream_output=True does not make sense with example loop"
424
- import time
425
- from functools import partial
426
-
427
- # get score model
428
- smodel, stokenizer, sdevice = get_score_model(**locals())
429
-
430
- if not eval_sharegpt_as_output:
431
- model, tokenizer, device = get_model(**locals())
432
- model_state = [model, tokenizer, device, base_model]
433
- kwargs_evaluate = {k: v for k, v in locals().items() if k in inputs_kwargs_list}
434
- my_db_state = [None]
435
- fun = partial(evaluate, model_state, my_db_state, **kwargs_evaluate)
436
- else:
437
- assert eval_sharegpt_prompts_only > 0
438
-
439
- def get_response(*args, exi=0):
440
- # assumes same ordering of examples and responses
441
- yield responses[exi]
442
-
443
- fun = get_response
444
- t0 = time.time()
445
- score_dump = []
446
-
447
- for exi, ex in enumerate(examples):
448
- instruction = ex[eval_func_param_names.index('instruction_nochat')]
449
- iinput = ex[eval_func_param_names.index('iinput_nochat')]
450
- context = ex[eval_func_param_names.index('context')]
451
- clear_torch_cache()
452
- print("")
453
- print("START" + "=" * 100)
454
- print("Question: %s %s" % (instruction, ('input=%s' % iinput if iinput else '')))
455
- print("-" * 105)
456
- # fun yields as generator, so have to iterate over it
457
- # Also means likely do NOT want --stream_output=True, else would show all generations
458
- gener = fun(*tuple(ex), exi=exi) if eval_sharegpt_as_output else fun(*tuple(ex))
459
- for res in gener:
460
- print(res)
461
- if smodel:
462
- score_with_prompt = False
463
- if score_with_prompt:
464
- data_point = dict(instruction=instruction, input=iinput, context=context)
465
- prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
466
- prompt = prompter.generate_prompt(data_point)
467
- else:
468
- # just raw input and output
469
- if eval_sharegpt_prompts_only > 0:
470
- # only our own examples have this filled at moment
471
- assert iinput in [None, ''], iinput # should be no iinput
472
- if not (chat_context and prompt_type == 'human_bot'):
473
- assert context in [None, ''], context # should be no context
474
- prompt = instruction
475
- cutoff_len = 768 if is_low_mem else 2048
476
- inputs = stokenizer(prompt, res,
477
- return_tensors="pt",
478
- truncation=True,
479
- max_length=cutoff_len)
480
- try:
481
- score = torch.sigmoid(smodel(**inputs).logits[0].float()).cpu().detach().numpy()[0]
482
- except torch.cuda.OutOfMemoryError as e:
483
- print("GPU OOM 1: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
484
- flush=True)
485
- traceback.print_exc()
486
- score = 0.0
487
- clear_torch_cache()
488
- except (Exception, RuntimeError) as e:
489
- if 'Expected all tensors to be on the same device' in str(e) or \
490
- 'expected scalar type Half but found Float' in str(e) or \
491
- 'probability tensor contains either' in str(e) or \
492
- 'cublasLt ran into an error!' in str(e):
493
- print("GPU error: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
494
- flush=True)
495
- traceback.print_exc()
496
- score = 0.0
497
- clear_torch_cache()
498
- else:
499
- raise
500
- print("SCORE %s: %s" % (exi, score), flush=True)
501
- score_dump.append(ex + [prompt, res, score])
502
- # dump every score in case abort
503
- df_scores = pd.DataFrame(score_dump,
504
- columns=eval_func_param_names + eval_extra_columns)
505
- df_scores.to_parquet(eval_filename, index=False)
506
- # plot histogram so far
507
- plt.figure(figsize=(10, 10))
508
- plt.hist(df_scores['score'], bins=20)
509
- score_avg = np.mean(df_scores['score'])
510
- score_median = np.median(df_scores['score'])
511
- plt.title("Score avg: %s median: %s" % (score_avg, score_median))
512
- plt.savefig(eval_filename.replace('.parquet', '.png'))
513
- plt.close()
514
-
515
- print("END" + "=" * 102)
516
- print("")
517
- t2 = time.time()
518
- print("Time taken so far: %.4f about %.4g per example" % (t2 - t0, (t2 - t0) / (1 + exi)))
519
- t1 = time.time()
520
- print("Total time taken: %.4f about %.4g per example" % (t1 - t0, (t1 - t0) / num_examples))
521
- return eval_filename
522
-
523
- if gradio:
524
  # imported here so don't require gradio to run generate
525
  from gradio_runner import go_gradio
526
 
527
  # get default model
528
  all_kwargs = locals().copy()
529
  if all_kwargs.get('base_model') and not all_kwargs['login_mode_if_model0']:
530
- model0, tokenizer0, device = get_model(**all_kwargs)
 
531
  else:
532
  # if empty model, then don't load anything, just get gradio up
533
  model0, tokenizer0, device = None, None, None
534
  model_state0 = [model0, tokenizer0, device, all_kwargs['base_model']]
535
 
536
  # get score model
537
- smodel, stokenizer, sdevice = get_score_model(**all_kwargs)
 
 
538
  score_model_state0 = [smodel, stokenizer, sdevice, score_model]
539
 
540
  if enable_captions:
@@ -546,6 +439,7 @@ def main(
546
  else:
547
  caption_loader = False
548
 
 
549
  go_gradio(**locals())
550
 
551
 
@@ -624,12 +518,15 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
624
  else:
625
  device_map = {'': 'cpu'}
626
  model_kwargs['load_in_8bit'] = False
 
627
  print('device_map: %s' % device_map, flush=True)
628
 
629
  load_in_8bit = model_kwargs.get('load_in_8bit', False)
 
630
  model_kwargs['device_map'] = device_map
 
631
 
632
- if load_in_8bit or not load_half:
633
  model = model_loader.from_pretrained(
634
  base_model,
635
  config=config,
@@ -646,6 +543,7 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
646
 
647
  def get_model(
648
  load_8bit: bool = False,
 
649
  load_half: bool = True,
650
  infer_devices: bool = True,
651
  base_model: str = '',
@@ -659,12 +557,14 @@ def get_model(
659
  use_auth_token: Union[str, bool] = False,
660
  trust_remote_code: bool = True,
661
  offload_folder: str = None,
662
- compile: bool = True,
663
- **kwargs,
 
664
  ):
665
  """
666
 
667
  :param load_8bit: load model in 8-bit, not supported by all models
 
668
  :param load_half: load model in 16-bit
669
  :param infer_devices: Use torch infer of optimal placement of layers on devices (for non-lora case)
670
  For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
@@ -679,26 +579,29 @@ def get_model(
679
  :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
680
  :param trust_remote_code: trust code needed by model
681
  :param offload_folder: offload folder
682
- :param compile: whether to compile torch model
683
- :param kwargs:
684
  :return:
685
  """
686
- print("Get %s model" % base_model, flush=True)
687
- if base_model in ['llama', 'gptj']:
 
688
  from gpt4all_llm import get_model_tokenizer_gpt4all
689
  model, tokenizer, device = get_model_tokenizer_gpt4all(base_model)
690
  return model, tokenizer, device
691
 
692
  if lora_weights is not None and lora_weights.strip():
693
- print("Get %s lora weights" % lora_weights, flush=True)
 
694
  device = get_device()
695
 
696
  if 'gpt2' in base_model.lower():
697
  # RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half
698
  load_8bit = False
 
699
 
700
  assert base_model.strip(), (
701
- "Please choose a base model with --base_model (CLI) or in Models Tab (gradio)"
702
  )
703
 
704
  from transformers import AutoConfig
@@ -709,8 +612,9 @@ def get_model(
709
  llama_type_from_name = "llama" in base_model.lower()
710
  llama_type = llama_type_from_config or llama_type_from_name
711
  if llama_type:
712
- print("Detected as llama type from"
713
- " config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
 
714
 
715
  model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type)
716
  if not tokenizer_base_model:
@@ -744,7 +648,8 @@ def get_model(
744
  )
745
  if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
746
  model_kwargs.update(dict(load_in_8bit=load_8bit,
747
- device_map={"": 0} if load_8bit and device == 'cuda' else "auto",
 
748
  ))
749
  if 'mpt-' in base_model.lower() and gpu_id >= 0:
750
  model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
@@ -753,6 +658,7 @@ def get_model(
753
  # FIXME: could put on other GPUs
754
  model_kwargs['device_map'] = {"": 0} if device == 'cuda' else {"": 'cpu'}
755
  model_kwargs.pop('torch_dtype', None)
 
756
 
757
  if not lora_weights:
758
  with torch.device(device):
@@ -764,7 +670,7 @@ def get_model(
764
  offload_folder=offload_folder,
765
  )
766
  else:
767
- if load_half and not load_8bit:
768
  model = model_loader.from_pretrained(
769
  base_model,
770
  **model_kwargs).half()
@@ -772,7 +678,7 @@ def get_model(
772
  model = model_loader.from_pretrained(
773
  base_model,
774
  **model_kwargs)
775
- elif load_8bit:
776
  model = model_loader.from_pretrained(
777
  base_model,
778
  **model_kwargs
@@ -821,24 +727,62 @@ def get_model(
821
 
822
  if not isinstance(tokenizer, str):
823
  model.eval()
824
- if torch.__version__ >= "2" and sys.platform != "win32" and compile:
825
  model = torch.compile(model)
826
 
 
 
 
 
 
 
827
  return model, tokenizer, device
828
 
829
 
830
- def get_score_model(**kwargs):
831
- # score model
832
- if kwargs.get('score_model') is not None and kwargs.get('score_model').strip():
833
- score_all_kwargs = kwargs.copy()
834
- score_all_kwargs['load_8bit'] = False
835
- score_all_kwargs['load_half'] = False
836
- score_all_kwargs['base_model'] = kwargs.get('score_model').strip()
837
- score_all_kwargs['tokenizer_base_model'] = ''
838
- score_all_kwargs['lora_weights'] = ''
839
- score_all_kwargs['llama_type'] = False
840
- score_all_kwargs['compile'] = False
841
- smodel, stokenizer, sdevice = get_model(**score_all_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842
  else:
843
  smodel, stokenizer, sdevice = None, None, None
844
  return smodel, stokenizer, sdevice
@@ -864,6 +808,7 @@ eval_func_param_names = ['instruction',
864
  'instruction_nochat',
865
  'iinput_nochat',
866
  'langchain_mode',
 
867
  'document_choice',
868
  ]
869
 
@@ -892,6 +837,7 @@ def evaluate(
892
  instruction_nochat,
893
  iinput_nochat,
894
  langchain_mode,
 
895
  document_choice,
896
  # END NOTE: Examples must have same order of parameters
897
  src_lang=None,
@@ -901,27 +847,29 @@ def evaluate(
901
  save_dir=None,
902
  sanitize_bot_response=True,
903
  model_state0=None,
904
- is_low_mem=None,
905
  raise_generate_gpu_exceptions=None,
906
  chat_context=None,
907
  lora_weights=None,
908
  load_db_if_exists=True,
909
  dbs=None,
910
  user_path=None,
 
911
  use_openai_embedding=None,
912
  use_openai_model=None,
913
  hf_embedding_model=None,
914
  chunk=None,
915
  chunk_size=None,
916
  db_type=None,
917
- k=None,
918
  n_jobs=None,
919
  first_para=None,
920
  text_limit=None,
 
 
921
  ):
922
  # ensure passed these
923
  assert concurrency_count is not None
924
- assert is_low_mem is not None
925
  assert raise_generate_gpu_exceptions is not None
926
  assert chat_context is not None
927
  assert use_openai_embedding is not None
@@ -930,7 +878,7 @@ def evaluate(
930
  assert chunk is not None
931
  assert chunk_size is not None
932
  assert db_type is not None
933
- assert k is not None
934
  assert n_jobs is not None
935
  assert first_para is not None
936
 
@@ -940,7 +888,7 @@ def evaluate(
940
  locals_dict.pop('model_state0', None)
941
  print(locals_dict)
942
 
943
- no_model_msg = "Please choose a base model with --base_model (CLI) or in Models Tab (gradio).\nThen start New Conversation"
944
 
945
  if model_state0 is None:
946
  # e.g. for no gradio case, set dummy value, else should be set
@@ -990,7 +938,7 @@ def evaluate(
990
  db1 = dbs[langchain_mode]
991
  else:
992
  db1 = None
993
- if langchain_mode not in [False, 'Disabled', 'ChatLLM', 'LLM'] and db1 is not None or base_model in ['llama', 'gptj']:
994
  query = instruction if not iinput else "%s\n%s" % (instruction, iinput)
995
  outr = ""
996
  # use smaller cut_distanct for wiki_full since so many matches could be obtained, and often irrelevant unless close
@@ -1002,6 +950,7 @@ def evaluate(
1002
  load_db_if_exists=load_db_if_exists,
1003
  db=db1,
1004
  user_path=user_path,
 
1005
  max_new_tokens=max_new_tokens,
1006
  cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary
1007
  use_openai_embedding=use_openai_embedding,
@@ -1014,21 +963,28 @@ def evaluate(
1014
  langchain_mode=langchain_mode,
1015
  document_choice=document_choice,
1016
  db_type=db_type,
1017
- k=k,
1018
  temperature=temperature,
1019
  repetition_penalty=repetition_penalty,
1020
  top_k=top_k,
1021
  top_p=top_p,
1022
  prompt_type=prompt_type,
1023
  n_jobs=n_jobs,
 
 
1024
  ):
1025
- outr = r # doesn't accumulate, new answer every yield, so only save that full answer
1026
- yield r
1027
  if save_dir:
1028
  save_generate_output(output=outr, base_model=base_model, save_dir=save_dir)
1029
- print('Post-Generate Langchain: %s decoded_output: %s' % (str(datetime.now()), len(outr) if outr else -1),
1030
- flush=True)
1031
- if outr:
 
 
 
 
 
1032
  return
1033
 
1034
  if isinstance(tokenizer, str):
@@ -1038,7 +994,7 @@ def evaluate(
1038
  else:
1039
  raise RuntimeError("No such task type %s" % tokenizer)
1040
  # NOTE: uses max_length only
1041
- yield model(prompt, max_length=max_new_tokens)[0][key]
1042
 
1043
  if 'mbart-' in base_model.lower():
1044
  assert src_lang is not None
@@ -1048,7 +1004,7 @@ def evaluate(
1048
  # override, ignore user change
1049
  num_return_sequences = 1
1050
  stopping_criteria = get_stopping(prompt_type, tokenizer, device)
1051
- _, _, max_length_tokenize, max_prompt_length = get_cutoffs(is_low_mem)
1052
  prompt = prompt[-max_prompt_length:]
1053
  inputs = tokenizer(prompt,
1054
  return_tensors="pt",
@@ -1059,6 +1015,10 @@ def evaluate(
1059
  if debug and len(inputs["input_ids"]) > 0:
1060
  print('input_ids length', len(inputs["input_ids"][0]), flush=True)
1061
  input_ids = inputs["input_ids"].to(device)
 
 
 
 
1062
  generation_config = GenerationConfig(
1063
  temperature=float(temperature),
1064
  top_p=float(top_p),
@@ -1111,10 +1071,12 @@ def evaluate(
1111
  # https://github.com/h2oai/h2ogpt/issues/104
1112
  # but only makes sense if concurrency_count == 1
1113
  context_class = NullContext # if concurrency_count > 1 else filelock.FileLock
1114
- print('Pre-Generate: %s' % str(datetime.now()), flush=True)
 
1115
  decoded_output = None
1116
  with context_class("generate.lock"):
1117
- print('Generate: %s' % str(datetime.now()), flush=True)
 
1118
  # decoded tokenized prompt can deviate from prompt due to special characters
1119
  inputs_decoded = decoder(input_ids[0])
1120
  inputs_decoded_raw = decoder_raw(input_ids[0])
@@ -1136,7 +1098,8 @@ def evaluate(
1136
  decoder = decoder_raw
1137
  decoder_kwargs = decoder_raw_kwargs
1138
  else:
1139
- print("WARNING: Special characters in prompt", flush=True)
 
1140
  if stream_output:
1141
  skip_prompt = False
1142
  streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False,
@@ -1155,8 +1118,9 @@ def evaluate(
1155
  if bucket.qsize() > 0 or thread.exc:
1156
  thread.join()
1157
  outputs += new_text
1158
- yield prompter.get_response(outputs, prompt=inputs_decoded,
1159
- sanitize_bot_response=sanitize_bot_response)
 
1160
  except BaseException:
1161
  # if any exception, raise that exception if was from thread, first
1162
  if thread.exc:
@@ -1173,14 +1137,15 @@ def evaluate(
1173
  else:
1174
  outputs = model.generate(**gen_kwargs)
1175
  outputs = [decoder(s) for s in outputs.sequences]
1176
- yield prompter.get_response(outputs, prompt=inputs_decoded,
1177
- sanitize_bot_response=sanitize_bot_response)
1178
  if outputs and len(outputs) >= 1:
1179
  decoded_output = prompt + outputs[0]
1180
  if save_dir and decoded_output:
1181
  save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
1182
- print('Post-Generate: %s decoded_output: %s' % (
1183
- str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
 
1184
 
1185
 
1186
  inputs_list_names = list(inspect.signature(evaluate).parameters)
@@ -1188,12 +1153,15 @@ state_names = ['model_state', 'my_db_state']
1188
  inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names]
1189
 
1190
 
1191
- def get_cutoffs(is_low_mem, for_context=False):
1192
  # help to avoid errors like:
1193
  # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
1194
  # RuntimeError: expected scalar type Half but found Float
1195
  # with - 256
1196
- max_length_tokenize = 768 - 256 if is_low_mem else 2048 - 256
 
 
 
1197
  cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
1198
  output_smallest = 30 * 4
1199
  max_prompt_length = cutoff_len - output_smallest
@@ -1286,7 +1254,7 @@ def get_generate_params(model_lower, chat,
1286
  prompt_type, temperature, top_p, top_k, num_beams,
1287
  max_new_tokens, min_new_tokens, early_stopping, max_time,
1288
  repetition_penalty, num_return_sequences,
1289
- do_sample):
1290
  use_defaults = False
1291
  use_default_examples = True
1292
  examples = []
@@ -1303,7 +1271,8 @@ def get_generate_params(model_lower, chat,
1303
 
1304
  if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
1305
  prompt_type = inv_prompt_type_to_model_lower[model_lower]
1306
- print("Auto-selecting prompt_type=%s for %s" % (prompt_type, model_lower), flush=True)
 
1307
 
1308
  # examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
1309
  if show_examples is None:
@@ -1366,9 +1335,6 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa
1366
  prompt_type = prompt_type or 'plain'
1367
  else:
1368
  prompt_type = ''
1369
- examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
1370
- stream_output, prompt_type or 'plain', 0.1, 0.75, 40, 4, 256, 0, False, max_time_defaults, 1.0, 1,
1371
- False]]
1372
  task_info = "No task"
1373
  if prompt_type == 'instruct':
1374
  task_info = "Answer question or follow imperative as instruction with optionally input."
@@ -1443,13 +1409,15 @@ y = np.random.randint(0, 1, 100)
1443
 
1444
  # fit random forest classifier with 20 estimators""", ''] + params_list,
1445
  ]
 
 
1446
 
1447
  src_lang = "English"
1448
  tgt_lang = "Russian"
1449
 
1450
  # move to correct position
1451
  for example in examples:
1452
- example += [chat, '', '', 'Disabled', ['All']]
1453
  # adjust examples if non-chat mode
1454
  if not chat:
1455
  example[eval_func_param_names.index('instruction_nochat')] = example[
@@ -1521,6 +1489,32 @@ def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_l
1521
  return score
1522
 
1523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1524
  if __name__ == "__main__":
1525
  """
1526
  Examples:
 
9
  import time
10
  import traceback
11
  import typing
12
+ import warnings
13
  from datetime import datetime
14
  import filelock
15
  import psutil
16
 
17
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
18
+ os.environ['BITSANDBYTES_NOWELCOME'] = '1'
19
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
20
+
21
  from loaders import get_loaders
22
  from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
23
+ import_matplotlib, get_device, makedirs, get_kwargs
24
 
25
  import_matplotlib()
 
26
 
27
  SEED = 1236
28
  set_seed(SEED)
29
 
 
30
  from typing import Union
 
 
31
 
32
  import fire
33
  import torch
 
35
  from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
36
  from accelerate import init_empty_weights, infer_auto_device_map
37
 
38
+ from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types
39
  from stopping import get_stopping
40
 
41
  eval_extra_columns = ['prompt', 'response', 'score']
 
48
 
49
  def main(
50
  load_8bit: bool = False,
51
+ load_4bit: bool = False,
52
  load_half: bool = True,
53
  infer_devices: bool = True,
54
  base_model: str = '',
55
  tokenizer_base_model: str = '',
56
  lora_weights: str = "",
57
  gpu_id: int = 0,
58
+ compile_model: bool = True,
59
 
60
  prompt_type: Union[int, str] = None,
61
  # input to generation
 
71
  early_stopping: Union[bool, str] = None,
72
  max_time: float = None,
73
 
74
+ memory_restriction_level: int = None,
75
  debug: bool = False,
76
  save_dir: str = None,
77
  share: bool = True,
 
84
  src_lang: str = "English",
85
  tgt_lang: str = "Russian",
86
 
87
+ cli: bool = False,
88
+ cli_loop: bool = True,
89
  gradio: bool = True,
90
  gradio_avoid_processing_markdown: bool = False,
91
+ gradio_offline_level: int = 0,
92
  chat: bool = True,
93
  chat_context: bool = False,
94
  stream_output: bool = True,
95
  show_examples: bool = None,
96
  verbose: bool = False,
97
+ h2ocolors: bool = False,
98
+ height: int = 600,
99
  show_lora: bool = True,
100
  login_mode_if_model0: bool = False,
101
  block_gradio_exit: bool = True,
 
114
  score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
115
  auto_score: bool = True,
116
 
117
+ eval_filename: str = None,
118
+ eval_prompts_only_num: int = 0,
119
+ eval_prompts_only_seed: int = 1234,
120
+ eval_as_output: bool = False,
121
 
122
  langchain_mode: str = 'Disabled',
123
  visible_langchain_modes: list = ['UserData', 'MyData'],
124
+ document_choice: list = ['All'],
125
  user_path: str = None,
126
+ detect_user_path_changes_every_query: bool = False,
127
  load_db_if_exists: bool = True,
128
  keep_sources_in_context: bool = False,
129
  db_type: str = 'chroma',
 
137
  enable_sources_list: bool = True,
138
  chunk: bool = True,
139
  chunk_size: int = 512,
140
+ top_k_docs: int = 3, # FIXME: Can go back to 4 once https://github.com/h2oai/h2ogpt/issues/192 fixed
141
  n_jobs: int = -1,
142
  enable_captions: bool = True,
143
  captions_model: str = "Salesforce/blip-image-captioning-base",
 
148
  """
149
 
150
  :param load_8bit: load model in 8-bit using bitsandbytes
151
+ :param load_4bit: load model in 4-bit using bitsandbytes
152
  :param load_half: load model in float16
153
  :param infer_devices: whether to control devices with gpu_id. If False, then spread across GPUs
154
+ :param base_model: model HF-type name. If use --base_model to preload model, cannot unload in gradio in models tab
155
+ :param tokenizer_base_model: tokenizer HF-type name. Usually not required, inferred from base_model.
156
  :param lora_weights: LORA weights path/HF link
157
  :param gpu_id: if infer_devices, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1
158
+ :param compile_model Whether to compile the model
159
  :param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
160
  :param temperature: generation temperature
161
  :param top_p: generation top_p
 
168
  :param min_new_tokens: generation min tokens
169
  :param early_stopping: generation early stopping
170
  :param max_time: maximum time to allow for generation
171
+ :param memory_restriction_level: 0 = no restriction to tokens or model, 1 = some restrictions on token 2 = HF like restriction 3 = very low memory case
172
  :param debug: enable debug mode
173
  :param save_dir: directory chat data is saved to
174
  :param share: whether to share the gradio app with sharable URL
 
179
  :param offload_folder: path for spilling model onto disk
180
  :param src_lang: source languages to include if doing translation (None = all)
181
  :param tgt_lang: target languages to include if doing translation (None = all)
182
+ :param cli: whether to use CLI (non-gradio) interface.
183
+ :param cli_loop: whether to loop for CLI (False usually only for testing)
184
  :param gradio: whether to enable gradio, or to enable benchmark mode
185
  :param gradio_avoid_processing_markdown:
186
+ :param gradio_offline_level: > 0, then change fonts so full offline
187
+ == 1 means backend won't need internet for fonts, but front-end UI might if font not cached
188
+ == 2 means backend and frontend don't need internet to download any fonts.
189
+ Note: Some things always disabled include HF telemetry, gradio telemetry, chromadb posthog that involve uploading.
190
+ This option further disables google fonts for downloading, which is less intrusive than uploading,
191
+ but still required in air-gapped case. The fonts don't look as nice as google fonts, but ensure full offline behavior.
192
  :param chat: whether to enable chat mode with chat history
193
  :param chat_context: whether to use extra helpful context if human_bot
194
  :param stream_output: whether to stream output from generate
 
211
  :param extra_lora_options: extra LORA to show in list in gradio
212
  :param score_model: which model to score responses (None means no scoring)
213
  :param auto_score: whether to automatically score responses
214
+ :param eval_filename: json file to use for evaluation, if None is sharegpt
215
+ :param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
216
+ :param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
217
+ :param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
218
  :param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
219
  WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
220
+ :param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
221
+ If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
222
+ :param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
223
+ Expensive for large number of files, so not done by default. By default only detect changes during db loading.
224
  :param visible_langchain_modes: dbs to generate at launch to be ready for LLM
225
  Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
226
  But wiki_full is expensive and requires preparation
227
  To allow scratch space only live in session, add 'MyData' to list
228
  Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
229
  FIXME: Avoid 'All' for now, not implemented
230
+ :param document_choice: Default document choice when taking subset of collection
231
  :param load_db_if_exists: Whether to load chroma db if exists or re-generate db
232
  :param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
233
+ :param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
234
  :param use_openai_embedding: Whether to use OpenAI embeddings for vector db
235
  :param use_openai_model: Whether to use OpenAI model for use with vector db
236
  :param hf_embedding_model: Which HF embedding model to use for vector db
237
  :param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db
238
  :param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
239
  :param enable_url_upload: Whether to allow upload from URL
240
+ :param enable_text_upload: Whether to allow upload of text
241
  :param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
242
  :param chunk: Whether to chunk data (True unless know data is already optimally chunked)
243
  :param chunk_size: Size of chunks, with typically top-4 passed to LLM, so neesd to be in context length
244
+ :param top_k_docs: number of chunks to give LLM
245
  :param n_jobs: Number of processors to use when consuming documents (-1 = all, is default)
246
  :param enable_captions: Whether to support captions using BLIP for image files as documents, then preloads that model
247
  :param captions_model: Which model to use for captions.
 
259
  is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
260
  is_gpth2oai = bool(os.getenv("GPT_H2O_AI"))
261
  is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
262
+ if memory_restriction_level is None:
263
+ memory_restriction_level = 2 if is_hf else 0 # 2 assumes run on 24GB consumer GPU
264
+ else:
265
+ assert 0 <= memory_restriction_level <= 3, "Bad memory_restriction_level=%s" % memory_restriction_level
266
  admin_pass = os.getenv("ADMIN_PASS")
267
  # will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
268
  # but becomes unrecoverable sometimes if raise, so just be silent for now
 
294
  # by default don't sample, too chatty
295
  do_sample = False if do_sample is None else do_sample
296
 
297
+ if memory_restriction_level == 2:
298
  if not base_model:
299
  base_model = 'h2oai/h2ogpt-oasst1-512-12b'
300
  # don't set load_8bit if passed base_model, doesn't always work so can't just override
301
  load_8bit = True
302
+ load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
303
  else:
304
  base_model = 'h2oai/h2ogpt-oasst1-512-20b' if not base_model else base_model
305
+ if memory_restriction_level >= 2:
306
  load_8bit = True
307
+ load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
308
  if is_hf:
309
  # must override share if in spaces
310
  share = False
311
  save_dir = os.getenv('SAVE_DIR', save_dir)
312
  score_model = os.getenv('SCORE_MODEL', score_model)
313
+ if score_model == 'None' or score_model is None:
314
  score_model = ''
315
  concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count))
316
  api_open = bool(int(os.getenv('API_OPEN', api_open)))
 
320
  if n_gpus == 0:
321
  gpu_id = None
322
  load_8bit = False
323
+ load_4bit = False
324
  load_half = False
325
  infer_devices = False
326
  torch.backends.cudnn.benchmark = True
 
360
  max_new_tokens, min_new_tokens, early_stopping, max_time,
361
  repetition_penalty, num_return_sequences,
362
  do_sample,
363
+ top_k_docs,
364
+ verbose,
365
  )
366
 
367
  locals_dict = locals()
368
  locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
369
+ if verbose:
370
+ print(f"Generating model with params:\n{locals_print}", flush=True)
371
+ print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()), flush=True)
372
 
373
  if langchain_mode != "Disabled":
374
  # SECOND PLACE where LangChain referenced, but all imports are kept local so not required
 
388
  # FIXME: All should be avoided until scans over each db, shouldn't be separate db
389
  continue
390
  persist_directory1 = 'db_dir_%s' % langchain_mode1 # single place, no special names for each case
391
+ db = prep_langchain(persist_directory1,
392
+ load_db_if_exists,
393
+ db_type, use_openai_embedding,
394
  langchain_mode1, user_path,
395
  hf_embedding_model,
396
  kwargs_make_db=locals())
 
404
  assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
405
  assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
406
 
407
+ if cli:
408
+ from cli import run_cli
409
+ return run_cli(**get_kwargs(run_cli, exclude_names=['model_state0'], **locals()))
410
+ elif not gradio:
411
+ from eval import run_eval
412
+ return run_eval(**get_kwargs(run_eval, exclude_names=['model_state0'], **locals()))
413
+ elif gradio:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  # imported here so don't require gradio to run generate
415
  from gradio_runner import go_gradio
416
 
417
  # get default model
418
  all_kwargs = locals().copy()
419
  if all_kwargs.get('base_model') and not all_kwargs['login_mode_if_model0']:
420
+ model0, tokenizer0, device = get_model(reward_type=False,
421
+ **get_kwargs(get_model, exclude_names=['reward_type'], **all_kwargs))
422
  else:
423
  # if empty model, then don't load anything, just get gradio up
424
  model0, tokenizer0, device = None, None, None
425
  model_state0 = [model0, tokenizer0, device, all_kwargs['base_model']]
426
 
427
  # get score model
428
+ smodel, stokenizer, sdevice = get_score_model(reward_type=True,
429
+ **get_kwargs(get_score_model, exclude_names=['reward_type'],
430
+ **all_kwargs))
431
  score_model_state0 = [smodel, stokenizer, sdevice, score_model]
432
 
433
  if enable_captions:
 
439
  else:
440
  caption_loader = False
441
 
442
+ # assume gradio needs everything
443
  go_gradio(**locals())
444
 
445
 
 
518
  else:
519
  device_map = {'': 'cpu'}
520
  model_kwargs['load_in_8bit'] = False
521
+ model_kwargs['load_in_4bit'] = False
522
  print('device_map: %s' % device_map, flush=True)
523
 
524
  load_in_8bit = model_kwargs.get('load_in_8bit', False)
525
+ load_in_4bit = model_kwargs.get('load_in_4bit', False)
526
  model_kwargs['device_map'] = device_map
527
+ pop_unused_model_kwargs(model_kwargs)
528
 
529
+ if load_in_8bit or load_in_4bit or not load_half:
530
  model = model_loader.from_pretrained(
531
  base_model,
532
  config=config,
 
543
 
544
  def get_model(
545
  load_8bit: bool = False,
546
+ load_4bit: bool = False,
547
  load_half: bool = True,
548
  infer_devices: bool = True,
549
  base_model: str = '',
 
557
  use_auth_token: Union[str, bool] = False,
558
  trust_remote_code: bool = True,
559
  offload_folder: str = None,
560
+ compile_model: bool = True,
561
+
562
+ verbose: bool = False,
563
  ):
564
  """
565
 
566
  :param load_8bit: load model in 8-bit, not supported by all models
567
+ :param load_4bit: load model in 4-bit, not supported by all models
568
  :param load_half: load model in 16-bit
569
  :param infer_devices: Use torch infer of optimal placement of layers on devices (for non-lora case)
570
  For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
 
579
  :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
580
  :param trust_remote_code: trust code needed by model
581
  :param offload_folder: offload folder
582
+ :param compile_model: whether to compile torch model
583
+ :param verbose:
584
  :return:
585
  """
586
+ if verbose:
587
+ print("Get %s model" % base_model, flush=True)
588
+ if base_model in non_hf_types:
589
  from gpt4all_llm import get_model_tokenizer_gpt4all
590
  model, tokenizer, device = get_model_tokenizer_gpt4all(base_model)
591
  return model, tokenizer, device
592
 
593
  if lora_weights is not None and lora_weights.strip():
594
+ if verbose:
595
+ print("Get %s lora weights" % lora_weights, flush=True)
596
  device = get_device()
597
 
598
  if 'gpt2' in base_model.lower():
599
  # RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half
600
  load_8bit = False
601
+ load_4bit = False
602
 
603
  assert base_model.strip(), (
604
+ "Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)"
605
  )
606
 
607
  from transformers import AutoConfig
 
612
  llama_type_from_name = "llama" in base_model.lower()
613
  llama_type = llama_type_from_config or llama_type_from_name
614
  if llama_type:
615
+ if verbose:
616
+ print("Detected as llama type from"
617
+ " config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
618
 
619
  model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type)
620
  if not tokenizer_base_model:
 
648
  )
649
  if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
650
  model_kwargs.update(dict(load_in_8bit=load_8bit,
651
+ load_in_4bit=load_4bit,
652
+ device_map={"": 0} if (load_8bit or load_4bit) and device == 'cuda' else "auto",
653
  ))
654
  if 'mpt-' in base_model.lower() and gpu_id >= 0:
655
  model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
 
658
  # FIXME: could put on other GPUs
659
  model_kwargs['device_map'] = {"": 0} if device == 'cuda' else {"": 'cpu'}
660
  model_kwargs.pop('torch_dtype', None)
661
+ pop_unused_model_kwargs(model_kwargs)
662
 
663
  if not lora_weights:
664
  with torch.device(device):
 
670
  offload_folder=offload_folder,
671
  )
672
  else:
673
+ if load_half and not (load_8bit or load_4bit):
674
  model = model_loader.from_pretrained(
675
  base_model,
676
  **model_kwargs).half()
 
678
  model = model_loader.from_pretrained(
679
  base_model,
680
  **model_kwargs)
681
+ elif load_8bit or load_4bit:
682
  model = model_loader.from_pretrained(
683
  base_model,
684
  **model_kwargs
 
727
 
728
  if not isinstance(tokenizer, str):
729
  model.eval()
730
+ if torch.__version__ >= "2" and sys.platform != "win32" and compile_model:
731
  model = torch.compile(model)
732
 
733
+ if hasattr(config, 'max_position_embeddings') and isinstance(config.max_position_embeddings, int):
734
+ # help automatically limit inputs to generate
735
+ tokenizer.model_max_length = config.max_position_embeddings
736
+ else:
737
+ tokenizer.model_max_length = 2048
738
+
739
  return model, tokenizer, device
740
 
741
 
742
+ def pop_unused_model_kwargs(model_kwargs):
743
+ """
744
+ in-place pop unused kwargs that are not dependency-upgrade friendly
745
+ no point passing in False, is default, and helps avoid needing to update requirements for new deps
746
+ :param model_kwargs:
747
+ :return:
748
+ """
749
+ check_list = ['load_in_8bit', 'load_in_4bit']
750
+ for k in check_list:
751
+ if k in model_kwargs and not model_kwargs[k]:
752
+ model_kwargs.pop(k)
753
+
754
+
755
+ def get_score_model(score_model: str = None,
756
+ load_8bit: bool = False,
757
+ load_4bit: bool = False,
758
+ load_half: bool = True,
759
+ infer_devices: bool = True,
760
+ base_model: str = '',
761
+ tokenizer_base_model: str = '',
762
+ lora_weights: str = "",
763
+ gpu_id: int = 0,
764
+
765
+ reward_type: bool = None,
766
+ local_files_only: bool = False,
767
+ resume_download: bool = True,
768
+ use_auth_token: Union[str, bool] = False,
769
+ trust_remote_code: bool = True,
770
+ offload_folder: str = None,
771
+ compile_model: bool = True,
772
+
773
+ verbose: bool = False,
774
+ ):
775
+ if score_model is not None and score_model.strip():
776
+ load_8bit = False
777
+ load_4bit = False
778
+ load_half = False
779
+ base_model = score_model.strip()
780
+ tokenizer_base_model = ''
781
+ lora_weights = ''
782
+ llama_type = False
783
+ compile_model = False
784
+ smodel, stokenizer, sdevice = get_model(reward_type=True,
785
+ **get_kwargs(get_model, exclude_names=['reward_type'], **locals()))
786
  else:
787
  smodel, stokenizer, sdevice = None, None, None
788
  return smodel, stokenizer, sdevice
 
808
  'instruction_nochat',
809
  'iinput_nochat',
810
  'langchain_mode',
811
+ 'top_k_docs',
812
  'document_choice',
813
  ]
814
 
 
837
  instruction_nochat,
838
  iinput_nochat,
839
  langchain_mode,
840
+ top_k_docs,
841
  document_choice,
842
  # END NOTE: Examples must have same order of parameters
843
  src_lang=None,
 
847
  save_dir=None,
848
  sanitize_bot_response=True,
849
  model_state0=None,
850
+ memory_restriction_level=None,
851
  raise_generate_gpu_exceptions=None,
852
  chat_context=None,
853
  lora_weights=None,
854
  load_db_if_exists=True,
855
  dbs=None,
856
  user_path=None,
857
+ detect_user_path_changes_every_query=None,
858
  use_openai_embedding=None,
859
  use_openai_model=None,
860
  hf_embedding_model=None,
861
  chunk=None,
862
  chunk_size=None,
863
  db_type=None,
 
864
  n_jobs=None,
865
  first_para=None,
866
  text_limit=None,
867
+ verbose=False,
868
+ cli=False,
869
  ):
870
  # ensure passed these
871
  assert concurrency_count is not None
872
+ assert memory_restriction_level is not None
873
  assert raise_generate_gpu_exceptions is not None
874
  assert chat_context is not None
875
  assert use_openai_embedding is not None
 
878
  assert chunk is not None
879
  assert chunk_size is not None
880
  assert db_type is not None
881
+ assert top_k_docs is not None and isinstance(top_k_docs, int)
882
  assert n_jobs is not None
883
  assert first_para is not None
884
 
 
888
  locals_dict.pop('model_state0', None)
889
  print(locals_dict)
890
 
891
+ no_model_msg = "Please choose a base model with --base_model (CLI) or load in Models Tab (gradio).\nThen start New Conversation"
892
 
893
  if model_state0 is None:
894
  # e.g. for no gradio case, set dummy value, else should be set
 
938
  db1 = dbs[langchain_mode]
939
  else:
940
  db1 = None
941
+ if langchain_mode not in [False, 'Disabled', 'ChatLLM', 'LLM'] and db1 is not None or base_model in non_hf_types:
942
  query = instruction if not iinput else "%s\n%s" % (instruction, iinput)
943
  outr = ""
944
  # use smaller cut_distanct for wiki_full since so many matches could be obtained, and often irrelevant unless close
 
950
  load_db_if_exists=load_db_if_exists,
951
  db=db1,
952
  user_path=user_path,
953
+ detect_user_path_changes_every_query=detect_user_path_changes_every_query,
954
  max_new_tokens=max_new_tokens,
955
  cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary
956
  use_openai_embedding=use_openai_embedding,
 
963
  langchain_mode=langchain_mode,
964
  document_choice=document_choice,
965
  db_type=db_type,
966
+ k=top_k_docs,
967
  temperature=temperature,
968
  repetition_penalty=repetition_penalty,
969
  top_k=top_k,
970
  top_p=top_p,
971
  prompt_type=prompt_type,
972
  n_jobs=n_jobs,
973
+ verbose=verbose,
974
+ cli=cli,
975
  ):
976
+ outr, extra = r # doesn't accumulate, new answer every yield, so only save that full answer
977
+ yield dict(response=outr, sources=extra)
978
  if save_dir:
979
  save_generate_output(output=outr, base_model=base_model, save_dir=save_dir)
980
+ if verbose:
981
+ print(
982
+ 'Post-Generate Langchain: %s decoded_output: %s' % (str(datetime.now()), len(outr) if outr else -1),
983
+ flush=True)
984
+ if outr or base_model in non_hf_types:
985
+ # if got no response (e.g. not showing sources and got no sources,
986
+ # so nothing to give to LLM), then slip through and ask LLM
987
+ # Or if llama/gptj, then just return since they had no response and can't go down below code path
988
  return
989
 
990
  if isinstance(tokenizer, str):
 
994
  else:
995
  raise RuntimeError("No such task type %s" % tokenizer)
996
  # NOTE: uses max_length only
997
+ yield dict(response=model(prompt, max_length=max_new_tokens)[0][key], sources='')
998
 
999
  if 'mbart-' in base_model.lower():
1000
  assert src_lang is not None
 
1004
  # override, ignore user change
1005
  num_return_sequences = 1
1006
  stopping_criteria = get_stopping(prompt_type, tokenizer, device)
1007
+ _, _, max_length_tokenize, max_prompt_length = get_cutoffs(memory_restriction_level, model_max_length=tokenizer.model_max_length)
1008
  prompt = prompt[-max_prompt_length:]
1009
  inputs = tokenizer(prompt,
1010
  return_tensors="pt",
 
1015
  if debug and len(inputs["input_ids"]) > 0:
1016
  print('input_ids length', len(inputs["input_ids"][0]), flush=True)
1017
  input_ids = inputs["input_ids"].to(device)
1018
+ # CRITICAL LIMIT else will fail
1019
+ max_max_tokens = tokenizer.model_max_length
1020
+ max_input_tokens = max_max_tokens - max_new_tokens
1021
+ input_ids = input_ids[:, -max_input_tokens:]
1022
  generation_config = GenerationConfig(
1023
  temperature=float(temperature),
1024
  top_p=float(top_p),
 
1071
  # https://github.com/h2oai/h2ogpt/issues/104
1072
  # but only makes sense if concurrency_count == 1
1073
  context_class = NullContext # if concurrency_count > 1 else filelock.FileLock
1074
+ if verbose:
1075
+ print('Pre-Generate: %s' % str(datetime.now()), flush=True)
1076
  decoded_output = None
1077
  with context_class("generate.lock"):
1078
+ if verbose:
1079
+ print('Generate: %s' % str(datetime.now()), flush=True)
1080
  # decoded tokenized prompt can deviate from prompt due to special characters
1081
  inputs_decoded = decoder(input_ids[0])
1082
  inputs_decoded_raw = decoder_raw(input_ids[0])
 
1098
  decoder = decoder_raw
1099
  decoder_kwargs = decoder_raw_kwargs
1100
  else:
1101
+ if verbose:
1102
+ print("WARNING: Special characters in prompt", flush=True)
1103
  if stream_output:
1104
  skip_prompt = False
1105
  streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False,
 
1118
  if bucket.qsize() > 0 or thread.exc:
1119
  thread.join()
1120
  outputs += new_text
1121
+ yield dict(response=prompter.get_response(outputs, prompt=inputs_decoded,
1122
+ sanitize_bot_response=sanitize_bot_response),
1123
+ sources='')
1124
  except BaseException:
1125
  # if any exception, raise that exception if was from thread, first
1126
  if thread.exc:
 
1137
  else:
1138
  outputs = model.generate(**gen_kwargs)
1139
  outputs = [decoder(s) for s in outputs.sequences]
1140
+ yield dict(response=prompter.get_response(outputs, prompt=inputs_decoded,
1141
+ sanitize_bot_response=sanitize_bot_response), sources='')
1142
  if outputs and len(outputs) >= 1:
1143
  decoded_output = prompt + outputs[0]
1144
  if save_dir and decoded_output:
1145
  save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
1146
+ if verbose:
1147
+ print('Post-Generate: %s decoded_output: %s' % (
1148
+ str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
1149
 
1150
 
1151
  inputs_list_names = list(inspect.signature(evaluate).parameters)
 
1153
  inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names]
1154
 
1155
 
1156
+ def get_cutoffs(memory_restriction_level, for_context=False, model_max_length=2048):
1157
  # help to avoid errors like:
1158
  # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
1159
  # RuntimeError: expected scalar type Half but found Float
1160
  # with - 256
1161
+ if memory_restriction_level > 0:
1162
+ max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256
1163
+ else:
1164
+ max_length_tokenize = model_max_length - 256
1165
  cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
1166
  output_smallest = 30 * 4
1167
  max_prompt_length = cutoff_len - output_smallest
 
1254
  prompt_type, temperature, top_p, top_k, num_beams,
1255
  max_new_tokens, min_new_tokens, early_stopping, max_time,
1256
  repetition_penalty, num_return_sequences,
1257
+ do_sample, k, verbose):
1258
  use_defaults = False
1259
  use_default_examples = True
1260
  examples = []
 
1271
 
1272
  if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
1273
  prompt_type = inv_prompt_type_to_model_lower[model_lower]
1274
+ if verbose:
1275
+ print("Auto-selecting prompt_type=%s for %s" % (prompt_type, model_lower), flush=True)
1276
 
1277
  # examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
1278
  if show_examples is None:
 
1335
  prompt_type = prompt_type or 'plain'
1336
  else:
1337
  prompt_type = ''
 
 
 
1338
  task_info = "No task"
1339
  if prompt_type == 'instruct':
1340
  task_info = "Answer question or follow imperative as instruction with optionally input."
 
1409
 
1410
  # fit random forest classifier with 20 estimators""", ''] + params_list,
1411
  ]
1412
+ # add summary example
1413
+ examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else ''] + params_list]
1414
 
1415
  src_lang = "English"
1416
  tgt_lang = "Russian"
1417
 
1418
  # move to correct position
1419
  for example in examples:
1420
+ example += [chat, '', '', 'Disabled', k, ['All']]
1421
  # adjust examples if non-chat mode
1422
  if not chat:
1423
  example[eval_func_param_names.index('instruction_nochat')] = example[
 
1489
  return score
1490
 
1491
 
1492
+ def check_locals(**kwargs):
1493
+ # ensure everything in evaluate is here
1494
+ can_skip_because_locally_generated = [ # evaluate
1495
+ 'instruction',
1496
+ 'iinput',
1497
+ 'context',
1498
+ 'instruction_nochat',
1499
+ 'iinput_nochat',
1500
+ # get_model:
1501
+ 'reward_type'
1502
+ ]
1503
+ for k in eval_func_param_names:
1504
+ if k in can_skip_because_locally_generated:
1505
+ continue
1506
+ assert k in kwargs, "Missing %s" % k
1507
+ for k in inputs_kwargs_list:
1508
+ if k in can_skip_because_locally_generated:
1509
+ continue
1510
+ assert k in kwargs, "Missing %s" % k
1511
+
1512
+ for k in list(inspect.signature(get_model).parameters):
1513
+ if k in can_skip_because_locally_generated:
1514
+ continue
1515
+ assert k in kwargs, "Missing %s" % k
1516
+
1517
+
1518
  if __name__ == "__main__":
1519
  """
1520
  Examples:
gpt4all_llm.py CHANGED
@@ -1,5 +1,6 @@
1
  import inspect
2
  import os
 
3
  from typing import Dict, Any, Optional, List
4
  from langchain.callbacks.manager import CallbackManagerForLLMRun
5
  from pydantic import root_validator
@@ -21,11 +22,11 @@ class FakeTokenizer:
21
 
22
  def get_model_tokenizer_gpt4all(base_model, **kwargs):
23
  # defaults (some of these are generation parameters, so need to be passed in at generation time)
24
- model_kwargs = dict(n_ctx=kwargs.get('max_new_tokens', 256),
25
- n_threads=os.cpu_count() // 2,
26
  temp=kwargs.get('temperature', 0.2),
27
  top_p=kwargs.get('top_p', 0.75),
28
- top_k=kwargs.get('top_k', 40))
 
29
  env_gpt4all_file = ".env_gpt4all"
30
  model_kwargs.update(dotenv_values(env_gpt4all_file))
31
 
@@ -33,43 +34,103 @@ def get_model_tokenizer_gpt4all(base_model, **kwargs):
33
  if 'model_path_llama' not in model_kwargs:
34
  raise ValueError("No model_path_llama in %s" % env_gpt4all_file)
35
  model_path = model_kwargs.pop('model_path_llama')
 
 
 
 
 
 
 
 
 
 
 
 
36
  from gpt4all import GPT4All as GPT4AllModel
37
- elif base_model == "gptj":
38
- if 'model_path_gptj' not in model_kwargs:
39
- raise ValueError("No model_path_gptj in %s" % env_gpt4all_file)
40
- model_path = model_kwargs.pop('model_path_gptj')
 
 
41
  from gpt4all import GPT4All as GPT4AllModel
 
42
  else:
43
  raise ValueError("No such base_model %s" % base_model)
44
- func_names = list(inspect.signature(GPT4AllModel).parameters)
45
- model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
46
- model = GPT4AllModel(model_path, **model_kwargs)
47
  return model, FakeTokenizer(), 'cpu'
48
 
49
 
50
- def get_llm_gpt4all(model_name, model=None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  max_new_tokens=256,
52
  temperature=0.1,
53
  repetition_penalty=1.0,
54
  top_k=40,
55
- top_p=0.7):
 
56
  env_gpt4all_file = ".env_gpt4all"
57
- model_kwargs = dotenv_values(env_gpt4all_file)
58
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
59
- callbacks = [StreamingStdOutCallbackHandler()]
60
- n_ctx = model_kwargs.pop('n_ctx', 1024)
61
- default_params = {'context_erase': 0.5, 'n_batch': 1, 'n_ctx': n_ctx, 'n_predict': max_new_tokens,
62
- 'repeat_last_n': 64 if repetition_penalty != 1.0 else 0, 'repeat_penalty': repetition_penalty,
63
- 'temp': temperature, 'top_k': top_k, 'top_p': top_p}
 
 
 
 
 
 
 
 
64
  if model_name == 'llama':
65
- from langchain.llms import LlamaCpp
66
- model_path = model_kwargs.pop('model_path_llama') if model is None else model
67
- llm = LlamaCpp(model_path=model_path, n_ctx=n_ctx, callbacks=callbacks, verbose=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  else:
69
- model_path = model_kwargs.pop('model_path_gptj') if model is None else model
70
- llm = H2OGPT4All(model=model_path, backend='gptj', callbacks=callbacks,
71
- verbose=False, **default_params,
72
- )
73
  return llm
74
 
75
 
@@ -117,3 +178,78 @@ class H2OGPT4All(gpt4all.GPT4All):
117
  if verbose:
118
  print("_call prompt: %s" % prompt, flush=True)
119
  return super()._call(prompt, stop=stop, run_manager=run_manager)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import inspect
2
  import os
3
+ import sys
4
  from typing import Dict, Any, Optional, List
5
  from langchain.callbacks.manager import CallbackManagerForLLMRun
6
  from pydantic import root_validator
 
22
 
23
  def get_model_tokenizer_gpt4all(base_model, **kwargs):
24
  # defaults (some of these are generation parameters, so need to be passed in at generation time)
25
+ model_kwargs = dict(n_threads=os.cpu_count() // 2,
 
26
  temp=kwargs.get('temperature', 0.2),
27
  top_p=kwargs.get('top_p', 0.75),
28
+ top_k=kwargs.get('top_k', 40),
29
+ n_ctx=2048 - 256)
30
  env_gpt4all_file = ".env_gpt4all"
31
  model_kwargs.update(dotenv_values(env_gpt4all_file))
32
 
 
34
  if 'model_path_llama' not in model_kwargs:
35
  raise ValueError("No model_path_llama in %s" % env_gpt4all_file)
36
  model_path = model_kwargs.pop('model_path_llama')
37
+ # FIXME: GPT4All version of llama doesn't handle new quantization, so use llama_cpp_python
38
+ from llama_cpp import Llama
39
+ # llama sets some things at init model time, not generation time
40
+ func_names = list(inspect.signature(Llama.__init__).parameters)
41
+ model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
42
+ model_kwargs['n_ctx'] = int(model_kwargs['n_ctx'])
43
+ model = Llama(model_path=model_path, **model_kwargs)
44
+ elif base_model in "gpt4all_llama":
45
+ if 'model_name_gpt4all_llama' not in model_kwargs and 'model_path_gpt4all_llama' not in model_kwargs:
46
+ raise ValueError("No model_name_gpt4all_llama or model_path_gpt4all_llama in %s" % env_gpt4all_file)
47
+ model_name = model_kwargs.pop('model_name_gpt4all_llama')
48
+ model_type = 'llama'
49
  from gpt4all import GPT4All as GPT4AllModel
50
+ model = GPT4AllModel(model_name=model_name, model_type=model_type)
51
+ elif base_model in "gptj":
52
+ if 'model_name_gptj' not in model_kwargs and 'model_path_gptj' not in model_kwargs:
53
+ raise ValueError("No model_name_gpt4j or model_path_gpt4j in %s" % env_gpt4all_file)
54
+ model_name = model_kwargs.pop('model_name_gptj')
55
+ model_type = 'gptj'
56
  from gpt4all import GPT4All as GPT4AllModel
57
+ model = GPT4AllModel(model_name=model_name, model_type=model_type)
58
  else:
59
  raise ValueError("No such base_model %s" % base_model)
 
 
 
60
  return model, FakeTokenizer(), 'cpu'
61
 
62
 
63
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
64
+
65
+
66
+ class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
67
+
68
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
69
+ """Run on new LLM token. Only available when streaming is enabled."""
70
+ # streaming to std already occurs without this
71
+ # sys.stdout.write(token)
72
+ # sys.stdout.flush()
73
+ pass
74
+
75
+
76
+ def get_model_kwargs(env_kwargs, default_kwargs, cls):
77
+ # default from class
78
+ model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items()}
79
+ # from our defaults
80
+ model_kwargs.update(default_kwargs)
81
+ # from user defaults
82
+ model_kwargs.update(env_kwargs)
83
+ # ensure only valid keys
84
+ func_names = list(inspect.signature(cls).parameters)
85
+ model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
86
+ return model_kwargs
87
+
88
+
89
+ def get_llm_gpt4all(model_name,
90
+ model=None,
91
  max_new_tokens=256,
92
  temperature=0.1,
93
  repetition_penalty=1.0,
94
  top_k=40,
95
+ top_p=0.7,
96
+ verbose=False):
97
  env_gpt4all_file = ".env_gpt4all"
98
+ env_kwargs = dotenv_values(env_gpt4all_file)
99
+ callbacks = [H2OStreamingStdOutCallbackHandler()]
100
+ n_ctx = env_kwargs.pop('n_ctx', 2048 - max_new_tokens)
101
+ default_kwargs = dict(context_erase=0.5,
102
+ n_batch=1,
103
+ n_ctx=n_ctx,
104
+ n_predict=max_new_tokens,
105
+ repeat_last_n=64 if repetition_penalty != 1.0 else 0,
106
+ repeat_penalty=repetition_penalty,
107
+ temp=temperature,
108
+ temperature=temperature,
109
+ top_k=top_k,
110
+ top_p=top_p,
111
+ use_mlock=True,
112
+ verbose=verbose)
113
  if model_name == 'llama':
114
+ cls = H2OLlamaCpp
115
+ model_path = env_kwargs.pop('model_path_llama') if model is None else model
116
+ model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
117
+ model_kwargs.update(dict(model_path=model_path, callbacks=callbacks))
118
+ llm = cls(**model_kwargs)
119
+ llm.client.verbose = verbose
120
+ elif model_name == 'gpt4all_llama':
121
+ cls = H2OGPT4All
122
+ model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
123
+ model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
124
+ model_kwargs.update(dict(model=model_path, backend='llama', callbacks=callbacks))
125
+ llm = cls(**model_kwargs)
126
+ elif model_name == 'gptj':
127
+ cls = H2OGPT4All
128
+ model_path = env_kwargs.pop('model_path_gptj') if model is None else model
129
+ model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
130
+ model_kwargs.update(dict(model=model_path, backend='gptj', callbacks=callbacks))
131
+ llm = cls(**model_kwargs)
132
  else:
133
+ raise RuntimeError("No such model_name %s" % model_name)
 
 
 
134
  return llm
135
 
136
 
 
178
  if verbose:
179
  print("_call prompt: %s" % prompt, flush=True)
180
  return super()._call(prompt, stop=stop, run_manager=run_manager)
181
+
182
+
183
+ from langchain.llms import LlamaCpp
184
+
185
+
186
+ class H2OLlamaCpp(LlamaCpp):
187
+ model_path: Any
188
+ """Path to the pre-trained GPT4All model file."""
189
+
190
+ @root_validator()
191
+ def validate_environment(cls, values: Dict) -> Dict:
192
+ """Validate that llama-cpp-python library is installed."""
193
+ if isinstance(values["model_path"], str):
194
+ model_path = values["model_path"]
195
+ model_param_names = [
196
+ "lora_path",
197
+ "lora_base",
198
+ "n_ctx",
199
+ "n_parts",
200
+ "seed",
201
+ "f16_kv",
202
+ "logits_all",
203
+ "vocab_only",
204
+ "use_mlock",
205
+ "n_threads",
206
+ "n_batch",
207
+ "use_mmap",
208
+ "last_n_tokens_size",
209
+ ]
210
+ model_params = {k: values[k] for k in model_param_names}
211
+ # For backwards compatibility, only include if non-null.
212
+ if values["n_gpu_layers"] is not None:
213
+ model_params["n_gpu_layers"] = values["n_gpu_layers"]
214
+
215
+ try:
216
+ from llama_cpp import Llama
217
+
218
+ values["client"] = Llama(model_path, **model_params)
219
+ except ImportError:
220
+ raise ModuleNotFoundError(
221
+ "Could not import llama-cpp-python library. "
222
+ "Please install the llama-cpp-python library to "
223
+ "use this embedding model: pip install llama-cpp-python"
224
+ )
225
+ except Exception as e:
226
+ raise ValueError(
227
+ f"Could not load Llama model from path: {model_path}. "
228
+ f"Received error {e}"
229
+ )
230
+ else:
231
+ values["client"] = values["model_path"]
232
+ return values
233
+
234
+ def _call(
235
+ self,
236
+ prompt: str,
237
+ stop: Optional[List[str]] = None,
238
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
239
+ ) -> str:
240
+ verbose = False
241
+ # tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
242
+ prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
243
+ num_prompt_tokens = len(prompt_tokens)
244
+ if num_prompt_tokens > self.n_ctx:
245
+ # conservative by using int()
246
+ chars_per_token = int(len(prompt) / num_prompt_tokens)
247
+ prompt = prompt[-self.n_ctx * chars_per_token:]
248
+ if verbose:
249
+ print("reducing tokens, assuming average of %s chars/token: %s" % chars_per_token, flush=True)
250
+ prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
251
+ num_prompt_tokens2 = len(prompt_tokens2)
252
+ print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
253
+ if verbose:
254
+ print("_call prompt: %s" % prompt, flush=True)
255
+ return super()._call(prompt, stop=stop, run_manager=run_manager)
gpt_langchain.py CHANGED
@@ -3,6 +3,7 @@ import inspect
3
  import os
4
  import pathlib
5
  import pickle
 
6
  import shutil
7
  import subprocess
8
  import sys
@@ -16,9 +17,11 @@ from functools import reduce
16
  from operator import concat
17
 
18
  from joblib import Parallel, delayed
 
19
 
 
20
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
21
- get_device
22
 
23
  import_matplotlib()
24
 
@@ -35,7 +38,6 @@ from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, Pytho
35
  EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
36
  UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader
37
  from langchain.text_splitter import RecursiveCharacterTextSplitter
38
- from langchain.vectorstores import FAISS
39
  from langchain.chains.question_answering import load_qa_chain
40
  from langchain.docstore.document import Document
41
  from langchain import PromptTemplate
@@ -43,17 +45,36 @@ from langchain.vectorstores import Chroma
43
 
44
 
45
  def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directory="db_dir", langchain_mode='notset',
 
46
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
47
  if not sources:
48
  return None
49
  # get embedding model
50
  embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
 
 
 
51
 
52
  # Create vector database
53
  if db_type == 'faiss':
 
54
  db = FAISS.from_documents(sources, embedding)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  elif db_type == 'chroma':
56
- collection_name = langchain_mode.replace(' ', '_')
57
  os.makedirs(persist_directory, exist_ok=True)
58
  db = Chroma.from_documents(documents=sources,
59
  embedding=embedding,
@@ -61,34 +82,121 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directo
61
  collection_name=collection_name,
62
  anonymized_telemetry=False)
63
  db.persist()
64
- # FIXME: below just proves can load persistent dir, regenerates its embedding files, so a bit wasteful
65
- if False:
66
- db = Chroma(embedding_function=embedding,
67
- persist_directory=persist_directory,
68
- collection_name=collection_name)
69
  else:
70
  raise RuntimeError("No such db_type=%s" % db_type)
71
 
72
  return db
73
 
74
 
75
- def add_to_db(db, sources, db_type='faiss', avoid_dup=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  if not sources:
77
- return db
78
  if db_type == 'faiss':
79
  db.add_documents(sources)
 
 
 
 
 
 
 
 
 
80
  elif db_type == 'chroma':
81
- if avoid_dup:
82
- collection = db.get()
83
- metadata_sources = set([x['source'] for x in collection['metadatas']])
84
- sources = [x for x in sources if x.metadata['source'] not in metadata_sources]
85
- if len(sources) == 0:
86
- return db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  db.add_documents(documents=sources)
88
  db.persist()
89
  else:
90
  raise RuntimeError("No such db_type=%s" % db_type)
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  return db
93
 
94
 
@@ -126,19 +234,23 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
126
  top_k=40,
127
  top_p=0.7,
128
  prompt_type=None,
 
 
129
  ):
130
  if use_openai_model:
131
  from langchain.llms import OpenAI
132
  llm = OpenAI(temperature=0)
133
  model_name = 'openai'
134
  streamer = None
135
- elif model_name in ['gptj', 'llama']:
 
136
  from gpt4all_llm import get_llm_gpt4all
137
  llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
138
  temperature=temperature,
139
  repetition_penalty=repetition_penalty,
140
  top_k=top_k,
141
  top_p=top_p,
 
142
  )
143
  streamer = None
144
  prompt_type = 'plain'
@@ -149,6 +261,7 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
149
  # only used if didn't pass model in
150
  assert model_name is None
151
  assert tokenizer is None
 
152
  model_name = 'h2oai/h2ogpt-oasst1-512-12b'
153
  # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
154
  # model_name = 'h2oai/h2ogpt-oasst1-512-20b'
@@ -165,7 +278,12 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
165
  torch_dtype=torch_dtype,
166
  load_in_8bit=load_8bit)
167
 
168
- gen_kwargs = dict(max_new_tokens=max_new_tokens, return_full_text=True, early_stopping=False)
 
 
 
 
 
169
  if stream_output:
170
  skip_prompt = False
171
  from generate import H2OTextIteratorStreamer
@@ -175,17 +293,19 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
175
  else:
176
  streamer = None
177
 
178
- if 'h2ogpt' in model_name or prompt_type == 'human_bot':
179
- from h2oai_pipeline import H2OTextGenerationPipeline
180
- pipe = H2OTextGenerationPipeline(model=model, tokenizer=tokenizer, **gen_kwargs)
181
- # pipe.task = "text-generation"
182
- # below makes it listen only to our prompt removal, not built in prompt removal that is less general and not specific for our model
183
- pipe.task = "text2text-generation"
184
- prompt_type = 'human_bot'
185
- else:
186
- # only for non-instruct tuned cases when ok with just normal next token prediction
187
- from transformers import pipeline
188
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, **gen_kwargs)
 
 
189
 
190
  from langchain.llms import HuggingFacePipeline
191
  llm = HuggingFacePipeline(pipeline=pipe)
@@ -341,6 +461,12 @@ try:
341
  except (pkg_resources.DistributionNotFound, AssertionError):
342
  have_arxiv = False
343
 
 
 
 
 
 
 
344
  image_types = ["png", "jpg", "jpeg"]
345
  non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
346
  "md", "html",
@@ -357,9 +483,10 @@ file_types = non_image_types + image_types
357
 
358
  def add_meta(docs1, file):
359
  file_extension = pathlib.Path(file).suffix
 
360
  if not isinstance(docs1, list):
361
  docs1 = [docs1]
362
- [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now))) for x in docs1]
363
 
364
 
365
  def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, chunk=True, chunk_size=512,
@@ -409,42 +536,45 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
409
  f.write(file)
410
  metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
411
  doc1 = Document(page_content=file, metadata=metadata)
412
- elif file.endswith('.html') or file.endswith('.mhtml'):
413
  docs1 = UnstructuredHTMLLoader(file_path=file).load()
414
  add_meta(docs1, file)
415
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
416
- elif (file.endswith('.docx') or file.endswith('.doc')) and have_libreoffice:
417
  docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
418
  add_meta(docs1, file)
419
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
420
- elif file.endswith('.odt'):
421
  docs1 = UnstructuredODTLoader(file_path=file).load()
422
  add_meta(docs1, file)
423
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
424
- elif file.endswith('pptx') or file.endswith('ppt'):
425
  docs1 = UnstructuredPowerPointLoader(file_path=file).load()
426
  add_meta(docs1, file)
427
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
428
- elif file.endswith('.txt'):
429
  # use UnstructuredFileLoader ?
430
- doc1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
 
 
431
  add_meta(doc1, file)
432
- elif file.endswith('.rtf'):
433
  docs1 = UnstructuredRTFLoader(file).load()
434
  add_meta(docs1, file)
435
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
436
- elif file.endswith('.md'):
437
  docs1 = UnstructuredMarkdownLoader(file).load()
438
  add_meta(docs1, file)
439
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
440
- elif file.endswith('.enex'):
441
- doc1 = EverNoteLoader(file).load()
442
  add_meta(doc1, file)
443
- elif file.endswith('.epub'):
 
444
  docs1 = UnstructuredEPubLoader(file).load()
445
  add_meta(docs1, file)
446
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
447
- elif file.endswith('.jpeg') or file.endswith('.jpg') or file.endswith('.png'):
448
  docs1 = []
449
  if have_tesseract and enable_ocr:
450
  # OCR, somewhat works, but not great
@@ -471,13 +601,14 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
471
  docs1.extend(docs1c)
472
  for doci in docs1:
473
  doci.metadata['source'] = doci.metadata['image_path']
 
474
  if docs1:
475
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
476
- elif file.endswith('.msg'):
477
  raise RuntimeError("Not supported, GPL3 license")
478
  # docs1 = OutlookMessageLoader(file).load()
479
  # docs1[0].metadata['source'] = file
480
- elif file.endswith('.eml'):
481
  try:
482
  docs1 = UnstructuredEmailLoader(file).load()
483
  add_meta(docs1, file)
@@ -491,34 +622,43 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
491
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
492
  else:
493
  raise
494
- # elif file.endswith('.gcsdir'):
495
  # doc1 = GCSDirectoryLoader(project_name, bucket, prefix).load()
496
- # elif file.endswith('.gcsfile'):
497
  # doc1 = GCSFileLoader(project_name, bucket, blob).load()
498
- elif file.endswith('.rst'):
499
  with open(file, "r") as f:
500
  doc1 = Document(page_content=f.read(), metadata={"source": file})
501
  add_meta(doc1, file)
502
- elif file.endswith('.pdf'):
 
 
 
 
 
 
 
 
 
 
 
503
  # Some PDFs return nothing or junk from PDFMinerLoader
504
- # e.g. Beyond fine-tuning_ Classifying high resolution mammograms using function-preserving transformations _ Elsevier Enhanced Reader.pdf
505
- doc1 = PyPDFLoader(file).load_and_split()
506
  add_meta(doc1, file)
507
- elif file.endswith('.csv'):
508
  doc1 = CSVLoader(file).load()
509
  add_meta(doc1, file)
510
- elif file.endswith('.py'):
511
  doc1 = PythonLoader(file).load()
512
  add_meta(doc1, file)
513
- elif file.endswith('.toml'):
514
  doc1 = TomlLoader(file).load()
515
  add_meta(doc1, file)
516
- elif file.endswith('.urls'):
517
  with open(file, "r") as f:
518
  docs1 = UnstructuredURLLoader(urls=f.readlines()).load()
519
  add_meta(docs1, file)
520
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
521
- elif file.endswith('.zip'):
522
  with zipfile.ZipFile(file, 'r') as zip_ref:
523
  # don't put into temporary path, since want to keep references to docs inside zip
524
  # so just extract in path where
@@ -529,11 +669,17 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
529
  raise RuntimeError("No file handler for %s" % os.path.basename(file))
530
 
531
  # allow doc1 to be list or not. If not list, did not chunk yet, so chunk now
 
532
  if not isinstance(doc1, list):
533
  if chunk:
534
  docs = chunk_sources([doc1], chunk_size=chunk_size)
535
  else:
536
  docs = [doc1]
 
 
 
 
 
537
  else:
538
  docs = doc1
539
 
@@ -590,6 +736,8 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
590
  captions_model=None,
591
  caption_loader=None,
592
  enable_ocr=False,
 
 
593
  ):
594
  globs_image_types = []
595
  globs_non_image_types = []
@@ -617,6 +765,28 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
617
  # But instead, allow fail so can collect unsupported too
618
  set_globs_image_types = set(globs_image_types)
619
  globs_non_image_types.extend([x for x in path_or_paths if x not in set_globs_image_types])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
  # could use generator, but messes up metadata handling in recursive case
621
  if caption_loader and not isinstance(caption_loader, (bool, str)) and \
622
  caption_loader.device != 'cpu' or \
@@ -643,21 +813,21 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
643
  if n_jobs != 1 and len(globs_non_image_types) > 1:
644
  # avoid nesting, e.g. upload 1 zip and then inside many files
645
  # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
646
- documents = Parallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
647
  delayed(path_to_doc1)(file, **kwargs) for file in globs_non_image_types
648
  )
649
  else:
650
- documents = [path_to_doc1(file, **kwargs) for file in globs_non_image_types]
651
 
652
  # do images separately since can't fork after cuda in parent, so can't be parallel
653
  if n_jobs_image != 1 and len(globs_image_types) > 1:
654
  # avoid nesting, e.g. upload 1 zip and then inside many files
655
  # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
656
- image_documents = Parallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
657
  delayed(path_to_doc1)(file, **kwargs) for file in globs_image_types
658
  )
659
  else:
660
- image_documents = [path_to_doc1(file, **kwargs) for file in globs_image_types]
661
 
662
  # add image docs in
663
  documents += image_documents
@@ -676,7 +846,9 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
676
  return documents
677
 
678
 
679
- def prep_langchain(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode, user_path,
 
 
680
  hf_embedding_model, n_jobs=-1, kwargs_make_db={}):
681
  """
682
  do prep first time, involving downloads
@@ -685,12 +857,18 @@ def prep_langchain(persist_directory, load_db_if_exists, db_type, use_openai_emb
685
  """
686
  assert langchain_mode not in ['MyData'], "Should not prep scratch data"
687
 
688
- if os.path.isdir(persist_directory):
 
 
689
  print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
690
  db = get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
691
  hf_embedding_model)
692
  else:
693
- print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True)
 
 
 
 
694
  db = None
695
  if langchain_mode in ['All', 'DriverlessAI docs']:
696
  # FIXME: Could also just use dai_docs.pickle directly and upload that
@@ -701,19 +879,52 @@ def prep_langchain(persist_directory, load_db_if_exists, db_type, use_openai_emb
701
 
702
  langchain_kwargs = kwargs_make_db.copy()
703
  langchain_kwargs.update(locals())
704
- db = make_db(**langchain_kwargs)
705
 
706
  return db
707
 
708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709
  def get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
710
  hf_embedding_model):
711
  if load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
712
  os.path.join(persist_directory, 'index')):
713
  print("DO Loading db: %s" % langchain_mode, flush=True)
714
  embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
 
 
 
 
715
  db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
716
- collection_name=langchain_mode.replace(' ', '_'))
 
717
  print("DONE Loading db: %s" % langchain_mode, flush=True)
718
  return db
719
  return None
@@ -740,21 +951,40 @@ def _make_db(use_openai_embedding=False,
740
  langchain_mode=None,
741
  user_path=None,
742
  db_type='faiss',
743
- load_db_if_exists=False,
744
  db=None,
745
- n_jobs=-1):
 
746
  persist_directory = 'db_dir_%s' % langchain_mode # single place, no special names for each case
747
  if not db and load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
748
  os.path.join(persist_directory, 'index')):
749
  assert langchain_mode not in ['MyData'], "Should not load MyData db this way"
750
- print("Loading db", flush=True)
751
  embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
 
 
 
 
752
  db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
753
- collection_name=langchain_mode.replace(' ', '_'))
754
- elif not db:
 
 
 
 
 
755
  assert langchain_mode not in ['MyData'], "Should not make MyData db this way"
756
- sources = []
757
- print("Generating sources", flush=True)
 
 
 
 
 
 
 
 
 
758
  if langchain_mode in ['wiki_full', 'All', "'All'"]:
759
  from read_wiki_full import get_all_documents
760
  small_test = None
@@ -783,9 +1013,25 @@ def _make_db(use_openai_embedding=False,
783
  sources.extend(sources1)
784
  if langchain_mode in ['All', 'UserData']:
785
  if user_path:
 
 
 
 
 
 
 
 
 
786
  # chunk internally for speed over multiple docs
787
- sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size)
 
 
 
 
 
 
788
  sources.extend(sources1)
 
789
  else:
790
  print("Chose UserData but user_path is empty/None", flush=True)
791
  if False and langchain_mode in ['urls', 'All', "'All'"]:
@@ -797,14 +1043,48 @@ def _make_db(use_openai_embedding=False,
797
  sources1 = loader.load()
798
  sources.extend(sources1)
799
  if not sources:
800
- print("langchain_mode %s has no sources, not making db" % langchain_mode, flush=True)
801
- return None
802
- print("Generating db", flush=True)
803
- db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type,
804
- persist_directory=persist_directory, langchain_mode=langchain_mode,
805
- hf_embedding_model=hf_embedding_model)
806
- print("Generated db", flush=True)
807
- return db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
808
 
809
 
810
  source_prefix = "Sources [Score | Link]:"
@@ -828,6 +1108,7 @@ def _run_qa_db(query=None,
828
  use_openai_model=False, use_openai_embedding=False,
829
  first_para=False, text_limit=None, k=4, chunk=False, chunk_size=1024,
830
  user_path=None,
 
831
  db_type='faiss',
832
  model_name=None, model=None, tokenizer=None,
833
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
@@ -847,7 +1128,9 @@ def _run_qa_db(query=None,
847
  top_p=0.7,
848
  langchain_mode=None,
849
  document_choice=['All'],
850
- n_jobs=-1):
 
 
851
  """
852
 
853
  :param query:
@@ -859,17 +1142,19 @@ def _run_qa_db(query=None,
859
  :param chunk:
860
  :param chunk_size:
861
  :param user_path: user path to glob recursively from
862
- :param db_type: 'faiss' for in-memory db or 'chroma' for persistent db
863
  :param model_name: model name, used to switch behaviors
864
  :param model: pre-initialized model, else will make new one
865
  :param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None
866
  :param answer_with_sources
867
  :return:
868
  """
869
-
870
- # FIXME: For All just go over all dbs instead of a separate db for All
871
- db = make_db(**locals())
872
- prompt_type = prompter.prompt_type if prompter is not None else prompt_type
 
 
873
  llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
874
  model=model, tokenizer=tokenizer,
875
  stream_output=stream_output,
@@ -879,74 +1164,173 @@ def _run_qa_db(query=None,
879
  top_k=top_k,
880
  top_p=top_p,
881
  prompt_type=prompt_type,
 
 
882
  )
883
 
884
- if model_name in ['llama', 'gptj']:
885
  # FIXME: for now, streams to stdout/stderr currently
886
  stream_output = False
887
 
888
- if not use_openai_model and prompt_type not in ['plain'] or model_name in ['llama', 'gptj']:
889
- # instruct-like, rather than few-shot prompt_type='plain' as default
890
- # but then sources confuse the model with how inserted among rest of text, so avoid
891
- prefix = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
892
  if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
893
  use_context = False
894
- template = """%s{context}{question}""" % prefix
895
  else:
896
  use_context = True
897
- template = """%s
898
- ==
899
- {context}
900
- ==
901
- {question}""" % prefix
902
- prompt = PromptTemplate(
903
- # input_variables=["summaries", "question"],
904
- input_variables=["context", "question"],
905
- template=template,
906
- )
907
- chain = load_qa_chain(llm, prompt=prompt)
908
  else:
909
- chain = load_qa_with_sources_chain(llm)
910
  use_context = True
911
 
912
- if query is None:
913
- query = "What are the main differences between Linux and Windows?"
914
  # https://github.com/hwchase17/langchain/issues/1946
915
  # FIXME: Seems to way to get size of chroma db to limit k to avoid
916
  # Chroma collection MyData contains fewer than 4 elements.
917
  # type logger error
918
  k_db = 1000 if db_type == 'chroma' else k # k=100 works ok too for
919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
920
  if db and use_context:
921
  if isinstance(document_choice, str):
922
  # support string as well
923
  document_choice = [document_choice]
924
- if not isinstance(db, Chroma) or len(document_choice) <= 1 and document_choice[0].lower() == 'all':
 
 
925
  # treat empty list as All for now, not 'None'
926
  filter_kwargs = {}
 
 
 
927
  else:
928
  if len(document_choice) >= 2:
929
  or_filter = [{"source": {"$eq": x}} for x in document_choice]
930
  filter_kwargs = dict(filter={"$or": or_filter})
931
- else:
932
  one_filter = [{"source": {"$eq": x}} for x in document_choice][0]
933
  filter_kwargs = dict(filter=one_filter)
934
- if len(document_choice) == 1 and document_choice[0].lower() == 'none':
 
 
935
  k_db = 1
936
  k = 0
937
  docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:k]
938
  # cut off so no high distance docs/sources considered
939
  docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
940
  scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
941
- if len(scores) > 0:
942
  print("Distance: min: %s max: %s mean: %s median: %s" %
943
  (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
944
  else:
945
  docs = []
946
  scores = []
947
 
948
- if not docs and use_context:
949
- return None
 
 
 
 
 
950
 
951
  common_words_file = "data/NGSL_1.2_stats.csv.zip"
952
  if os.path.isfile(common_words_file):
@@ -958,88 +1342,82 @@ def _run_qa_db(query=None,
958
  num_common = len([x.lower() in set_common for x in reduced_query_words])
959
  frac_common = num_common / len(reduced_query) if reduced_query else 0
960
  # FIXME: report to user bad query that uses too many common words
961
- print("frac_common: %s" % frac_common, flush=True)
 
 
 
 
 
962
 
963
- if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
964
  chain_kwargs = dict(input_documents=[], question=query)
965
  else:
966
  chain_kwargs = dict(input_documents=docs, question=query)
967
 
968
- if stream_output:
969
- answer = None
970
- assert streamer is not None
971
- target = wrapped_partial(chain, chain_kwargs)
972
- import queue
973
- bucket = queue.Queue()
974
- thread = EThread(target=target, streamer=streamer, bucket=bucket)
975
- thread.start()
976
- outputs = ""
977
- prompt = None # FIXME
978
- try:
979
- for new_text in streamer:
980
- # print("new_text: %s" % new_text, flush=True)
981
- if bucket.qsize() > 0 or thread.exc:
982
- thread.join()
983
- outputs += new_text
984
- if prompter: # and False: # FIXME: pipeline can already use prompter
985
- output1 = prompter.get_response(outputs, prompt=prompt,
986
- sanitize_bot_response=sanitize_bot_response)
987
- yield output1
988
- else:
989
- yield outputs
990
- except BaseException:
991
- # if any exception, raise that exception if was from thread, first
992
- if thread.exc:
993
- raise thread.exc
994
- raise
995
- finally:
996
- # in case no exception and didn't join with thread yet, then join
997
- if not thread.exc:
998
- answer = thread.join()
999
- # in case raise StopIteration or broke queue loop in streamer, but still have exception
1000
- if thread.exc:
1001
- raise thread.exc
1002
- # FIXME: answer is not string outputs from streamer. How to get actual final output?
1003
- # answer = outputs
1004
- else:
1005
- answer = chain(chain_kwargs)
1006
 
1007
- if not use_context:
1008
- ret = answer['output_text']
1009
- yield ret
1010
- elif answer is not None:
1011
  print("query: %s" % query, flush=True)
1012
  print("answer: %s" % answer['output_text'], flush=True)
1013
- # link
1014
- answer_sources = [(max(0.0, 1.5 - score) / 1.5, get_url(doc)) for score, doc in
1015
- zip(scores, answer['input_documents'])]
1016
- answer_sources_dict = defaultdict(list)
1017
- [answer_sources_dict[url].append(score) for score, url in answer_sources]
1018
- answers_dict = {}
1019
- for url, scores_url in answer_sources_dict.items():
1020
- answers_dict[url] = np.max(scores_url)
1021
- answer_sources = [(score, url) for url, score in answers_dict.items()]
1022
- answer_sources.sort(key=lambda x: x[0], reverse=True)
1023
- if show_rank:
1024
- # answer_sources = ['%d | %s' % (1 + rank, url) for rank, (score, url) in enumerate(answer_sources)]
1025
- # sorted_sources_urls = "Sources [Rank | Link]:<br>" + "<br>".join(answer_sources)
1026
- answer_sources = ['%s' % url for rank, (score, url) in enumerate(answer_sources)]
1027
- sorted_sources_urls = "Ranked Sources:<br>" + "<br>".join(answer_sources)
1028
- else:
1029
- answer_sources = ['<li>%.2g | %s</li>' % (score, url) for score, url in answer_sources]
1030
- sorted_sources_urls = f"{source_prefix}<p><ul>" + "<p>".join(answer_sources)
1031
- sorted_sources_urls += f"</ul></p>{source_postfix}"
1032
 
1033
- if not answer['output_text'].endswith('\n'):
1034
- answer['output_text'] += '\n'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1035
 
1036
- if answer_with_sources:
1037
- ret = answer['output_text'] + '\n' + sorted_sources_urls
1038
- else:
1039
- ret = answer['output_text']
1040
 
1041
- yield ret
1042
- return
 
 
 
 
1043
 
1044
 
1045
  def chunk_sources(sources, chunk_size=1024):
 
3
  import os
4
  import pathlib
5
  import pickle
6
+ import queue
7
  import shutil
8
  import subprocess
9
  import sys
 
17
  from operator import concat
18
 
19
  from joblib import Parallel, delayed
20
+ from tqdm import tqdm
21
 
22
+ from prompter import non_hf_types
23
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
24
+ get_device, ProgressParallel, remove, hash_file
25
 
26
  import_matplotlib()
27
 
 
38
  EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
39
  UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader
40
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
41
  from langchain.chains.question_answering import load_qa_chain
42
  from langchain.docstore.document import Document
43
  from langchain import PromptTemplate
 
45
 
46
 
47
  def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directory="db_dir", langchain_mode='notset',
48
+ collection_name=None,
49
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
50
  if not sources:
51
  return None
52
  # get embedding model
53
  embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
54
+ assert collection_name is not None or langchain_mode != 'notset'
55
+ if collection_name is None:
56
+ collection_name = langchain_mode.replace(' ', '_')
57
 
58
  # Create vector database
59
  if db_type == 'faiss':
60
+ from langchain.vectorstores import FAISS
61
  db = FAISS.from_documents(sources, embedding)
62
+
63
+ elif db_type == 'weaviate':
64
+ import weaviate
65
+ from weaviate.embedded import EmbeddedOptions
66
+ from langchain.vectorstores import Weaviate
67
+
68
+ # TODO: add support for connecting via docker compose
69
+ client = weaviate.Client(
70
+ embedded_options=EmbeddedOptions()
71
+ )
72
+ index_name = collection_name.capitalize()
73
+ db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False,
74
+ index_name=index_name)
75
+
76
  elif db_type == 'chroma':
77
+ assert persist_directory is not None
78
  os.makedirs(persist_directory, exist_ok=True)
79
  db = Chroma.from_documents(documents=sources,
80
  embedding=embedding,
 
82
  collection_name=collection_name,
83
  anonymized_telemetry=False)
84
  db.persist()
 
 
 
 
 
85
  else:
86
  raise RuntimeError("No such db_type=%s" % db_type)
87
 
88
  return db
89
 
90
 
91
+ def _get_unique_sources_in_weaviate(db):
92
+ batch_size = 100
93
+ id_source_list = []
94
+ result = db._client.data_object.get(class_name=db._index_name, limit=batch_size)
95
+
96
+ while result['objects']:
97
+ id_source_list += [(obj['id'], obj['properties']['source']) for obj in result['objects']]
98
+ last_id = id_source_list[-1][0]
99
+ result = db._client.data_object.get(class_name=db._index_name, limit=batch_size, after=last_id)
100
+
101
+ unique_sources = {source for _, source in id_source_list}
102
+ return unique_sources
103
+
104
+
105
+ def add_to_db(db, sources, db_type='faiss',
106
+ avoid_dup_by_file=False,
107
+ avoid_dup_by_content=True):
108
+ num_new_sources = len(sources)
109
  if not sources:
110
+ return db, num_new_sources, []
111
  if db_type == 'faiss':
112
  db.add_documents(sources)
113
+ elif db_type == 'weaviate':
114
+ # FIXME: only control by file name, not hash yet
115
+ if avoid_dup_by_file or avoid_dup_by_content:
116
+ unique_sources = _get_unique_sources_in_weaviate(db)
117
+ sources = [x for x in sources if x.metadata['source'] not in unique_sources]
118
+ num_new_sources = len(sources)
119
+ if num_new_sources == 0:
120
+ return db, num_new_sources, []
121
+ db.add_documents(documents=sources)
122
  elif db_type == 'chroma':
123
+ collection = db.get()
124
+ # files we already have:
125
+ metadata_files = set([x['source'] for x in collection['metadatas']])
126
+ if avoid_dup_by_file:
127
+ # Too weak in case file changed content, assume parent shouldn't pass true for this for now
128
+ raise RuntimeError("Not desired code path")
129
+ sources = [x for x in sources if x.metadata['source'] not in metadata_files]
130
+ if avoid_dup_by_content:
131
+ # look at hash, instead of page_content
132
+ # migration: If no hash previously, avoid updating,
133
+ # since don't know if need to update and may be expensive to redo all unhashed files
134
+ metadata_hash_ids = set(
135
+ [x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]])
136
+ # avoid sources with same hash
137
+ sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids]
138
+ # get new file names that match existing file names. delete existing files we are overridding
139
+ dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files])
140
+ print("Removing %s duplicate files from db because ingesting those as new documents" % len(
141
+ dup_metadata_files), flush=True)
142
+ client_collection = db._client.get_collection(name=db._collection.name)
143
+ for dup_file in dup_metadata_files:
144
+ dup_file_meta = dict(source=dup_file)
145
+ try:
146
+ client_collection.delete(where=dup_file_meta)
147
+ except KeyError:
148
+ pass
149
+ num_new_sources = len(sources)
150
+ if num_new_sources == 0:
151
+ return db, num_new_sources, []
152
  db.add_documents(documents=sources)
153
  db.persist()
154
  else:
155
  raise RuntimeError("No such db_type=%s" % db_type)
156
 
157
+ new_sources_metadata = [x.metadata for x in sources]
158
+
159
+ return db, num_new_sources, new_sources_metadata
160
+
161
+
162
+ def create_or_update_db(db_type, persist_directory, collection_name,
163
+ sources, use_openai_embedding, add_if_exists, verbose, hf_embedding_model):
164
+ if db_type == 'weaviate':
165
+ import weaviate
166
+ from weaviate.embedded import EmbeddedOptions
167
+
168
+ # TODO: add support for connecting via docker compose
169
+ client = weaviate.Client(
170
+ embedded_options=EmbeddedOptions()
171
+ )
172
+ index_name = collection_name.replace(' ', '_').capitalize()
173
+ if client.schema.exists(index_name) and not add_if_exists:
174
+ client.schema.delete_class(index_name)
175
+ if verbose:
176
+ print("Removing %s" % index_name, flush=True)
177
+ elif db_type == 'chroma':
178
+ if not os.path.isdir(persist_directory) or not add_if_exists:
179
+ if os.path.isdir(persist_directory):
180
+ if verbose:
181
+ print("Removing %s" % persist_directory, flush=True)
182
+ remove(persist_directory)
183
+ if verbose:
184
+ print("Generating db", flush=True)
185
+
186
+ if not add_if_exists:
187
+ if verbose:
188
+ print("Generating db", flush=True)
189
+ else:
190
+ if verbose:
191
+ print("Loading and updating db", flush=True)
192
+
193
+ db = get_db(sources,
194
+ use_openai_embedding=use_openai_embedding,
195
+ db_type=db_type,
196
+ persist_directory=persist_directory,
197
+ langchain_mode=collection_name,
198
+ hf_embedding_model=hf_embedding_model)
199
+
200
  return db
201
 
202
 
 
234
  top_k=40,
235
  top_p=0.7,
236
  prompt_type=None,
237
+ prompter=None,
238
+ verbose=False,
239
  ):
240
  if use_openai_model:
241
  from langchain.llms import OpenAI
242
  llm = OpenAI(temperature=0)
243
  model_name = 'openai'
244
  streamer = None
245
+ prompt_type = 'plain'
246
+ elif model_name in non_hf_types:
247
  from gpt4all_llm import get_llm_gpt4all
248
  llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
249
  temperature=temperature,
250
  repetition_penalty=repetition_penalty,
251
  top_k=top_k,
252
  top_p=top_p,
253
+ verbose=verbose,
254
  )
255
  streamer = None
256
  prompt_type = 'plain'
 
261
  # only used if didn't pass model in
262
  assert model_name is None
263
  assert tokenizer is None
264
+ prompt_type = 'human_bot'
265
  model_name = 'h2oai/h2ogpt-oasst1-512-12b'
266
  # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
267
  # model_name = 'h2oai/h2ogpt-oasst1-512-20b'
 
278
  torch_dtype=torch_dtype,
279
  load_in_8bit=load_8bit)
280
 
281
+ max_max_tokens = tokenizer.model_max_length
282
+ gen_kwargs = dict(max_new_tokens=max_new_tokens,
283
+ return_full_text=True,
284
+ early_stopping=False,
285
+ handle_long_generation='hole')
286
+
287
  if stream_output:
288
  skip_prompt = False
289
  from generate import H2OTextIteratorStreamer
 
293
  else:
294
  streamer = None
295
 
296
+ from h2oai_pipeline import H2OTextGenerationPipeline
297
+ pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
298
+ prompter=prompter,
299
+ prompt_type=prompt_type,
300
+ sanitize_bot_response=True,
301
+ chat=False, stream_output=stream_output,
302
+ tokenizer=tokenizer,
303
+ max_input_tokens=max_max_tokens - max_new_tokens,
304
+ **gen_kwargs)
305
+ # pipe.task = "text-generation"
306
+ # below makes it listen only to our prompt removal,
307
+ # not built in prompt removal that is less general and not specific for our model
308
+ pipe.task = "text2text-generation"
309
 
310
  from langchain.llms import HuggingFacePipeline
311
  llm = HuggingFacePipeline(pipeline=pipe)
 
461
  except (pkg_resources.DistributionNotFound, AssertionError):
462
  have_arxiv = False
463
 
464
+ try:
465
+ assert pkg_resources.get_distribution('pymupdf') is not None
466
+ have_pymupdf = True
467
+ except (pkg_resources.DistributionNotFound, AssertionError):
468
+ have_pymupdf = False
469
+
470
  image_types = ["png", "jpg", "jpeg"]
471
  non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
472
  "md", "html",
 
483
 
484
  def add_meta(docs1, file):
485
  file_extension = pathlib.Path(file).suffix
486
+ hashid = hash_file(file)
487
  if not isinstance(docs1, list):
488
  docs1 = [docs1]
489
+ [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
490
 
491
 
492
  def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, chunk=True, chunk_size=512,
 
536
  f.write(file)
537
  metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
538
  doc1 = Document(page_content=file, metadata=metadata)
539
+ elif file.lower().endswith('.html') or file.lower().endswith('.mhtml'):
540
  docs1 = UnstructuredHTMLLoader(file_path=file).load()
541
  add_meta(docs1, file)
542
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
543
+ elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
544
  docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
545
  add_meta(docs1, file)
546
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
547
+ elif file.lower().endswith('.odt'):
548
  docs1 = UnstructuredODTLoader(file_path=file).load()
549
  add_meta(docs1, file)
550
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
551
+ elif file.lower().endswith('pptx') or file.lower().endswith('ppt'):
552
  docs1 = UnstructuredPowerPointLoader(file_path=file).load()
553
  add_meta(docs1, file)
554
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
555
+ elif file.lower().endswith('.txt'):
556
  # use UnstructuredFileLoader ?
557
+ docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
558
+ # makes just one, but big one
559
+ doc1 = chunk_sources(docs1, chunk_size=chunk_size)
560
  add_meta(doc1, file)
561
+ elif file.lower().endswith('.rtf'):
562
  docs1 = UnstructuredRTFLoader(file).load()
563
  add_meta(docs1, file)
564
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
565
+ elif file.lower().endswith('.md'):
566
  docs1 = UnstructuredMarkdownLoader(file).load()
567
  add_meta(docs1, file)
568
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
569
+ elif file.lower().endswith('.enex'):
570
+ docs1 = EverNoteLoader(file).load()
571
  add_meta(doc1, file)
572
+ doc1 = chunk_sources(docs1, chunk_size=chunk_size)
573
+ elif file.lower().endswith('.epub'):
574
  docs1 = UnstructuredEPubLoader(file).load()
575
  add_meta(docs1, file)
576
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
577
+ elif file.lower().endswith('.jpeg') or file.lower().endswith('.jpg') or file.lower().endswith('.png'):
578
  docs1 = []
579
  if have_tesseract and enable_ocr:
580
  # OCR, somewhat works, but not great
 
601
  docs1.extend(docs1c)
602
  for doci in docs1:
603
  doci.metadata['source'] = doci.metadata['image_path']
604
+ doci.metadata['hash'] = hash_file(doci.metadata['source'])
605
  if docs1:
606
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
607
+ elif file.lower().endswith('.msg'):
608
  raise RuntimeError("Not supported, GPL3 license")
609
  # docs1 = OutlookMessageLoader(file).load()
610
  # docs1[0].metadata['source'] = file
611
+ elif file.lower().endswith('.eml'):
612
  try:
613
  docs1 = UnstructuredEmailLoader(file).load()
614
  add_meta(docs1, file)
 
622
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
623
  else:
624
  raise
625
+ # elif file.lower().endswith('.gcsdir'):
626
  # doc1 = GCSDirectoryLoader(project_name, bucket, prefix).load()
627
+ # elif file.lower().endswith('.gcsfile'):
628
  # doc1 = GCSFileLoader(project_name, bucket, blob).load()
629
+ elif file.lower().endswith('.rst'):
630
  with open(file, "r") as f:
631
  doc1 = Document(page_content=f.read(), metadata={"source": file})
632
  add_meta(doc1, file)
633
+ elif file.lower().endswith('.pdf'):
634
+ env_gpt4all_file = ".env_gpt4all"
635
+ from dotenv import dotenv_values
636
+ env_kwargs = dotenv_values(env_gpt4all_file)
637
+ pdf_class_name = env_kwargs.get('PDF_CLASS_NAME', 'PyMuPDFParser')
638
+ if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
639
+ # GPL, only use if installed
640
+ from langchain.document_loaders import PyMuPDFLoader
641
+ doc1 = PyMuPDFLoader(file).load_and_split()
642
+ else:
643
+ # open-source fallback
644
+ doc1 = PyPDFLoader(file).load_and_split()
645
  # Some PDFs return nothing or junk from PDFMinerLoader
 
 
646
  add_meta(doc1, file)
647
+ elif file.lower().endswith('.csv'):
648
  doc1 = CSVLoader(file).load()
649
  add_meta(doc1, file)
650
+ elif file.lower().endswith('.py'):
651
  doc1 = PythonLoader(file).load()
652
  add_meta(doc1, file)
653
+ elif file.lower().endswith('.toml'):
654
  doc1 = TomlLoader(file).load()
655
  add_meta(doc1, file)
656
+ elif file.lower().endswith('.urls'):
657
  with open(file, "r") as f:
658
  docs1 = UnstructuredURLLoader(urls=f.readlines()).load()
659
  add_meta(docs1, file)
660
  doc1 = chunk_sources(docs1, chunk_size=chunk_size)
661
+ elif file.lower().endswith('.zip'):
662
  with zipfile.ZipFile(file, 'r') as zip_ref:
663
  # don't put into temporary path, since want to keep references to docs inside zip
664
  # so just extract in path where
 
669
  raise RuntimeError("No file handler for %s" % os.path.basename(file))
670
 
671
  # allow doc1 to be list or not. If not list, did not chunk yet, so chunk now
672
+ # if list of length one, don't trust and chunk it
673
  if not isinstance(doc1, list):
674
  if chunk:
675
  docs = chunk_sources([doc1], chunk_size=chunk_size)
676
  else:
677
  docs = [doc1]
678
+ elif isinstance(doc1, list) and len(doc1) == 1:
679
+ if chunk:
680
+ docs = chunk_sources(doc1, chunk_size=chunk_size)
681
+ else:
682
+ docs = doc1
683
  else:
684
  docs = doc1
685
 
 
736
  captions_model=None,
737
  caption_loader=None,
738
  enable_ocr=False,
739
+ existing_files=[],
740
+ existing_hash_ids={},
741
  ):
742
  globs_image_types = []
743
  globs_non_image_types = []
 
765
  # But instead, allow fail so can collect unsupported too
766
  set_globs_image_types = set(globs_image_types)
767
  globs_non_image_types.extend([x for x in path_or_paths if x not in set_globs_image_types])
768
+
769
+ # filter out any files to skip (e.g. if already processed them)
770
+ # this is easy, but too aggressive in case a file changed, so parent probably passed existing_files=[]
771
+ assert not existing_files, "DEV: assume not using this approach"
772
+ if existing_files:
773
+ set_skip_files = set(existing_files)
774
+ globs_image_types = [x for x in globs_image_types if x not in set_skip_files]
775
+ globs_non_image_types = [x for x in globs_non_image_types if x not in set_skip_files]
776
+ if existing_hash_ids:
777
+ # assume consistent with add_meta() use of hash_file(file)
778
+ # also assume consistent with get_existing_hash_ids for dict creation
779
+ # assume hashable values
780
+ existing_hash_ids_set = set(existing_hash_ids.items())
781
+ hash_ids_all_image = set({x: hash_file(x) for x in globs_image_types}.items())
782
+ hash_ids_all_non_image = set({x: hash_file(x) for x in globs_non_image_types}.items())
783
+ # don't use symmetric diff. If file is gone, ignore and don't remove or something
784
+ # just consider existing files (key) having new hash or not (value)
785
+ new_files_image = set(dict(hash_ids_all_image - existing_hash_ids_set).keys())
786
+ new_files_non_image = set(dict(hash_ids_all_non_image - existing_hash_ids_set).keys())
787
+ globs_image_types = [x for x in globs_image_types if x in new_files_image]
788
+ globs_non_image_types = [x for x in globs_non_image_types if x in new_files_non_image]
789
+
790
  # could use generator, but messes up metadata handling in recursive case
791
  if caption_loader and not isinstance(caption_loader, (bool, str)) and \
792
  caption_loader.device != 'cpu' or \
 
813
  if n_jobs != 1 and len(globs_non_image_types) > 1:
814
  # avoid nesting, e.g. upload 1 zip and then inside many files
815
  # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
816
+ documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
817
  delayed(path_to_doc1)(file, **kwargs) for file in globs_non_image_types
818
  )
819
  else:
820
+ documents = [path_to_doc1(file, **kwargs) for file in tqdm(globs_non_image_types)]
821
 
822
  # do images separately since can't fork after cuda in parent, so can't be parallel
823
  if n_jobs_image != 1 and len(globs_image_types) > 1:
824
  # avoid nesting, e.g. upload 1 zip and then inside many files
825
  # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
826
+ image_documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
827
  delayed(path_to_doc1)(file, **kwargs) for file in globs_image_types
828
  )
829
  else:
830
+ image_documents = [path_to_doc1(file, **kwargs) for file in tqdm(globs_image_types)]
831
 
832
  # add image docs in
833
  documents += image_documents
 
846
  return documents
847
 
848
 
849
+ def prep_langchain(persist_directory,
850
+ load_db_if_exists,
851
+ db_type, use_openai_embedding, langchain_mode, user_path,
852
  hf_embedding_model, n_jobs=-1, kwargs_make_db={}):
853
  """
854
  do prep first time, involving downloads
 
857
  """
858
  assert langchain_mode not in ['MyData'], "Should not prep scratch data"
859
 
860
+ db_dir_exists = os.path.isdir(persist_directory)
861
+
862
+ if db_dir_exists and user_path is None:
863
  print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
864
  db = get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
865
  hf_embedding_model)
866
  else:
867
+ if db_dir_exists and user_path is not None:
868
+ print("Prep: persist_directory=%s exists, user_path=%s passed, adding any changed or new documents" % (
869
+ persist_directory, user_path), flush=True)
870
+ elif not db_dir_exists:
871
+ print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True)
872
  db = None
873
  if langchain_mode in ['All', 'DriverlessAI docs']:
874
  # FIXME: Could also just use dai_docs.pickle directly and upload that
 
879
 
880
  langchain_kwargs = kwargs_make_db.copy()
881
  langchain_kwargs.update(locals())
882
+ db, num_new_sources, new_sources_metadata = make_db(**langchain_kwargs)
883
 
884
  return db
885
 
886
 
887
+ import posthog
888
+
889
+ posthog.disabled = True
890
+
891
+
892
+ class FakeConsumer(object):
893
+ def __init__(self, *args, **kwargs):
894
+ pass
895
+
896
+ def run(self):
897
+ pass
898
+
899
+ def pause(self):
900
+ pass
901
+
902
+ def upload(self):
903
+ pass
904
+
905
+ def next(self):
906
+ pass
907
+
908
+ def request(self, batch):
909
+ pass
910
+
911
+
912
+ posthog.Consumer = FakeConsumer
913
+
914
+
915
  def get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
916
  hf_embedding_model):
917
  if load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
918
  os.path.join(persist_directory, 'index')):
919
  print("DO Loading db: %s" % langchain_mode, flush=True)
920
  embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
921
+ from chromadb.config import Settings
922
+ client_settings = Settings(anonymized_telemetry=False,
923
+ chroma_db_impl="duckdb+parquet",
924
+ persist_directory=persist_directory)
925
  db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
926
+ collection_name=langchain_mode.replace(' ', '_'),
927
+ client_settings=client_settings)
928
  print("DONE Loading db: %s" % langchain_mode, flush=True)
929
  return db
930
  return None
 
951
  langchain_mode=None,
952
  user_path=None,
953
  db_type='faiss',
954
+ load_db_if_exists=True,
955
  db=None,
956
+ n_jobs=-1,
957
+ verbose=False):
958
  persist_directory = 'db_dir_%s' % langchain_mode # single place, no special names for each case
959
  if not db and load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
960
  os.path.join(persist_directory, 'index')):
961
  assert langchain_mode not in ['MyData'], "Should not load MyData db this way"
962
+ print("Loading existing db", flush=True)
963
  embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
964
+ from chromadb.config import Settings
965
+ client_settings = Settings(anonymized_telemetry=False,
966
+ chroma_db_impl="duckdb+parquet",
967
+ persist_directory=persist_directory)
968
  db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
969
+ collection_name=langchain_mode.replace(' ', '_'),
970
+ client_settings=client_settings)
971
+ sources = []
972
+ if not db and langchain_mode not in ['MyData'] or \
973
+ user_path is not None and \
974
+ langchain_mode in ['UserData']:
975
+ # Should not make MyData db this way, why avoided, only upload from UI
976
  assert langchain_mode not in ['MyData'], "Should not make MyData db this way"
977
+ if verbose:
978
+ if langchain_mode in ['UserData']:
979
+ if user_path is not None:
980
+ print("Checking if changed or new sources in %s, and generating sources them" % user_path,
981
+ flush=True)
982
+ elif db is None:
983
+ print("user_path not passed and no db, no sources", flush=True)
984
+ else:
985
+ print("user_path not passed, using only existing db, no new sources", flush=True)
986
+ else:
987
+ print("Generating %s sources" % langchain_mode, flush=True)
988
  if langchain_mode in ['wiki_full', 'All', "'All'"]:
989
  from read_wiki_full import get_all_documents
990
  small_test = None
 
1013
  sources.extend(sources1)
1014
  if langchain_mode in ['All', 'UserData']:
1015
  if user_path:
1016
+ if db is not None:
1017
+ # NOTE: Ignore file names for now, only go by hash ids
1018
+ # existing_files = get_existing_files(db)
1019
+ existing_files = []
1020
+ existing_hash_ids = get_existing_hash_ids(db)
1021
+ else:
1022
+ # pretend no existing files so won't filter
1023
+ existing_files = []
1024
+ existing_hash_ids = []
1025
  # chunk internally for speed over multiple docs
1026
+ sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
1027
+ existing_files=existing_files, existing_hash_ids=existing_hash_ids)
1028
+ new_metadata_sources = set([x.metadata['source'] for x in sources1])
1029
+ if new_metadata_sources:
1030
+ print("Loaded %s new files as sources to add to UserData" % len(new_metadata_sources), flush=True)
1031
+ if verbose:
1032
+ print("Files added: %s" % '\n'.join(new_metadata_sources), flush=True)
1033
  sources.extend(sources1)
1034
+ print("Loaded %s sources for potentially adding to UserData" % len(sources), flush=True)
1035
  else:
1036
  print("Chose UserData but user_path is empty/None", flush=True)
1037
  if False and langchain_mode in ['urls', 'All', "'All'"]:
 
1043
  sources1 = loader.load()
1044
  sources.extend(sources1)
1045
  if not sources:
1046
+ if verbose:
1047
+ if db is not None:
1048
+ print("langchain_mode %s has no new sources, nothing to add to db" % langchain_mode, flush=True)
1049
+ else:
1050
+ print("langchain_mode %s has no sources, not making new db" % langchain_mode, flush=True)
1051
+ return db, 0, []
1052
+ if verbose:
1053
+ if db is not None:
1054
+ print("Generating db", flush=True)
1055
+ else:
1056
+ print("Adding to db", flush=True)
1057
+ if not db:
1058
+ if sources:
1059
+ db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type,
1060
+ persist_directory=persist_directory, langchain_mode=langchain_mode,
1061
+ hf_embedding_model=hf_embedding_model)
1062
+ if verbose:
1063
+ print("Generated db", flush=True)
1064
+ else:
1065
+ print("Did not generate db since no sources", flush=True)
1066
+ new_sources_metadata = [x.metadata for x in sources]
1067
+ elif user_path is not None and langchain_mode in ['UserData']:
1068
+ print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
1069
+ db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type)
1070
+ print("Existing db, added %s new sources from user_path=%s" % (num_new_sources, user_path), flush=True)
1071
+ else:
1072
+ new_sources_metadata = [x.metadata for x in sources]
1073
+
1074
+ return db, len(new_sources_metadata), new_sources_metadata
1075
+
1076
+
1077
+ def get_existing_files(db):
1078
+ collection = db.get()
1079
+ metadata_sources = set([x['source'] for x in collection['metadatas']])
1080
+ return metadata_sources
1081
+
1082
+
1083
+ def get_existing_hash_ids(db):
1084
+ collection = db.get()
1085
+ # assume consistency, that any prior hashed source was single hashed file at the time among all source chunks
1086
+ metadata_hash_ids = {x['source']: x.get('hashid') for x in collection['metadatas']}
1087
+ return metadata_hash_ids
1088
 
1089
 
1090
  source_prefix = "Sources [Score | Link]:"
 
1108
  use_openai_model=False, use_openai_embedding=False,
1109
  first_para=False, text_limit=None, k=4, chunk=False, chunk_size=1024,
1110
  user_path=None,
1111
+ detect_user_path_changes_every_query=False,
1112
  db_type='faiss',
1113
  model_name=None, model=None, tokenizer=None,
1114
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
 
1128
  top_p=0.7,
1129
  langchain_mode=None,
1130
  document_choice=['All'],
1131
+ n_jobs=-1,
1132
+ verbose=False,
1133
+ cli=False):
1134
  """
1135
 
1136
  :param query:
 
1142
  :param chunk:
1143
  :param chunk_size:
1144
  :param user_path: user path to glob recursively from
1145
+ :param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db
1146
  :param model_name: model name, used to switch behaviors
1147
  :param model: pre-initialized model, else will make new one
1148
  :param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None
1149
  :param answer_with_sources
1150
  :return:
1151
  """
1152
+ assert query is not None
1153
+ assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate
1154
+ if prompter is not None:
1155
+ prompt_type = prompter.prompt_type
1156
+ if model is not None:
1157
+ assert prompt_type is not None
1158
  llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
1159
  model=model, tokenizer=tokenizer,
1160
  stream_output=stream_output,
 
1164
  top_k=top_k,
1165
  top_p=top_p,
1166
  prompt_type=prompt_type,
1167
+ prompter=prompter,
1168
+ verbose=verbose,
1169
  )
1170
 
1171
+ if model_name in non_hf_types:
1172
  # FIXME: for now, streams to stdout/stderr currently
1173
  stream_output = False
1174
 
1175
+ use_context = False
1176
+ scores = []
1177
+ chain = None
1178
+
1179
+ func_names = list(inspect.signature(get_similarity_chain).parameters)
1180
+ sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
1181
+ missing_kwargs = [x for x in func_names if x not in sim_kwargs]
1182
+ assert not missing_kwargs, "Missing: %s" % missing_kwargs
1183
+ docs, chain, scores, use_context = get_similarity_chain(**sim_kwargs)
1184
+ if len(document_choice) > 0 and document_choice[0] == 'Only':
1185
+ formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
1186
+ yield formatted_doc_chunks, ''
1187
+ return
1188
+ if chain is None and model_name not in non_hf_types:
1189
+ # can only return if HF type
1190
+ return
1191
+
1192
+ if stream_output:
1193
+ answer = None
1194
+ assert streamer is not None
1195
+ import queue
1196
+ bucket = queue.Queue()
1197
+ thread = EThread(target=chain, streamer=streamer, bucket=bucket)
1198
+ thread.start()
1199
+ outputs = ""
1200
+ prompt = None # FIXME
1201
+ try:
1202
+ for new_text in streamer:
1203
+ # print("new_text: %s" % new_text, flush=True)
1204
+ if bucket.qsize() > 0 or thread.exc:
1205
+ thread.join()
1206
+ outputs += new_text
1207
+ if prompter: # and False: # FIXME: pipeline can already use prompter
1208
+ output1 = prompter.get_response(outputs, prompt=prompt,
1209
+ sanitize_bot_response=sanitize_bot_response)
1210
+ yield output1, ''
1211
+ else:
1212
+ yield outputs, ''
1213
+ except BaseException:
1214
+ # if any exception, raise that exception if was from thread, first
1215
+ if thread.exc:
1216
+ raise thread.exc
1217
+ raise
1218
+ finally:
1219
+ # in case no exception and didn't join with thread yet, then join
1220
+ if not thread.exc:
1221
+ answer = thread.join()
1222
+ # in case raise StopIteration or broke queue loop in streamer, but still have exception
1223
+ if thread.exc:
1224
+ raise thread.exc
1225
+ # FIXME: answer is not string outputs from streamer. How to get actual final output?
1226
+ # answer = outputs
1227
+ else:
1228
+ answer = chain()
1229
+
1230
+ if not use_context:
1231
+ ret = answer['output_text']
1232
+ extra = ''
1233
+ yield ret, extra
1234
+ elif answer is not None:
1235
+ ret, extra = get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=verbose)
1236
+ yield ret, extra
1237
+ return
1238
+
1239
+
1240
+ def get_similarity_chain(query=None,
1241
+ use_openai_model=False, use_openai_embedding=False,
1242
+ first_para=False, text_limit=None, k=4, chunk=False, chunk_size=1024,
1243
+ user_path=None,
1244
+ detect_user_path_changes_every_query=False,
1245
+ db_type='faiss',
1246
+ model_name=None,
1247
+ hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1248
+ prompt_type=None,
1249
+ cut_distanct=1.1,
1250
+ load_db_if_exists=False,
1251
+ db=None,
1252
+ langchain_mode=None,
1253
+ document_choice=['All'],
1254
+ n_jobs=-1,
1255
+ # beyond run_db_query:
1256
+ llm=None,
1257
+ verbose=False,
1258
+ ):
1259
+ # determine whether use of context out of docs is planned
1260
+ if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
1261
  if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
1262
  use_context = False
 
1263
  else:
1264
  use_context = True
 
 
 
 
 
 
 
 
 
 
 
1265
  else:
 
1266
  use_context = True
1267
 
 
 
1268
  # https://github.com/hwchase17/langchain/issues/1946
1269
  # FIXME: Seems to way to get size of chroma db to limit k to avoid
1270
  # Chroma collection MyData contains fewer than 4 elements.
1271
  # type logger error
1272
  k_db = 1000 if db_type == 'chroma' else k # k=100 works ok too for
1273
 
1274
+ # FIXME: For All just go over all dbs instead of a separate db for All
1275
+ if not detect_user_path_changes_every_query and db is not None:
1276
+ # avoid looking at user_path during similarity search db handling,
1277
+ # if already have db and not updating from user_path every query
1278
+ # but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was
1279
+ user_path = None
1280
+ db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
1281
+ hf_embedding_model=hf_embedding_model,
1282
+ first_para=first_para, text_limit=text_limit, chunk=chunk,
1283
+ chunk_size=chunk_size,
1284
+ langchain_mode=langchain_mode,
1285
+ user_path=user_path,
1286
+ db_type=db_type,
1287
+ load_db_if_exists=load_db_if_exists,
1288
+ db=db,
1289
+ n_jobs=n_jobs,
1290
+ verbose=verbose)
1291
+
1292
  if db and use_context:
1293
  if isinstance(document_choice, str):
1294
  # support string as well
1295
  document_choice = [document_choice]
1296
+ if not isinstance(db, Chroma) or \
1297
+ len(document_choice) == 0 or \
1298
+ len(document_choice) <= 1 and document_choice[0] == 'All':
1299
  # treat empty list as All for now, not 'None'
1300
  filter_kwargs = {}
1301
+ elif len(document_choice) > 0 and document_choice[0] == 'Only':
1302
+ # Only means All docs, but only will return sources, not LLM response
1303
+ filter_kwargs = {}
1304
  else:
1305
  if len(document_choice) >= 2:
1306
  or_filter = [{"source": {"$eq": x}} for x in document_choice]
1307
  filter_kwargs = dict(filter={"$or": or_filter})
1308
+ elif len(document_choice) > 0:
1309
  one_filter = [{"source": {"$eq": x}} for x in document_choice][0]
1310
  filter_kwargs = dict(filter=one_filter)
1311
+ else:
1312
+ filter_kwargs = {}
1313
+ if len(document_choice) == 1 and document_choice[0] == 'None':
1314
  k_db = 1
1315
  k = 0
1316
  docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:k]
1317
  # cut off so no high distance docs/sources considered
1318
  docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
1319
  scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
1320
+ if len(scores) > 0 and verbose:
1321
  print("Distance: min: %s max: %s mean: %s median: %s" %
1322
  (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
1323
  else:
1324
  docs = []
1325
  scores = []
1326
 
1327
+ if not docs and use_context and model_name not in non_hf_types:
1328
+ # if HF type and have no docs, can bail out
1329
+ return docs, None, [], False
1330
+
1331
+ if len(document_choice) > 0 and document_choice[0] == 'Only':
1332
+ # no LLM use
1333
+ return docs, None, [], False
1334
 
1335
  common_words_file = "data/NGSL_1.2_stats.csv.zip"
1336
  if os.path.isfile(common_words_file):
 
1342
  num_common = len([x.lower() in set_common for x in reduced_query_words])
1343
  frac_common = num_common / len(reduced_query) if reduced_query else 0
1344
  # FIXME: report to user bad query that uses too many common words
1345
+ if verbose:
1346
+ print("frac_common: %s" % frac_common, flush=True)
1347
+
1348
+ if len(docs) == 0:
1349
+ # avoid context == in prompt then
1350
+ use_context = False
1351
 
1352
+ if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
1353
+ # instruct-like, rather than few-shot prompt_type='plain' as default
1354
+ # but then sources confuse the model with how inserted among rest of text, so avoid
1355
+ prefix = ""
1356
+ if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
1357
+ template = """%s{context}{question}""" % prefix
1358
+ else:
1359
+ template = """%s
1360
+ ==
1361
+ {context}
1362
+ ==
1363
+ {question}""" % prefix
1364
+ prompt = PromptTemplate(
1365
+ # input_variables=["summaries", "question"],
1366
+ input_variables=["context", "question"],
1367
+ template=template,
1368
+ )
1369
+ chain = load_qa_chain(llm, prompt=prompt)
1370
+ else:
1371
+ chain = load_qa_with_sources_chain(llm)
1372
+
1373
+ if not use_context:
1374
  chain_kwargs = dict(input_documents=[], question=query)
1375
  else:
1376
  chain_kwargs = dict(input_documents=docs, question=query)
1377
 
1378
+ target = wrapped_partial(chain, chain_kwargs)
1379
+ return docs, target, scores, use_context
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1380
 
1381
+
1382
+ def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
1383
+ if verbose:
 
1384
  print("query: %s" % query, flush=True)
1385
  print("answer: %s" % answer['output_text'], flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1386
 
1387
+ if len(answer['input_documents']) == 0:
1388
+ extra = ''
1389
+ ret = answer['output_text'] + extra
1390
+ return ret, extra
1391
+
1392
+ # link
1393
+ answer_sources = [(max(0.0, 1.5 - score) / 1.5, get_url(doc)) for score, doc in
1394
+ zip(scores, answer['input_documents'])]
1395
+ answer_sources_dict = defaultdict(list)
1396
+ [answer_sources_dict[url].append(score) for score, url in answer_sources]
1397
+ answers_dict = {}
1398
+ for url, scores_url in answer_sources_dict.items():
1399
+ answers_dict[url] = np.max(scores_url)
1400
+ answer_sources = [(score, url) for url, score in answers_dict.items()]
1401
+ answer_sources.sort(key=lambda x: x[0], reverse=True)
1402
+ if show_rank:
1403
+ # answer_sources = ['%d | %s' % (1 + rank, url) for rank, (score, url) in enumerate(answer_sources)]
1404
+ # sorted_sources_urls = "Sources [Rank | Link]:<br>" + "<br>".join(answer_sources)
1405
+ answer_sources = ['%s' % url for rank, (score, url) in enumerate(answer_sources)]
1406
+ sorted_sources_urls = "Ranked Sources:<br>" + "<br>".join(answer_sources)
1407
+ else:
1408
+ answer_sources = ['<li>%.2g | %s</li>' % (score, url) for score, url in answer_sources]
1409
+ sorted_sources_urls = f"{source_prefix}<p><ul>" + "<p>".join(answer_sources)
1410
+ sorted_sources_urls += f"</ul></p>{source_postfix}"
1411
 
1412
+ if not answer['output_text'].endswith('\n'):
1413
+ answer['output_text'] += '\n'
 
 
1414
 
1415
+ if answer_with_sources:
1416
+ extra = '\n' + sorted_sources_urls
1417
+ else:
1418
+ extra = ''
1419
+ ret = answer['output_text'] + extra
1420
+ return ret, extra
1421
 
1422
 
1423
  def chunk_sources(sources, chunk_size=1024):
gradio_runner.py CHANGED
@@ -9,17 +9,33 @@ import traceback
9
  import uuid
10
  import filelock
11
  import pandas as pd
 
12
  import tabulate
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
15
  from prompter import Prompter, \
16
- prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, generate_prompt
17
  from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
18
- ping, get_short_name, get_url, makedirs
19
  from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa, langchain_modes, \
20
  inputs_kwargs_list, get_cutoffs, scratch_base_dir
21
 
22
- import gradio as gr
23
  from apscheduler.schedulers.background import BackgroundScheduler
24
 
25
 
@@ -27,12 +43,11 @@ def go_gradio(**kwargs):
27
  allow_api = kwargs['allow_api']
28
  is_public = kwargs['is_public']
29
  is_hf = kwargs['is_hf']
30
- is_low_mem = kwargs['is_low_mem']
31
  n_gpus = kwargs['n_gpus']
32
  admin_pass = kwargs['admin_pass']
33
  model_state0 = kwargs['model_state0']
34
  score_model_state0 = kwargs['score_model_state0']
35
- queue = True
36
  dbs = kwargs['dbs']
37
  db_type = kwargs['db_type']
38
  visible_langchain_modes = kwargs['visible_langchain_modes']
@@ -41,7 +56,6 @@ def go_gradio(**kwargs):
41
  enable_sources_list = kwargs['enable_sources_list']
42
  enable_url_upload = kwargs['enable_url_upload']
43
  enable_text_upload = kwargs['enable_text_upload']
44
- allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
45
  use_openai_embedding = kwargs['use_openai_embedding']
46
  hf_embedding_model = kwargs['hf_embedding_model']
47
  enable_captions = kwargs['enable_captions']
@@ -50,6 +64,8 @@ def go_gradio(**kwargs):
50
  caption_loader = kwargs['caption_loader']
51
 
52
  # easy update of kwargs needed for evaluate() etc.
 
 
53
  kwargs.update(locals())
54
 
55
  if 'mbart-' in kwargs['model_lower']:
@@ -76,8 +92,8 @@ def go_gradio(**kwargs):
76
  """
77
  else:
78
  description = more_info
79
- description += "If this host is busy, try [12B](https://gpt.h2o.ai), [30B](http://gpt2.h2o.ai), [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
80
- description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md)</p>"""
81
  if is_hf:
82
  description += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
83
 
@@ -95,6 +111,7 @@ def go_gradio(**kwargs):
95
  else:
96
  css_code = """footer {visibility: hidden}"""
97
  css_code += """
 
98
  body.dark{#warning {background-color: #555555};}
99
  #small_btn {
100
  margin: 0.6em 0em 0.55em 0;
@@ -131,7 +148,19 @@ body.dark{#warning {background-color: #555555};}
131
 
132
  Chatbot._postprocess_chat_messages = _postprocess_chat_messages
133
 
134
- theme = H2oTheme() if kwargs['h2ocolors'] else SoftTheme()
 
 
 
 
 
 
 
 
 
 
 
 
135
  demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
136
  callback = gr.CSVLogger()
137
 
@@ -173,7 +202,11 @@ body.dark{#warning {background-color: #555555};}
173
  lora_options_state = gr.State([lora_options])
174
  my_db_state = gr.State([None, None])
175
  chat_state = gr.State({})
176
- docs_state = gr.State(['All'])
 
 
 
 
177
  gr.Markdown(f"""
178
  {get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
179
 
@@ -258,10 +291,10 @@ body.dark{#warning {background-color: #555555};}
258
  radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
259
  type='value')
260
  with gr.Row():
261
- clear_chat_btn = gr.Button(value="Clear Chat", visible=True)
262
- export_chats_btn = gr.Button(value="Export Chats to Download")
263
- remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True)
264
- add_to_chats_btn = gr.Button("Import Chats from Upload")
265
  with gr.Row():
266
  chats_file = gr.File(interactive=False, label="Download Exported Chats")
267
  chatsup_output = gr.File(label="Upload Chat File(s)",
@@ -269,7 +302,7 @@ body.dark{#warning {background-color: #555555};}
269
  file_count='multiple',
270
  elem_id="warning", elem_classes="feedback")
271
  with gr.TabItem("Data Source"):
272
- langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/README_LangChain.md',
273
  from_str=True)
274
  gr.HTML(value=f"""LangChain Support Disabled<p>
275
  Run:<p>
@@ -302,7 +335,7 @@ body.dark{#warning {background-color: #555555};}
302
  with data_row2:
303
  with gr.Column(scale=50):
304
  document_choice = gr.Dropdown(docs_state.value,
305
- label="Choose Subset of Doc(s) in Collection [click get to update]",
306
  value=docs_state.value[0],
307
  interactive=True,
308
  multiselect=True,
@@ -312,6 +345,8 @@ body.dark{#warning {background-color: #555555};}
312
  ).style(full_width=False, size='sm')
313
  show_sources_btn = gr.Button(value="Show Sources",
314
  ).style(full_width=False, size='sm')
 
 
315
 
316
  # import control
317
  if kwargs['langchain_mode'] != 'Disabled':
@@ -375,7 +410,7 @@ body.dark{#warning {background-color: #555555};}
375
  with sources_row3:
376
  with gr.Column(scale=1):
377
  file_source = gr.File(interactive=False,
378
- label="Download File with Sources [click get to make file]")
379
  with gr.Column(scale=2):
380
  pass
381
  sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
@@ -411,14 +446,24 @@ body.dark{#warning {background-color: #555555};}
411
  )
412
  # FIXME: https://github.com/h2oai/h2ogpt/issues/106
413
  if os.getenv('TESTINGFAIL'):
414
- max_beams = 8 if not (is_low_mem or is_public) else 1
415
  else:
416
  max_beams = 1
417
  num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
418
  value=min(max_beams, kwargs['num_beams']), label="Beams",
419
  info="Number of searches for optimal overall probability. "
420
  "Uses more GPU memory/compute")
421
- max_max_new_tokens = 2048 if not is_low_mem else kwargs['max_new_tokens']
 
 
 
 
 
 
 
 
 
 
422
  max_new_tokens = gr.Slider(
423
  minimum=1, maximum=max_max_new_tokens, step=1,
424
  value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
@@ -450,11 +495,19 @@ body.dark{#warning {background-color: #555555};}
450
  visible=not is_public)
451
  chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
452
  visible=not is_public)
 
 
 
 
 
 
 
 
453
 
454
  with gr.TabItem("Models"):
455
- load_msg = "Load-Unload Model/LORA" if not is_public \
456
  else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
457
- load_msg2 = "Load-Unload Model/LORA 2" if not is_public \
458
  else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2"
459
  compare_checkbox = gr.components.Checkbox(label="Compare Mode",
460
  value=False, visible=not is_public)
@@ -468,7 +521,7 @@ body.dark{#warning {background-color: #555555};}
468
  lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
469
  value=kwargs['lora_weights'], visible=kwargs['show_lora'])
470
  with gr.Column(scale=1):
471
- load_model_button = gr.Button(load_msg)
472
  model_load8bit_checkbox = gr.components.Checkbox(
473
  label="Load 8-bit [requires support]",
474
  value=kwargs['load_8bit'])
@@ -476,19 +529,12 @@ body.dark{#warning {background-color: #555555};}
476
  label="Choose Devices [If not Checked, use all GPUs]",
477
  value=kwargs['infer_devices'])
478
  model_gpu = gr.Dropdown(n_gpus_list,
479
- label="GPU ID 2 [-1 = all GPUs, if Choose is enabled]",
480
  value=kwargs['gpu_id'])
481
  model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'],
482
  interactive=False)
483
  lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
484
  visible=kwargs['show_lora'], interactive=False)
485
- with gr.Row():
486
- with gr.Column(scale=50):
487
- new_model = gr.Textbox(label="New Model HF name/path")
488
- new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora'])
489
- with gr.Column(scale=1):
490
- add_model_button = gr.Button("Add new model name")
491
- add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
492
  col_model2 = gr.Column(visible=False)
493
  with col_model2:
494
  with gr.Row():
@@ -499,7 +545,7 @@ body.dark{#warning {background-color: #555555};}
499
  value=no_lora_str,
500
  visible=kwargs['show_lora'])
501
  with gr.Column(scale=1):
502
- load_model_button2 = gr.Button(load_msg2)
503
  model_load8bit_checkbox2 = gr.components.Checkbox(
504
  label="Load 8-bit 2 [requires support]",
505
  value=kwargs['load_8bit'])
@@ -508,12 +554,22 @@ body.dark{#warning {background-color: #555555};}
508
  value=kwargs[
509
  'infer_devices'])
510
  model_gpu2 = gr.Dropdown(n_gpus_list,
511
- label="GPU ID [-1 = all GPUs, if choose is enabled]",
512
  value=kwargs['gpu_id'])
513
  # no model/lora loaded ever in model2 by default
514
  model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str)
515
  lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str,
516
  visible=kwargs['show_lora'])
 
 
 
 
 
 
 
 
 
 
517
  with gr.TabItem("System"):
518
  admin_row = gr.Row()
519
  with admin_row:
@@ -530,7 +586,7 @@ body.dark{#warning {background-color: #555555};}
530
  with gr.Row():
531
  zip_btn = gr.Button("Zip")
532
  zip_text = gr.Textbox(label="Zip file name", interactive=False)
533
- file_output = gr.File(interactive=False)
534
  with gr.Row():
535
  s3up_btn = gr.Button("S3UP")
536
  s3up_text = gr.Textbox(label='S3UP result', interactive=False)
@@ -542,7 +598,7 @@ body.dark{#warning {background-color: #555555};}
542
  description += """<i><li>Conversations may be used to improve h2oGPT. Do not share sensitive information.</i></li>"""
543
  if 'h2ogpt-research' in kwargs['base_model']:
544
  description += """<i><li>Research demonstration only, not used for commercial purposes.</i></li>"""
545
- description += """<i><li>By using h2oGPT, you accept our <a href="https://github.com/h2oai/h2ogpt/blob/main/tos.md">Terms of Service</a></i></li></ul></p>"""
546
  gr.Markdown(value=description, show_label=False, interactive=False)
547
 
548
  # Get flagged data
@@ -633,24 +689,37 @@ body.dark{#warning {background-color: #555555};}
633
  api_name='add_txt_to_my' if allow_api else None) \
634
  .then(clear_textbox, outputs=user_text_text, queue=queue)
635
 
636
- get_sources1 = functools.partial(get_sources, dbs=dbs)
637
 
638
  # if change collection source, must clear doc selections from it to avoid inconsistency
639
  def clear_doc_choice():
640
- return gr.Dropdown.update(choices=['All'], value=['All'])
641
 
642
  langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice)
643
 
644
  def update_dropdown(x):
645
- return gr.Dropdown.update(choices=x, value='All')
646
 
647
- show_sources1 = functools.partial(get_source_files_given_langchain_mode, dbs=dbs)
648
  get_sources_btn.click(get_sources1, inputs=[my_db_state, langchain_mode], outputs=[file_source, docs_state],
649
  queue=queue,
650
  api_name='get_sources' if allow_api else None) \
651
  .then(fn=update_dropdown, inputs=docs_state, outputs=document_choice)
652
  # show button, else only show when add. Could add to above get_sources for download/dropdown, but bit much maybe
653
- show_sources_btn.click(fn=show_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
 
655
  def check_admin_pass(x):
656
  return gr.update(visible=x == admin_pass)
@@ -661,10 +730,6 @@ body.dark{#warning {background-color: #555555};}
661
  admin_btn.click(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row, queue=False) \
662
  .then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
663
 
664
- # Get inputs to evaluate()
665
- # don't deepcopy, can contain model itself
666
- all_kwargs = kwargs.copy()
667
- all_kwargs.update(locals())
668
  inputs_list = get_inputs_list(all_kwargs, kwargs['model_lower'])
669
  from functools import partial
670
  kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
@@ -714,7 +779,10 @@ body.dark{#warning {background-color: #555555};}
714
  """ Similar to user() """
715
  args_list = list(args)
716
 
717
- max_length_tokenize = 512 if is_low_mem else 2048
 
 
 
718
  cutoff_len = max_length_tokenize * 4 # restrict deberta related to max for LLM
719
  smodel = score_model_state0[0]
720
  stokenizer = score_model_state0[1]
@@ -811,6 +879,8 @@ body.dark{#warning {background-color: #555555};}
811
  # e.g. when user just hits enter in textbox,
812
  # else will have <human>: <bot>: on single line, which seems to be "ok" for LLM but not usual
813
  user_message1 = '\n'
 
 
814
 
815
  history = args_list[-1]
816
  if undo and history:
@@ -830,6 +900,43 @@ body.dark{#warning {background-color: #555555};}
830
  # FIXME: compare, same history for now
831
  return history + [[user_message1, None]]
832
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
833
  def bot(*args, retry=False):
834
  """
835
  bot that consumes history for user input
@@ -861,47 +968,15 @@ body.dark{#warning {background-color: #555555};}
861
  history = []
862
  yield history, ''
863
  return
864
- # ensure output will be unique to models
865
- _, _, _, max_prompt_length = get_cutoffs(is_low_mem, for_context=True)
866
- history = copy.deepcopy(history)
867
  instruction1 = history[-1][0]
868
  if not instruction1:
869
  # reject empty query, can sometimes go nuts
870
  history = []
871
  yield history, ''
872
  return
873
-
874
- context1 = ''
875
- if max_prompt_length is not None and langchain_mode1 not in ['LLM']:
876
- prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
877
- chat1 = args_list[eval_func_param_names.index('chat')]
878
- context1 = ''
879
- # - 1 below because current instruction already in history from user()
880
- for histi in range(0, len(history) - 1):
881
- data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
882
- prompt, pre_response, terminate_response, chat_sep = generate_prompt(data_point, prompt_type1,
883
- chat1, reduced=True)
884
- # md -> back to text, maybe not super important if model trained enough
885
- if not kwargs['keep_sources_in_context']:
886
- from gpt_langchain import source_prefix, source_postfix
887
- import re
888
- prompt = re.sub(f'{re.escape(source_prefix)}.*?{re.escape(source_postfix)}', '', prompt,
889
- flags=re.DOTALL)
890
- if prompt.endswith('\n<p>'):
891
- prompt = prompt[:-4]
892
- prompt = prompt.replace('<br>', chat_sep)
893
- if not prompt.endswith(chat_sep):
894
- prompt += chat_sep
895
- # most recent first, add older if can
896
- # only include desired chat history
897
- if len(prompt + context1) > max_prompt_length:
898
- break
899
- context1 = prompt + context1
900
-
901
- _, pre_response, terminate_response, chat_sep = generate_prompt({}, prompt_type1, chat1,
902
- reduced=True)
903
- if context1 and not context1.endswith(chat_sep):
904
- context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
905
  args_list[0] = instruction1 # override original instruction with history from user
906
  args_list[2] = context1
907
  fun1 = partial(evaluate,
@@ -909,8 +984,11 @@ body.dark{#warning {background-color: #555555};}
909
  my_db_state1,
910
  **kwargs_evaluate)
911
  try:
912
- for output in fun1(*tuple(args_list)):
913
- bot_message = output
 
 
 
914
  history[-1][1] = bot_message
915
  yield history, ''
916
  except StopIteration:
@@ -1067,11 +1145,11 @@ body.dark{#warning {background-color: #555555};}
1067
  if len(stepy) != 2:
1068
  # something off
1069
  return False
1070
- questionx = stepx[0].replace('<p>', '').replace('</p>', '')
1071
- answerx = stepx[1].replace('<p>', '').replace('</p>', '')
1072
 
1073
- questiony = stepy[0].replace('<p>', '').replace('</p>', '')
1074
- answery = stepy[1].replace('<p>', '').replace('</p>', '')
1075
 
1076
  if questionx != questiony or answerx != answery:
1077
  return False
@@ -1221,7 +1299,9 @@ body.dark{#warning {background-color: #555555};}
1221
  lora_weights = ''
1222
 
1223
  all_kwargs1['lora_weights'] = lora_weights.strip()
1224
- model1, tokenizer1, device1 = get_model(**all_kwargs1)
 
 
1225
  clear_torch_cache()
1226
 
1227
  if kwargs['debug']:
@@ -1242,7 +1322,7 @@ body.dark{#warning {background-color: #555555};}
1242
  chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
1243
  nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat)
1244
  if not is_public:
1245
- load_model_event = load_model_button.click(**load_model_args) \
1246
  .then(**prompt_update_args) \
1247
  .then(**chatbot_update_args) \
1248
  .then(**nochat_update_args) \
@@ -1255,7 +1335,8 @@ body.dark{#warning {background-color: #555555};}
1255
  prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
1256
  chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
1257
  if not is_public:
1258
- load_model_event2 = load_model_button2.click(**load_model_args2) \
 
1259
  .then(**prompt_update_args2) \
1260
  .then(**chatbot_update_args2) \
1261
  .then(clear_torch_cache)
@@ -1331,6 +1412,27 @@ body.dark{#warning {background-color: #555555};}
1331
  submit_event3d, submit_event3f,
1332
  submit_event_nochat],
1333
  queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1334
  demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] else None)
1335
 
1336
  demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
@@ -1339,7 +1441,7 @@ body.dark{#warning {background-color: #555555};}
1339
  scheduler = BackgroundScheduler()
1340
  scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20)
1341
  if is_public and \
1342
- kwargs['base_model'] not in ['gptj', 'llama']:
1343
  # FIXME: disable for gptj, langchain or gpt4all modify print itself
1344
  # FIXME: and any multi-threaded/async print will enter model output!
1345
  scheduler.add_job(func=ping, trigger="interval", seconds=60)
@@ -1348,14 +1450,15 @@ body.dark{#warning {background-color: #555555};}
1348
  # import control
1349
  if kwargs['langchain_mode'] == 'Disabled' and \
1350
  os.environ.get("TEST_LANGCHAIN_IMPORT") and \
1351
- kwargs['base_model'] not in ['gptj', 'llama']:
1352
  assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
1353
  assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
1354
 
1355
  demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
1356
  favicon_path=favicon_path, prevent_thread_lock=True,
1357
  auth=kwargs['auth'])
1358
- print("Started GUI", flush=True)
 
1359
  if kwargs['block_gradio_exit']:
1360
  demo.block_thread()
1361
 
@@ -1384,7 +1487,7 @@ def get_inputs_list(inputs_dict, model_lower):
1384
  return inputs_list
1385
 
1386
 
1387
- def get_sources(db1, langchain_mode, dbs=None):
1388
  if langchain_mode in ['ChatLLM', 'LLM']:
1389
  source_files_added = "NA"
1390
  source_list = []
@@ -1407,7 +1510,7 @@ def get_sources(db1, langchain_mode, dbs=None):
1407
  sources_file = 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))
1408
  with open(sources_file, "wt") as f:
1409
  f.write(source_files_added)
1410
- source_list = ['All'] + source_list
1411
  return sources_file, source_list
1412
 
1413
 
@@ -1471,7 +1574,7 @@ def _update_user_db(file, db1, x, y, dbs=None, db_type=None, langchain_mode='Use
1471
  if langchain_mode == 'MyData':
1472
  if db1[0] is not None:
1473
  # then add
1474
- add_to_db(db1[0], sources, db_type=db_type)
1475
  else:
1476
  assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1
1477
  # then create
@@ -1486,13 +1589,13 @@ def _update_user_db(file, db1, x, y, dbs=None, db_type=None, langchain_mode='Use
1486
  hf_embedding_model=hf_embedding_model)
1487
  if db1[0] is None:
1488
  db1[1] = None
1489
- source_files_added = get_source_files(db1[0], exceptions=exceptions)
1490
  return db1, x, y, source_files_added
1491
  else:
1492
  persist_directory = 'db_dir_%s' % langchain_mode
1493
  if langchain_mode in dbs and dbs[langchain_mode] is not None:
1494
  # then add
1495
- add_to_db(dbs[langchain_mode], sources, db_type=db_type)
1496
  else:
1497
  # then create
1498
  db = get_db(sources, use_openai_embedding=use_openai_embedding,
@@ -1504,11 +1607,11 @@ def _update_user_db(file, db1, x, y, dbs=None, db_type=None, langchain_mode='Use
1504
  # NOTE we do not return db, because function call always same code path
1505
  # return dbs[langchain_mode], x, y
1506
  # db in this code path is updated in place
1507
- source_files_added = get_source_files(dbs[langchain_mode], exceptions=exceptions)
1508
  return x, y, source_files_added
1509
 
1510
 
1511
- def get_source_files_given_langchain_mode(db1, langchain_mode='UserData', dbs=None):
1512
  with filelock.FileLock("db_%s.lock" % langchain_mode.replace(' ', '_')):
1513
  if langchain_mode in ['wiki_full']:
1514
  # NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
@@ -1519,17 +1622,31 @@ def get_source_files_given_langchain_mode(db1, langchain_mode='UserData', dbs=No
1519
  db = dbs[langchain_mode]
1520
  else:
1521
  db = None
1522
- return get_source_files(db, exceptions=None)
 
 
 
 
 
1523
 
1524
 
1525
- def get_source_files(db, exceptions=None):
1526
  if exceptions is None:
1527
  exceptions = []
1528
 
1529
- if db is not None:
1530
- metadatas = db.get()['metadatas']
 
 
 
 
 
 
 
 
1531
  else:
1532
- metadatas = []
 
1533
 
1534
  # below automatically de-dups
1535
  from gpt_langchain import get_url
@@ -1558,28 +1675,28 @@ def get_source_files(db, exceptions=None):
1558
  <html>
1559
  <body>
1560
  <p>
1561
- Sources: <br>
1562
  </p>
1563
  <div style="overflow-y: auto;height:400px">
1564
- {0}
1565
  {1}
 
1566
  </div>
1567
  </body>
1568
  </html>
1569
- """.format(source_files_added, exceptions_html)
1570
  elif metadatas:
1571
  source_files_added = """\
1572
  <html>
1573
  <body>
1574
  <p>
1575
- Sources: <br>
1576
  </p>
1577
  <div style="overflow-y: auto;height:400px">
1578
- {0}
1579
  </div>
1580
  </body>
1581
  </html>
1582
- """.format(source_files_added)
1583
  elif exceptions_html:
1584
  source_files_added = """\
1585
  <html>
@@ -1594,6 +1711,31 @@ def get_source_files(db, exceptions=None):
1594
  </html>
1595
  """.format(exceptions_html)
1596
  else:
1597
- source_files_added = ""
 
 
 
1598
 
1599
  return source_files_added
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import uuid
10
  import filelock
11
  import pandas as pd
12
+ import requests
13
  import tabulate
14
 
15
+ # This is a hack to prevent Gradio from phoning home when it gets imported
16
+ os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
17
+
18
+
19
+ def my_get(url, **kwargs):
20
+ print('Gradio HTTP request redirected to localhost :)', flush=True)
21
+ kwargs.setdefault('allow_redirects', True)
22
+ return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
23
+
24
+
25
+ original_get = requests.get
26
+ requests.get = my_get
27
+ import gradio as gr
28
+
29
+ requests.get = original_get
30
+
31
  from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
32
  from prompter import Prompter, \
33
+ prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, generate_prompt, non_hf_types
34
  from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
35
+ ping, get_short_name, get_url, makedirs, get_kwargs
36
  from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa, langchain_modes, \
37
  inputs_kwargs_list, get_cutoffs, scratch_base_dir
38
 
 
39
  from apscheduler.schedulers.background import BackgroundScheduler
40
 
41
 
 
43
  allow_api = kwargs['allow_api']
44
  is_public = kwargs['is_public']
45
  is_hf = kwargs['is_hf']
46
+ memory_restriction_level = kwargs['memory_restriction_level']
47
  n_gpus = kwargs['n_gpus']
48
  admin_pass = kwargs['admin_pass']
49
  model_state0 = kwargs['model_state0']
50
  score_model_state0 = kwargs['score_model_state0']
 
51
  dbs = kwargs['dbs']
52
  db_type = kwargs['db_type']
53
  visible_langchain_modes = kwargs['visible_langchain_modes']
 
56
  enable_sources_list = kwargs['enable_sources_list']
57
  enable_url_upload = kwargs['enable_url_upload']
58
  enable_text_upload = kwargs['enable_text_upload']
 
59
  use_openai_embedding = kwargs['use_openai_embedding']
60
  hf_embedding_model = kwargs['hf_embedding_model']
61
  enable_captions = kwargs['enable_captions']
 
64
  caption_loader = kwargs['caption_loader']
65
 
66
  # easy update of kwargs needed for evaluate() etc.
67
+ queue = True
68
+ allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
69
  kwargs.update(locals())
70
 
71
  if 'mbart-' in kwargs['model_lower']:
 
92
  """
93
  else:
94
  description = more_info
95
+ description += "If this host is busy, try [12B](https://gpt.h2o.ai), [Falcon 40B](http://falcon.h2o.ai), [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
96
+ description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/docs/tos.md)</p>"""
97
  if is_hf:
98
  description += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
99
 
 
111
  else:
112
  css_code = """footer {visibility: hidden}"""
113
  css_code += """
114
+ @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
115
  body.dark{#warning {background-color: #555555};}
116
  #small_btn {
117
  margin: 0.6em 0em 0.55em 0;
 
148
 
149
  Chatbot._postprocess_chat_messages = _postprocess_chat_messages
150
 
151
+ if kwargs['gradio_offline_level'] >= 0:
152
+ # avoid GoogleFont that pulls from internet
153
+ if kwargs['gradio_offline_level'] == 1:
154
+ # front end would still have to download fonts or have cached it at some point
155
+ base_font = 'Source Sans Pro'
156
+ else:
157
+ base_font = 'Helvetica'
158
+ theme_kwargs = dict(font=(base_font, 'ui-sans-serif', 'system-ui', 'sans-serif'),
159
+ font_mono=('IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'))
160
+ else:
161
+ theme_kwargs = dict()
162
+
163
+ theme = H2oTheme(**theme_kwargs) if kwargs['h2ocolors'] else SoftTheme(**theme_kwargs)
164
  demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
165
  callback = gr.CSVLogger()
166
 
 
202
  lora_options_state = gr.State([lora_options])
203
  my_db_state = gr.State([None, None])
204
  chat_state = gr.State({})
205
+ # make user default first and default choice, dedup
206
+ docs_state00 = kwargs['document_choice'] + ['All', 'Only', 'None']
207
+ docs_state0 = []
208
+ [docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
209
+ docs_state = gr.State(docs_state0) # first is chosen as default
210
  gr.Markdown(f"""
211
  {get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
212
 
 
291
  radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
292
  type='value')
293
  with gr.Row():
294
+ clear_chat_btn = gr.Button(value="Clear Chat", visible=True).style(size='sm')
295
+ export_chats_btn = gr.Button(value="Export Chats to Download").style(size='sm')
296
+ remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True).style(size='sm')
297
+ add_to_chats_btn = gr.Button("Import Chats from Upload").style(size='sm')
298
  with gr.Row():
299
  chats_file = gr.File(interactive=False, label="Download Exported Chats")
300
  chatsup_output = gr.File(label="Upload Chat File(s)",
 
302
  file_count='multiple',
303
  elem_id="warning", elem_classes="feedback")
304
  with gr.TabItem("Data Source"):
305
+ langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/docs/README_LangChain.md',
306
  from_str=True)
307
  gr.HTML(value=f"""LangChain Support Disabled<p>
308
  Run:<p>
 
335
  with data_row2:
336
  with gr.Column(scale=50):
337
  document_choice = gr.Dropdown(docs_state.value,
338
+ label="Choose Subset of Doc(s) in Collection [click get sources to update]",
339
  value=docs_state.value[0],
340
  interactive=True,
341
  multiselect=True,
 
345
  ).style(full_width=False, size='sm')
346
  show_sources_btn = gr.Button(value="Show Sources",
347
  ).style(full_width=False, size='sm')
348
+ refresh_sources_btn = gr.Button(value="Refresh Sources",
349
+ ).style(full_width=False, size='sm')
350
 
351
  # import control
352
  if kwargs['langchain_mode'] != 'Disabled':
 
410
  with sources_row3:
411
  with gr.Column(scale=1):
412
  file_source = gr.File(interactive=False,
413
+ label="Download File w/Sources [click get sources to make file]")
414
  with gr.Column(scale=2):
415
  pass
416
  sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
 
446
  )
447
  # FIXME: https://github.com/h2oai/h2ogpt/issues/106
448
  if os.getenv('TESTINGFAIL'):
449
+ max_beams = 8 if not (memory_restriction_level or is_public) else 1
450
  else:
451
  max_beams = 1
452
  num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
453
  value=min(max_beams, kwargs['num_beams']), label="Beams",
454
  info="Number of searches for optimal overall probability. "
455
  "Uses more GPU memory/compute")
456
+ # FIXME: 2048 should be tokenizer.model_max_length, but may not even have model yet
457
+ if kwargs['max_new_tokens']:
458
+ max_max_new_tokens = kwargs['max_new_tokens']
459
+ elif memory_restriction_level == 1:
460
+ max_max_new_tokens = 768
461
+ elif memory_restriction_level == 2:
462
+ max_max_new_tokens = 512
463
+ elif memory_restriction_level >= 3:
464
+ max_max_new_tokens = 256
465
+ else:
466
+ max_max_new_tokens = 2048
467
  max_new_tokens = gr.Slider(
468
  minimum=1, maximum=max_max_new_tokens, step=1,
469
  value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
 
495
  visible=not is_public)
496
  chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
497
  visible=not is_public)
498
+ count_chat_tokens_btn = gr.Button(value="Count Chat Tokens", visible=not is_public)
499
+ chat_token_count = gr.Textbox(label="Chat Token Count", value=None,
500
+ visible=not is_public, interactive=False)
501
+ top_k_docs = gr.Slider(minimum=0, maximum=20, step=1,
502
+ value=kwargs['top_k_docs'],
503
+ label="Number of document chunks",
504
+ info="For LangChain",
505
+ visible=not is_public)
506
 
507
  with gr.TabItem("Models"):
508
+ load_msg = "Load-Unload Model/LORA [unload works if did not use --base_model]" if not is_public \
509
  else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
510
+ load_msg2 = "Load-Unload Model/LORA 2 [unload works if did not use --base_model]" if not is_public \
511
  else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2"
512
  compare_checkbox = gr.components.Checkbox(label="Compare Mode",
513
  value=False, visible=not is_public)
 
521
  lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
522
  value=kwargs['lora_weights'], visible=kwargs['show_lora'])
523
  with gr.Column(scale=1):
524
+ load_model_button = gr.Button(load_msg).style(full_width=False, size='sm')
525
  model_load8bit_checkbox = gr.components.Checkbox(
526
  label="Load 8-bit [requires support]",
527
  value=kwargs['load_8bit'])
 
529
  label="Choose Devices [If not Checked, use all GPUs]",
530
  value=kwargs['infer_devices'])
531
  model_gpu = gr.Dropdown(n_gpus_list,
532
+ label="GPU ID [-1 = all GPUs, if Choose is enabled]",
533
  value=kwargs['gpu_id'])
534
  model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'],
535
  interactive=False)
536
  lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
537
  visible=kwargs['show_lora'], interactive=False)
 
 
 
 
 
 
 
538
  col_model2 = gr.Column(visible=False)
539
  with col_model2:
540
  with gr.Row():
 
545
  value=no_lora_str,
546
  visible=kwargs['show_lora'])
547
  with gr.Column(scale=1):
548
+ load_model_button2 = gr.Button(load_msg2).style(full_width=False, size='sm')
549
  model_load8bit_checkbox2 = gr.components.Checkbox(
550
  label="Load 8-bit 2 [requires support]",
551
  value=kwargs['load_8bit'])
 
554
  value=kwargs[
555
  'infer_devices'])
556
  model_gpu2 = gr.Dropdown(n_gpus_list,
557
+ label="GPU ID 2 [-1 = all GPUs, if choose is enabled]",
558
  value=kwargs['gpu_id'])
559
  # no model/lora loaded ever in model2 by default
560
  model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str)
561
  lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str,
562
  visible=kwargs['show_lora'])
563
+ with gr.Row():
564
+ with gr.Column(scale=50):
565
+ new_model = gr.Textbox(label="New Model HF name/path")
566
+ with gr.Row():
567
+ add_model_button = gr.Button("Add new model name").style(full_width=False, size='sm')
568
+ with gr.Column(scale=50):
569
+ new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora'])
570
+ with gr.Row():
571
+ add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora']).style(
572
+ full_width=False, size='sm')
573
  with gr.TabItem("System"):
574
  admin_row = gr.Row()
575
  with admin_row:
 
586
  with gr.Row():
587
  zip_btn = gr.Button("Zip")
588
  zip_text = gr.Textbox(label="Zip file name", interactive=False)
589
+ file_output = gr.File(interactive=False, label="Zip file to Download")
590
  with gr.Row():
591
  s3up_btn = gr.Button("S3UP")
592
  s3up_text = gr.Textbox(label='S3UP result', interactive=False)
 
598
  description += """<i><li>Conversations may be used to improve h2oGPT. Do not share sensitive information.</i></li>"""
599
  if 'h2ogpt-research' in kwargs['base_model']:
600
  description += """<i><li>Research demonstration only, not used for commercial purposes.</i></li>"""
601
+ description += """<i><li>By using h2oGPT, you accept our <a href="https://github.com/h2oai/h2ogpt/blob/main/docs/tos.md">Terms of Service</a></i></li></ul></p>"""
602
  gr.Markdown(value=description, show_label=False, interactive=False)
603
 
604
  # Get flagged data
 
689
  api_name='add_txt_to_my' if allow_api else None) \
690
  .then(clear_textbox, outputs=user_text_text, queue=queue)
691
 
692
+ get_sources1 = functools.partial(get_sources, dbs=dbs, docs_state0=docs_state0)
693
 
694
  # if change collection source, must clear doc selections from it to avoid inconsistency
695
  def clear_doc_choice():
696
+ return gr.Dropdown.update(choices=docs_state0, value=[docs_state0[0]])
697
 
698
  langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice)
699
 
700
  def update_dropdown(x):
701
+ return gr.Dropdown.update(choices=x, value=[docs_state0[0]])
702
 
 
703
  get_sources_btn.click(get_sources1, inputs=[my_db_state, langchain_mode], outputs=[file_source, docs_state],
704
  queue=queue,
705
  api_name='get_sources' if allow_api else None) \
706
  .then(fn=update_dropdown, inputs=docs_state, outputs=document_choice)
707
  # show button, else only show when add. Could add to above get_sources for download/dropdown, but bit much maybe
708
+ show_sources1 = functools.partial(get_source_files_given_langchain_mode, dbs=dbs)
709
+ show_sources_btn.click(fn=show_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text,
710
+ api_name='show_sources' if allow_api else None)
711
+
712
+ # Get inputs to evaluate() and make_db()
713
+ # don't deepcopy, can contain model itself
714
+ all_kwargs = kwargs.copy()
715
+ all_kwargs.update(locals())
716
+
717
+ refresh_sources1 = functools.partial(update_and_get_source_files_given_langchain_mode,
718
+ **get_kwargs(update_and_get_source_files_given_langchain_mode,
719
+ exclude_names=['db1', 'langchain_mode'],
720
+ **all_kwargs))
721
+ refresh_sources_btn.click(fn=refresh_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text,
722
+ api_name='refresh_sources' if allow_api else None)
723
 
724
  def check_admin_pass(x):
725
  return gr.update(visible=x == admin_pass)
 
730
  admin_btn.click(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row, queue=False) \
731
  .then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
732
 
 
 
 
 
733
  inputs_list = get_inputs_list(all_kwargs, kwargs['model_lower'])
734
  from functools import partial
735
  kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
 
779
  """ Similar to user() """
780
  args_list = list(args)
781
 
782
+ if memory_restriction_level > 0:
783
+ max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256
784
+ else:
785
+ max_length_tokenize = 2048 - 256
786
  cutoff_len = max_length_tokenize * 4 # restrict deberta related to max for LLM
787
  smodel = score_model_state0[0]
788
  stokenizer = score_model_state0[1]
 
879
  # e.g. when user just hits enter in textbox,
880
  # else will have <human>: <bot>: on single line, which seems to be "ok" for LLM but not usual
881
  user_message1 = '\n'
882
+ # ensure good visually, else markdown ignores multiple \n
883
+ user_message1 = user_message1.replace('\n', '<br>')
884
 
885
  history = args_list[-1]
886
  if undo and history:
 
900
  # FIXME: compare, same history for now
901
  return history + [[user_message1, None]]
902
 
903
+ def history_to_context(history, langchain_mode1, prompt_type1, chat1):
904
+ # ensure output will be unique to models
905
+ # FIXME: hard-coded 2048 implicitly passed:
906
+ _, _, _, max_prompt_length = get_cutoffs(memory_restriction_level, for_context=True)
907
+ history = copy.deepcopy(history)
908
+
909
+ context1 = ''
910
+ if max_prompt_length is not None and langchain_mode1 not in ['LLM']:
911
+ context1 = ''
912
+ # - 1 below because current instruction already in history from user()
913
+ for histi in range(0, len(history) - 1):
914
+ data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
915
+ prompt, pre_response, terminate_response, chat_sep = generate_prompt(data_point, prompt_type1,
916
+ chat1, reduced=True)
917
+ # md -> back to text, maybe not super important if model trained enough
918
+ if not kwargs['keep_sources_in_context']:
919
+ from gpt_langchain import source_prefix, source_postfix
920
+ import re
921
+ prompt = re.sub(f'{re.escape(source_prefix)}.*?{re.escape(source_postfix)}', '', prompt,
922
+ flags=re.DOTALL)
923
+ if prompt.endswith('\n<p>'):
924
+ prompt = prompt[:-4]
925
+ prompt = prompt.replace('<br>', chat_sep)
926
+ if not prompt.endswith(chat_sep):
927
+ prompt += chat_sep
928
+ # most recent first, add older if can
929
+ # only include desired chat history
930
+ if len(prompt + context1) > max_prompt_length:
931
+ break
932
+ context1 = prompt + context1
933
+
934
+ _, pre_response, terminate_response, chat_sep = generate_prompt({}, prompt_type1, chat1,
935
+ reduced=True)
936
+ if context1 and not context1.endswith(chat_sep):
937
+ context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
938
+ return context1
939
+
940
  def bot(*args, retry=False):
941
  """
942
  bot that consumes history for user input
 
968
  history = []
969
  yield history, ''
970
  return
 
 
 
971
  instruction1 = history[-1][0]
972
  if not instruction1:
973
  # reject empty query, can sometimes go nuts
974
  history = []
975
  yield history, ''
976
  return
977
+ prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
978
+ chat1 = args_list[eval_func_param_names.index('chat')]
979
+ context1 = history_to_context(history, langchain_mode1, prompt_type1, chat1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
980
  args_list[0] = instruction1 # override original instruction with history from user
981
  args_list[2] = context1
982
  fun1 = partial(evaluate,
 
984
  my_db_state1,
985
  **kwargs_evaluate)
986
  try:
987
+ for output_fun in fun1(*tuple(args_list)):
988
+ output = output_fun['response']
989
+ extra = output_fun['sources'] # FIXME: can show sources in separate text box etc.
990
+ # ensure good visually, else markdown ignores multiple \n
991
+ bot_message = output.replace('\n', '<br>')
992
  history[-1][1] = bot_message
993
  yield history, ''
994
  except StopIteration:
 
1145
  if len(stepy) != 2:
1146
  # something off
1147
  return False
1148
+ questionx = stepx[0].replace('<p>', '').replace('</p>', '') if stepx[0] is not None else None
1149
+ answerx = stepx[1].replace('<p>', '').replace('</p>', '') if stepx[1] is not None else None
1150
 
1151
+ questiony = stepy[0].replace('<p>', '').replace('</p>', '') if stepy[0] is not None else None
1152
+ answery = stepy[1].replace('<p>', '').replace('</p>', '') if stepy[1] is not None else None
1153
 
1154
  if questionx != questiony or answerx != answery:
1155
  return False
 
1299
  lora_weights = ''
1300
 
1301
  all_kwargs1['lora_weights'] = lora_weights.strip()
1302
+ model1, tokenizer1, device1 = get_model(reward_type=False,
1303
+ **get_kwargs(get_model, exclude_names=['reward_type'],
1304
+ **all_kwargs1))
1305
  clear_torch_cache()
1306
 
1307
  if kwargs['debug']:
 
1322
  chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
1323
  nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat)
1324
  if not is_public:
1325
+ load_model_event = load_model_button.click(**load_model_args, api_name='load_model' if allow_api else None) \
1326
  .then(**prompt_update_args) \
1327
  .then(**chatbot_update_args) \
1328
  .then(**nochat_update_args) \
 
1335
  prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
1336
  chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
1337
  if not is_public:
1338
+ load_model_event2 = load_model_button2.click(**load_model_args2,
1339
+ api_name='load_model2' if allow_api else None) \
1340
  .then(**prompt_update_args2) \
1341
  .then(**chatbot_update_args2) \
1342
  .then(clear_torch_cache)
 
1412
  submit_event3d, submit_event3f,
1413
  submit_event_nochat],
1414
  queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
1415
+
1416
+ def count_chat_tokens(model_state1, chat1, prompt_type1):
1417
+ if model_state1 and not isinstance(model_state1[1], str):
1418
+ tokenizer = model_state1[1]
1419
+ elif model_state0 and not isinstance(model_state0[1], str):
1420
+ tokenizer = model_state0[1]
1421
+ else:
1422
+ tokenizer = None
1423
+ if tokenizer is not None:
1424
+ langchain_mode1 = 'ChatLLM'
1425
+ # fake user message to mimic bot()
1426
+ chat1 = copy.deepcopy(chat1)
1427
+ chat1 = chat1 + [['user_message1', None]]
1428
+ context1 = history_to_context(chat1, langchain_mode1, prompt_type1, chat1)
1429
+ return str(tokenizer(context1, return_tensors="pt")['input_ids'].shape[1])
1430
+ else:
1431
+ return "N/A"
1432
+
1433
+ count_chat_tokens_btn.click(fn=count_chat_tokens, inputs=[model_state, text_output, prompt_type],
1434
+ outputs=chat_token_count, api_name='count_tokens' if allow_api else None)
1435
+
1436
  demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] else None)
1437
 
1438
  demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
 
1441
  scheduler = BackgroundScheduler()
1442
  scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20)
1443
  if is_public and \
1444
+ kwargs['base_model'] not in non_hf_types:
1445
  # FIXME: disable for gptj, langchain or gpt4all modify print itself
1446
  # FIXME: and any multi-threaded/async print will enter model output!
1447
  scheduler.add_job(func=ping, trigger="interval", seconds=60)
 
1450
  # import control
1451
  if kwargs['langchain_mode'] == 'Disabled' and \
1452
  os.environ.get("TEST_LANGCHAIN_IMPORT") and \
1453
+ kwargs['base_model'] not in non_hf_types:
1454
  assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
1455
  assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
1456
 
1457
  demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
1458
  favicon_path=favicon_path, prevent_thread_lock=True,
1459
  auth=kwargs['auth'])
1460
+ if kwargs['verbose']:
1461
+ print("Started GUI", flush=True)
1462
  if kwargs['block_gradio_exit']:
1463
  demo.block_thread()
1464
 
 
1487
  return inputs_list
1488
 
1489
 
1490
+ def get_sources(db1, langchain_mode, dbs=None, docs_state0=None):
1491
  if langchain_mode in ['ChatLLM', 'LLM']:
1492
  source_files_added = "NA"
1493
  source_list = []
 
1510
  sources_file = 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))
1511
  with open(sources_file, "wt") as f:
1512
  f.write(source_files_added)
1513
+ source_list = docs_state0 + source_list
1514
  return sources_file, source_list
1515
 
1516
 
 
1574
  if langchain_mode == 'MyData':
1575
  if db1[0] is not None:
1576
  # then add
1577
+ db, num_new_sources, new_sources_metadata = add_to_db(db1[0], sources, db_type=db_type)
1578
  else:
1579
  assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1
1580
  # then create
 
1589
  hf_embedding_model=hf_embedding_model)
1590
  if db1[0] is None:
1591
  db1[1] = None
1592
+ source_files_added = get_source_files(db=db1[0], exceptions=exceptions)
1593
  return db1, x, y, source_files_added
1594
  else:
1595
  persist_directory = 'db_dir_%s' % langchain_mode
1596
  if langchain_mode in dbs and dbs[langchain_mode] is not None:
1597
  # then add
1598
+ db, num_new_sources, new_sources_metadata = add_to_db(dbs[langchain_mode], sources, db_type=db_type)
1599
  else:
1600
  # then create
1601
  db = get_db(sources, use_openai_embedding=use_openai_embedding,
 
1607
  # NOTE we do not return db, because function call always same code path
1608
  # return dbs[langchain_mode], x, y
1609
  # db in this code path is updated in place
1610
+ source_files_added = get_source_files(db=dbs[langchain_mode], exceptions=exceptions)
1611
  return x, y, source_files_added
1612
 
1613
 
1614
+ def get_db(db1, langchain_mode, dbs=None):
1615
  with filelock.FileLock("db_%s.lock" % langchain_mode.replace(' ', '_')):
1616
  if langchain_mode in ['wiki_full']:
1617
  # NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
 
1622
  db = dbs[langchain_mode]
1623
  else:
1624
  db = None
1625
+ return db
1626
+
1627
+
1628
+ def get_source_files_given_langchain_mode(db1, langchain_mode='UserData', dbs=None):
1629
+ db = get_db(db1, langchain_mode, dbs=dbs)
1630
+ return get_source_files(db=db, exceptions=None)
1631
 
1632
 
1633
+ def get_source_files(db=None, exceptions=None, metadatas=None):
1634
  if exceptions is None:
1635
  exceptions = []
1636
 
1637
+ # only should be one source, not confused
1638
+ assert db is not None or metadatas is not None
1639
+
1640
+ if metadatas is None:
1641
+ source_label = "Sources:"
1642
+ if db is not None:
1643
+ metadatas = db.get()['metadatas']
1644
+ else:
1645
+ metadatas = []
1646
+ adding_new = False
1647
  else:
1648
+ source_label = "New Sources:"
1649
+ adding_new = True
1650
 
1651
  # below automatically de-dups
1652
  from gpt_langchain import get_url
 
1675
  <html>
1676
  <body>
1677
  <p>
1678
+ {0} <br>
1679
  </p>
1680
  <div style="overflow-y: auto;height:400px">
 
1681
  {1}
1682
+ {2}
1683
  </div>
1684
  </body>
1685
  </html>
1686
+ """.format(source_label, source_files_added, exceptions_html)
1687
  elif metadatas:
1688
  source_files_added = """\
1689
  <html>
1690
  <body>
1691
  <p>
1692
+ {0} <br>
1693
  </p>
1694
  <div style="overflow-y: auto;height:400px">
1695
+ {1}
1696
  </div>
1697
  </body>
1698
  </html>
1699
+ """.format(source_label, source_files_added)
1700
  elif exceptions_html:
1701
  source_files_added = """\
1702
  <html>
 
1711
  </html>
1712
  """.format(exceptions_html)
1713
  else:
1714
+ if adding_new:
1715
+ source_files_added = "No New Sources"
1716
+ else:
1717
+ source_files_added = "No Sources"
1718
 
1719
  return source_files_added
1720
+
1721
+
1722
+ def update_and_get_source_files_given_langchain_mode(db1, langchain_mode, dbs=None, first_para=None,
1723
+ text_limit=None, chunk=None, chunk_size=None,
1724
+ user_path=None, db_type=None, load_db_if_exists=None,
1725
+ n_jobs=None, verbose=None):
1726
+ db = get_db(db1, langchain_mode, dbs=dbs)
1727
+
1728
+ from gpt_langchain import make_db
1729
+ db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=False,
1730
+ hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1731
+ first_para=first_para, text_limit=text_limit, chunk=chunk,
1732
+ chunk_size=chunk_size,
1733
+ langchain_mode=langchain_mode,
1734
+ user_path=user_path,
1735
+ db_type=db_type,
1736
+ load_db_if_exists=load_db_if_exists,
1737
+ db=db,
1738
+ n_jobs=n_jobs,
1739
+ verbose=verbose)
1740
+ # return only new sources with text saying such
1741
+ return get_source_files(db=None, exceptions=None, metadatas=new_sources_metadata)
gradio_themes.py CHANGED
@@ -1,7 +1,10 @@
1
  from __future__ import annotations
 
 
 
2
  from gradio.themes.soft import Soft
3
  from gradio.themes import Color
4
- from gradio.themes.utils import colors, sizes
5
 
6
  h2o_yellow = Color(
7
  name="yellow",
@@ -43,6 +46,22 @@ class H2oTheme(Soft):
43
  spacing_size: sizes.Size | str = sizes.spacing_md,
44
  radius_size: sizes.Size | str = sizes.radius_md,
45
  text_size: sizes.Size | str = sizes.text_lg,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  ):
47
  super().__init__(
48
  primary_hue=primary_hue,
@@ -51,6 +70,8 @@ class H2oTheme(Soft):
51
  spacing_size=spacing_size,
52
  radius_size=radius_size,
53
  text_size=text_size,
 
 
54
  )
55
  super().set(
56
  link_text_color="#3344DD",
@@ -89,6 +110,22 @@ class SoftTheme(Soft):
89
  spacing_size: sizes.Size | str = sizes.spacing_md,
90
  radius_size: sizes.Size | str = sizes.radius_md,
91
  text_size: sizes.Size | str = sizes.text_md,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  ):
93
  super().__init__(
94
  primary_hue=primary_hue,
@@ -97,6 +134,8 @@ class SoftTheme(Soft):
97
  spacing_size=spacing_size,
98
  radius_size=radius_size,
99
  text_size=text_size,
 
 
100
  )
101
 
102
 
@@ -125,7 +164,7 @@ def get_h2o_title(title):
125
  <h1 style="line-height:60px">{title}</h1>
126
  </div>
127
  <div style="float:right; height: 80px; width: 80px; margin-top:-100px">
128
- <img src=https://raw.githubusercontent.com/h2oai/h2ogpt/main/h2o-qr.png></img>
129
  </div>
130
  """
131
 
 
1
  from __future__ import annotations
2
+
3
+ from typing import Iterable
4
+
5
  from gradio.themes.soft import Soft
6
  from gradio.themes import Color
7
+ from gradio.themes.utils import colors, sizes, fonts
8
 
9
  h2o_yellow = Color(
10
  name="yellow",
 
46
  spacing_size: sizes.Size | str = sizes.spacing_md,
47
  radius_size: sizes.Size | str = sizes.radius_md,
48
  text_size: sizes.Size | str = sizes.text_lg,
49
+ font: fonts.Font
50
+ | str
51
+ | Iterable[fonts.Font | str] = (
52
+ fonts.GoogleFont("Montserrat"),
53
+ "ui-sans-serif",
54
+ "system-ui",
55
+ "sans-serif",
56
+ ),
57
+ font_mono: fonts.Font
58
+ | str
59
+ | Iterable[fonts.Font | str] = (
60
+ fonts.GoogleFont("IBM Plex Mono"),
61
+ "ui-monospace",
62
+ "Consolas",
63
+ "monospace",
64
+ ),
65
  ):
66
  super().__init__(
67
  primary_hue=primary_hue,
 
70
  spacing_size=spacing_size,
71
  radius_size=radius_size,
72
  text_size=text_size,
73
+ font=font,
74
+ font_mono=font_mono,
75
  )
76
  super().set(
77
  link_text_color="#3344DD",
 
110
  spacing_size: sizes.Size | str = sizes.spacing_md,
111
  radius_size: sizes.Size | str = sizes.radius_md,
112
  text_size: sizes.Size | str = sizes.text_md,
113
+ font: fonts.Font
114
+ | str
115
+ | Iterable[fonts.Font | str] = (
116
+ fonts.GoogleFont("Montserrat"),
117
+ "ui-sans-serif",
118
+ "system-ui",
119
+ "sans-serif",
120
+ ),
121
+ font_mono: fonts.Font
122
+ | str
123
+ | Iterable[fonts.Font | str] = (
124
+ fonts.GoogleFont("IBM Plex Mono"),
125
+ "ui-monospace",
126
+ "Consolas",
127
+ "monospace",
128
+ ),
129
  ):
130
  super().__init__(
131
  primary_hue=primary_hue,
 
134
  spacing_size=spacing_size,
135
  radius_size=radius_size,
136
  text_size=text_size,
137
+ font=font,
138
+ font_mono=font_mono,
139
  )
140
 
141
 
 
164
  <h1 style="line-height:60px">{title}</h1>
165
  </div>
166
  <div style="float:right; height: 80px; width: 80px; margin-top:-100px">
167
+ <img src=https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png></img>
168
  </div>
169
  """
170
 
h2oai_pipeline.py CHANGED
@@ -2,36 +2,57 @@ from transformers import TextGenerationPipeline
2
  from transformers.pipelines.text_generation import ReturnType
3
 
4
  from stopping import get_stopping
5
-
6
- prompt_type = "human_bot"
7
- human = "<human>:"
8
- bot = "<bot>:"
9
-
10
- # human-bot interaction like OIG dataset
11
- prompt = """{human} {instruction}
12
- {bot}""".format(
13
- human=human,
14
- instruction="{instruction}",
15
- bot=bot,
16
- )
17
 
18
 
19
  class H2OTextGenerationPipeline(TextGenerationPipeline):
20
- def __init__(self, *args, use_prompter=False, debug=False, chat=False, stream_output=False,
21
- sanitize_bot_response=True, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  super().__init__(*args, **kwargs)
23
- self.use_prompter = use_prompter
24
  self.prompt_text = None
 
 
 
25
  if self.use_prompter:
26
- from prompter import Prompter
27
- self.prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
 
 
 
 
 
28
  else:
29
  self.prompter = None
 
 
 
30
  self.sanitize_bot_response = sanitize_bot_response
 
31
 
32
  def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
33
- prompt_text = prompt.format(instruction=prompt_text)
 
 
34
  self.prompt_text = prompt_text
 
 
 
35
  return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
36
  **generate_kwargs)
37
 
@@ -43,12 +64,65 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
43
  outputs = rec['generated_text']
44
  outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
45
  sanitize_bot_response=self.sanitize_bot_response)
 
 
46
  else:
47
- outputs = rec['generated_text'].split(bot)[1].strip().split(human)[0].strip()
48
  rec['generated_text'] = outputs
49
  return records
50
 
51
  def _forward(self, model_inputs, **generate_kwargs):
52
- stopping_criteria = get_stopping(prompt_type, self.tokenizer, self.device, human=human, bot=bot)
53
- generate_kwargs['stopping_criteria'] = stopping_criteria
54
- return super()._forward(model_inputs, **generate_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from transformers.pipelines.text_generation import ReturnType
3
 
4
  from stopping import get_stopping
5
+ from prompter import Prompter
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  class H2OTextGenerationPipeline(TextGenerationPipeline):
9
+ def __init__(self, *args, debug=False, chat=False, stream_output=False,
10
+ sanitize_bot_response=True,
11
+ use_prompter=True, prompter=None, prompt_type=None,
12
+ max_input_tokens=2048 - 256, **kwargs):
13
+ """
14
+ HF-like pipeline, but handle instruction prompting and stopping (for some models)
15
+ :param args:
16
+ :param debug:
17
+ :param chat:
18
+ :param stream_output:
19
+ :param sanitize_bot_response:
20
+ :param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter
21
+ :param prompter: prompter, can pass if have already
22
+ :param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
23
+ If use_prompter, then will make prompter and use it.
24
+ :param max_input_tokens:
25
+ :param kwargs:
26
+ """
27
  super().__init__(*args, **kwargs)
 
28
  self.prompt_text = None
29
+ self.use_prompter = use_prompter
30
+ self.prompt_type = prompt_type
31
+ self.prompter = prompter
32
  if self.use_prompter:
33
+ if self.prompter is not None:
34
+ assert self.prompter.prompt_type is not None
35
+ else:
36
+ self.prompter = Prompter(self.prompt_type, debug=debug, chat=chat, stream_output=stream_output)
37
+ self.human = self.prompter.humanstr
38
+ self.bot = self.prompter.botstr
39
+ self.can_stop = True
40
  else:
41
  self.prompter = None
42
+ self.human = None
43
+ self.bot = None
44
+ self.can_stop = False
45
  self.sanitize_bot_response = sanitize_bot_response
46
+ self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
47
 
48
  def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
49
+ data_point = dict(context='', instruction=prompt_text, input='')
50
+ if self.prompter is not None:
51
+ prompt_text = self.prompter.generate_prompt(data_point)
52
  self.prompt_text = prompt_text
53
+ if handle_long_generation is None:
54
+ # forces truncation of inputs to avoid critical failure
55
+ handle_long_generation = 'hole'
56
  return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
57
  **generate_kwargs)
58
 
 
64
  outputs = rec['generated_text']
65
  outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
66
  sanitize_bot_response=self.sanitize_bot_response)
67
+ elif self.bot and self.human:
68
+ outputs = rec['generated_text'].split(self.bot)[1].strip().split(self.human)[0].strip()
69
  else:
70
+ outputs = rec['generated_text']
71
  rec['generated_text'] = outputs
72
  return records
73
 
74
  def _forward(self, model_inputs, **generate_kwargs):
75
+ if self.can_stop:
76
+ stopping_criteria = get_stopping(self.prompt_type, self.tokenizer, self.device, human=self.human,
77
+ bot=self.bot)
78
+ generate_kwargs['stopping_criteria'] = stopping_criteria
79
+ # return super()._forward(model_inputs, **generate_kwargs)
80
+ return self.__forward(model_inputs, **generate_kwargs)
81
+
82
+ # FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
83
+ # FIXME: https://github.com/h2oai/h2ogpt/issues/172
84
+ def __forward(self, model_inputs, **generate_kwargs):
85
+ input_ids = model_inputs["input_ids"]
86
+ attention_mask = model_inputs.get("attention_mask", None)
87
+ # Allow empty prompts
88
+ if input_ids.shape[1] == 0:
89
+ input_ids = None
90
+ attention_mask = None
91
+ in_b = 1
92
+ else:
93
+ in_b = input_ids.shape[0]
94
+ prompt_text = model_inputs.pop("prompt_text")
95
+
96
+ ## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
97
+ ## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
98
+ # generate_kwargs = copy.deepcopy(generate_kwargs)
99
+ prefix_length = generate_kwargs.pop("prefix_length", 0)
100
+ if prefix_length > 0:
101
+ has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
102
+ "generation_config" in generate_kwargs
103
+ and generate_kwargs["generation_config"].max_new_tokens is not None
104
+ )
105
+ if not has_max_new_tokens:
106
+ generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
107
+ generate_kwargs["max_length"] += prefix_length
108
+ has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
109
+ "generation_config" in generate_kwargs
110
+ and generate_kwargs["generation_config"].min_new_tokens is not None
111
+ )
112
+ if not has_min_new_tokens and "min_length" in generate_kwargs:
113
+ generate_kwargs["min_length"] += prefix_length
114
+
115
+ # BS x SL
116
+ generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
117
+ out_b = generated_sequence.shape[0]
118
+ if self.framework == "pt":
119
+ generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
120
+ elif self.framework == "tf":
121
+ from transformers import is_tf_available
122
+ if is_tf_available():
123
+ import tensorflow as tf
124
+ generated_sequence = tf.reshape(generated_sequence,
125
+ (in_b, out_b // in_b, *generated_sequence.shape[1:]))
126
+ else:
127
+ raise ValueError("TF not avaialble.")
128
+ return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
prompter.py CHANGED
@@ -1,6 +1,8 @@
1
  import time
2
  from enum import Enum
3
 
 
 
4
 
5
  class PromptType(Enum):
6
  plain = 0
@@ -17,6 +19,10 @@ class PromptType(Enum):
17
  open_assistant = 11
18
  wizard_lm = 12
19
  wizard_mega = 13
 
 
 
 
20
 
21
 
22
  prompt_type_to_model_name = {
@@ -26,6 +32,7 @@ prompt_type_to_model_name = {
26
  'EleutherAI/pythia-12b',
27
  'EleutherAI/pythia-12b-deduped',
28
  'EleutherAI/gpt-neox-20b',
 
29
  'decapoda-research/llama-7b-hf',
30
  'decapoda-research/llama-13b-hf',
31
  'decapoda-research/llama-30b-hf',
@@ -39,7 +46,8 @@ prompt_type_to_model_name = {
39
  'mosaicml/mpt-7b-instruct', # internal code handles instruct
40
  'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
41
  'gptj', # internally handles prompting
42
- 'llama', # internally handles prompting
 
43
  ],
44
  'prompt_answer': [
45
  'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
@@ -47,6 +55,7 @@ prompt_type_to_model_name = {
47
  'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
48
  'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
49
  'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
 
50
  ],
51
  'instruct': [],
52
  'instruct_with_end': ['databricks/dolly-v2-12b'],
@@ -58,7 +67,9 @@ prompt_type_to_model_name = {
58
  'h2oai/h2ogpt-oig-oasst1-512-6_9b',
59
  'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
60
  'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
61
- 'h2oai/h2ogpt-research-oasst1-512-30b', # private
 
 
62
  ],
63
  'dai_faq': [],
64
  'summarize': [],
@@ -83,7 +94,8 @@ for p in PromptType:
83
 
84
 
85
  def get_prompt(prompt_type, chat, context, reduced):
86
- if prompt_type in [-1, "-1", "plain"]:
 
87
  promptA = promptB = PreInstruct = PreInput = PreResponse = ''
88
  terminate_response = []
89
  chat_sep = ''
@@ -95,11 +107,14 @@ def get_prompt(prompt_type, chat, context, reduced):
95
  chat_sep = '\n'
96
  humanstr = ''
97
  botstr = ''
98
- elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
 
 
 
99
  promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (
100
- chat and reduced) else ''
101
  promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
102
- chat and reduced) else ''
103
 
104
  PreInstruct = """
105
  ### Instruction:
@@ -112,18 +127,20 @@ def get_prompt(prompt_type, chat, context, reduced):
112
  PreResponse = """
113
  ### Response:
114
  """
115
- if prompt_type in [7, "7", "instruct_with_end"]:
 
116
  terminate_response = ['### End']
117
  else:
118
  terminate_response = None
119
  chat_sep = '\n'
120
  humanstr = PreInstruct
121
  botstr = PreResponse
122
- elif prompt_type in [1, "1", "quality"]:
 
123
  promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (
124
- chat and reduced) else ''
125
  promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (
126
- chat and reduced) else ''
127
 
128
  PreInstruct = """
129
  ### Instruction:
@@ -140,10 +157,14 @@ def get_prompt(prompt_type, chat, context, reduced):
140
  chat_sep = '\n'
141
  humanstr = PreInstruct # first thing human says
142
  botstr = PreResponse # first thing bot says
143
- elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
 
 
 
144
  human = '<human>:'
145
  bot = "<bot>:"
146
- if reduced or context or prompt_type in [2, "2", "human_bot"]:
 
147
  preprompt = ''
148
  else:
149
  cur_date = time.strftime('%Y-%m-%d')
@@ -174,7 +195,8 @@ Current Time: {}
174
  chat_sep = '\n'
175
  humanstr = human # tag before human talks
176
  botstr = bot # tag before bot talks
177
- elif prompt_type in [3, "3", "dai_faq"]:
 
178
  promptA = ''
179
  promptB = 'Answer the following Driverless AI question.\n'
180
 
@@ -191,7 +213,8 @@ Current Time: {}
191
  chat_sep = terminate_response
192
  humanstr = PreInstruct
193
  botstr = PreResponse
194
- elif prompt_type in [5, "5", "summarize"]:
 
195
  promptA = promptB = PreInput = ''
196
  PreInstruct = '## Main Text\n\n'
197
  PreResponse = '\n\n## Summary\n\n'
@@ -199,10 +222,11 @@ Current Time: {}
199
  chat_sep = '\n'
200
  humanstr = PreInstruct
201
  botstr = PreResponse
202
- elif prompt_type in [6, "6", "instruct_vicuna"]:
 
203
  promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
204
  "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (
205
- chat and reduced) else ''
206
 
207
  PreInstruct = """
208
  ### Human:
@@ -218,7 +242,8 @@ Current Time: {}
218
  chat_sep = '\n'
219
  humanstr = PreInstruct
220
  botstr = PreResponse
221
- elif prompt_type in [10, "10", "prompt_answer"]:
 
222
  preprompt = ''
223
  prompt_tokens = "<|prompt|>"
224
  answer_tokens = "<|answer|>"
@@ -232,7 +257,8 @@ Current Time: {}
232
  chat_sep = eos
233
  humanstr = prompt_tokens
234
  botstr = answer_tokens
235
- elif prompt_type in [11, "11", "open_assistant"]:
 
236
  # From added_tokens.json
237
  preprompt = ''
238
  prompt_tokens = "<|prompter|>"
@@ -248,20 +274,22 @@ Current Time: {}
248
  chat_sep = eos
249
  humanstr = prompt_tokens
250
  botstr = answer_tokens
251
- elif prompt_type in [12, "12", "wizard_lm"]:
 
252
  # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
253
  preprompt = ''
254
  start = ''
255
  promptB = promptA = '%s%s' % (preprompt, start)
256
  PreInstruct = ""
257
  PreInput = None
258
- PreResponse = "\n\n### Response"
259
  eos = "</s>"
260
  terminate_response = [PreResponse, eos]
261
  chat_sep = eos
262
  humanstr = promptA
263
  botstr = PreResponse
264
- elif prompt_type in [13, "13", "wizard_mega"]:
 
265
  preprompt = ''
266
  start = ''
267
  promptB = promptA = '%s%s' % (preprompt, start)
@@ -276,6 +304,75 @@ Current Time: {}
276
  chat_sep = '\n'
277
  humanstr = PreInstruct
278
  botstr = PreResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  else:
280
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
281
 
@@ -412,7 +509,7 @@ class Prompter(object):
412
  multi_output = len(outputs) > 1
413
 
414
  for oi, output in enumerate(outputs):
415
- if self.prompt_type in [0, '0', 'plain']:
416
  output = clean_response(output)
417
  elif prompt is None:
418
  # then use most basic parsing like pipeline
 
1
  import time
2
  from enum import Enum
3
 
4
+ non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
5
+
6
 
7
  class PromptType(Enum):
8
  plain = 0
 
19
  open_assistant = 11
20
  wizard_lm = 12
21
  wizard_mega = 13
22
+ instruct_vicuna2 = 14
23
+ instruct_vicuna3 = 15
24
+ wizard2 = 16
25
+ wizard3 = 17
26
 
27
 
28
  prompt_type_to_model_name = {
 
32
  'EleutherAI/pythia-12b',
33
  'EleutherAI/pythia-12b-deduped',
34
  'EleutherAI/gpt-neox-20b',
35
+ 'openlm-research/open_llama_7b_700bt_preview',
36
  'decapoda-research/llama-7b-hf',
37
  'decapoda-research/llama-13b-hf',
38
  'decapoda-research/llama-30b-hf',
 
46
  'mosaicml/mpt-7b-instruct', # internal code handles instruct
47
  'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
48
  'gptj', # internally handles prompting
49
+ 'llama', # plain, or need to choose prompt_type for given TheBloke model
50
+ 'gpt4all_llama', # internally handles prompting
51
  ],
52
  'prompt_answer': [
53
  'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
 
55
  'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
56
  'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
57
  'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
58
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
59
  ],
60
  'instruct': [],
61
  'instruct_with_end': ['databricks/dolly-v2-12b'],
 
67
  'h2oai/h2ogpt-oig-oasst1-512-6_9b',
68
  'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
69
  'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
70
+ 'h2oai/h2ogpt-research-oasst1-512-30b',
71
+ 'h2oai/h2ogpt-oasst1-falcon-40b',
72
+ 'h2oai/h2ogpt-oig-oasst1-falcon-40b',
73
  ],
74
  'dai_faq': [],
75
  'summarize': [],
 
94
 
95
 
96
  def get_prompt(prompt_type, chat, context, reduced):
97
+ if prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
98
+ PromptType.plain.name]:
99
  promptA = promptB = PreInstruct = PreInput = PreResponse = ''
100
  terminate_response = []
101
  chat_sep = ''
 
107
  chat_sep = '\n'
108
  humanstr = ''
109
  botstr = ''
110
+ elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
111
+ PromptType.instruct.name] + [PromptType.instruct_with_end.value,
112
+ str(PromptType.instruct_with_end.value),
113
+ PromptType.instruct_with_end.name]:
114
  promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (
115
+ chat and reduced) else ''
116
  promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
117
+ chat and reduced) else ''
118
 
119
  PreInstruct = """
120
  ### Instruction:
 
127
  PreResponse = """
128
  ### Response:
129
  """
130
+ if prompt_type in [PromptType.instruct_with_end.value, str(PromptType.instruct_with_end.value),
131
+ PromptType.instruct_with_end.name]:
132
  terminate_response = ['### End']
133
  else:
134
  terminate_response = None
135
  chat_sep = '\n'
136
  humanstr = PreInstruct
137
  botstr = PreResponse
138
+ elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
139
+ PromptType.quality.name]:
140
  promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (
141
+ chat and reduced) else ''
142
  promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (
143
+ chat and reduced) else ''
144
 
145
  PreInstruct = """
146
  ### Instruction:
 
157
  chat_sep = '\n'
158
  humanstr = PreInstruct # first thing human says
159
  botstr = PreResponse # first thing bot says
160
+ elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
161
+ PromptType.human_bot.name] + [PromptType.human_bot_orig.value,
162
+ str(PromptType.human_bot_orig.value),
163
+ PromptType.human_bot_orig.name]:
164
  human = '<human>:'
165
  bot = "<bot>:"
166
+ if reduced or context or prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
167
+ PromptType.human_bot.name]:
168
  preprompt = ''
169
  else:
170
  cur_date = time.strftime('%Y-%m-%d')
 
195
  chat_sep = '\n'
196
  humanstr = human # tag before human talks
197
  botstr = bot # tag before bot talks
198
+ elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
199
+ PromptType.dai_faq.name]:
200
  promptA = ''
201
  promptB = 'Answer the following Driverless AI question.\n'
202
 
 
213
  chat_sep = terminate_response
214
  humanstr = PreInstruct
215
  botstr = PreResponse
216
+ elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
217
+ PromptType.summarize.name]:
218
  promptA = promptB = PreInput = ''
219
  PreInstruct = '## Main Text\n\n'
220
  PreResponse = '\n\n## Summary\n\n'
 
222
  chat_sep = '\n'
223
  humanstr = PreInstruct
224
  botstr = PreResponse
225
+ elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
226
+ PromptType.instruct_vicuna.name]:
227
  promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
228
  "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (
229
+ chat and reduced) else ''
230
 
231
  PreInstruct = """
232
  ### Human:
 
242
  chat_sep = '\n'
243
  humanstr = PreInstruct
244
  botstr = PreResponse
245
+ elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
246
+ PromptType.prompt_answer.name]:
247
  preprompt = ''
248
  prompt_tokens = "<|prompt|>"
249
  answer_tokens = "<|answer|>"
 
257
  chat_sep = eos
258
  humanstr = prompt_tokens
259
  botstr = answer_tokens
260
+ elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
261
+ PromptType.open_assistant.name]:
262
  # From added_tokens.json
263
  preprompt = ''
264
  prompt_tokens = "<|prompter|>"
 
274
  chat_sep = eos
275
  humanstr = prompt_tokens
276
  botstr = answer_tokens
277
+ elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
278
+ PromptType.wizard_lm.name]:
279
  # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
280
  preprompt = ''
281
  start = ''
282
  promptB = promptA = '%s%s' % (preprompt, start)
283
  PreInstruct = ""
284
  PreInput = None
285
+ PreResponse = "\n\n### Response\n"
286
  eos = "</s>"
287
  terminate_response = [PreResponse, eos]
288
  chat_sep = eos
289
  humanstr = promptA
290
  botstr = PreResponse
291
+ elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
292
+ PromptType.wizard_mega.name]:
293
  preprompt = ''
294
  start = ''
295
  promptB = promptA = '%s%s' % (preprompt, start)
 
304
  chat_sep = '\n'
305
  humanstr = PreInstruct
306
  botstr = PreResponse
307
+ elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
308
+ PromptType.instruct_vicuna2.name]:
309
+ promptA = promptB = "" if not (
310
+ chat and reduced) else ''
311
+
312
+ PreInstruct = """
313
+ HUMAN:
314
+ """
315
+
316
+ PreInput = None
317
+
318
+ PreResponse = """
319
+ ASSISTANT:
320
+ """
321
+ terminate_response = [
322
+ 'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
323
+ chat_sep = '\n'
324
+ humanstr = PreInstruct
325
+ botstr = PreResponse
326
+ elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
327
+ PromptType.instruct_vicuna3.name]:
328
+ promptA = promptB = "" if not (
329
+ chat and reduced) else ''
330
+
331
+ PreInstruct = """
332
+ ### User:
333
+ """
334
+
335
+ PreInput = None
336
+
337
+ PreResponse = """
338
+ ### Assistant:
339
+ """
340
+ terminate_response = [
341
+ '### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
342
+ chat_sep = '\n'
343
+ humanstr = PreInstruct
344
+ botstr = PreResponse
345
+ elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
346
+ PromptType.wizard2.name]:
347
+ # https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
348
+ preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request."""
349
+ start = ''
350
+ promptB = promptA = '%s%s' % (preprompt, start)
351
+ PreInstruct = """
352
+ ### Instruction:
353
+ """
354
+ PreInput = None
355
+ PreResponse = """
356
+ ### Response:
357
+ """
358
+ terminate_response = [PreResponse]
359
+ chat_sep = '\n'
360
+ humanstr = PreInstruct
361
+ botstr = PreResponse
362
+ elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
363
+ PromptType.wizard3.name]:
364
+ # https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
365
+ preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
366
+ start = ''
367
+ promptB = promptA = '%s%s' % (preprompt, start)
368
+ PreInstruct = """USER: """
369
+ PreInput = None
370
+ PreResponse = """ASSISTANT: """
371
+ terminate_response = [PreResponse]
372
+ chat_sep = '\n'
373
+ humanstr = PreInstruct
374
+ botstr = PreResponse
375
+
376
  else:
377
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
378
 
 
509
  multi_output = len(outputs) > 1
510
 
511
  for oi, output in enumerate(outputs):
512
+ if self.prompt_type in [PromptType.plain.value, str(PromptType.plain.value), PromptType.plain.name]:
513
  output = clean_response(output)
514
  elif prompt is None:
515
  # then use most basic parsing like pipeline
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
  # for generate (gradio server) and finetune
2
  datasets==2.12.0
3
  sentencepiece==0.1.97
4
- accelerate==0.18.0
5
  gradio==3.31.0
6
  huggingface_hub==0.14.1
7
  appdirs==1.4.4
@@ -18,8 +17,9 @@ numpy==1.24.2
18
  pandas==2.0.0
19
  matplotlib==3.7.1
20
  loralib==0.1.1
21
- bitsandbytes==0.38.1
22
- git+https://github.com/huggingface/peft.git@098962fa6515f2e4fe83a757f5995d3ffbb1c373
 
23
  transformers==4.28.1
24
  tokenizers==0.13.3
25
  APScheduler==3.10.1
@@ -50,18 +50,15 @@ pypandoc_binary==1.11
50
  openpyxl==3.1.2
51
  lm_dataformat==0.0.20
52
  bioc==2.0
53
- # To install with constraints
54
- # grep -v '#\|peft' requirements.txt > req_constraints.txt ; pip install -r requirements_optional_langchain.txt -c req_constraints.txt
55
 
 
 
56
  # optional for chat with PDF
57
- langchain==0.0.178
58
  pypdf==3.8.1
59
  tiktoken==0.3.3
60
  # avoid textract, requires old six
61
  #textract==1.6.5
62
- # choose:
63
- #faiss-cpu
64
- faiss-gpu==1.7.2
65
 
66
  # for HF embeddings
67
  sentence_transformers==2.2.2
@@ -69,7 +66,7 @@ sentence_transformers==2.2.2
69
  openai==0.27.6
70
 
71
  # local vector db
72
- chromadb==0.3.23
73
  # server vector db
74
  #pymilvus==2.2.8
75
 
@@ -92,8 +89,12 @@ requests_file==1.5.1
92
  tabulate==0.9.0
93
  # FYI pandoc already part of requirements.txt
94
 
95
- jq==1.4.1
 
96
 
97
  # to check licenses
98
  # Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
99
  pip-licenses==4.3.0
 
 
 
 
1
  # for generate (gradio server) and finetune
2
  datasets==2.12.0
3
  sentencepiece==0.1.97
 
4
  gradio==3.31.0
5
  huggingface_hub==0.14.1
6
  appdirs==1.4.4
 
17
  pandas==2.0.0
18
  matplotlib==3.7.1
19
  loralib==0.1.1
20
+ bitsandbytes==0.39.0
21
+ accelerate==0.19.0
22
+ git+https://github.com/huggingface/peft.git@3714aa2fff158fdfa637b2b65952580801d890b2
23
  transformers==4.28.1
24
  tokenizers==0.13.3
25
  APScheduler==3.10.1
 
50
  openpyxl==3.1.2
51
  lm_dataformat==0.0.20
52
  bioc==2.0
 
 
53
 
54
+ # falcon
55
+ einops==0.6.1
56
  # optional for chat with PDF
57
+ langchain==0.0.183
58
  pypdf==3.8.1
59
  tiktoken==0.3.3
60
  # avoid textract, requires old six
61
  #textract==1.6.5
 
 
 
62
 
63
  # for HF embeddings
64
  sentence_transformers==2.2.2
 
66
  openai==0.27.6
67
 
68
  # local vector db
69
+ chromadb==0.3.25
70
  # server vector db
71
  #pymilvus==2.2.8
72
 
 
89
  tabulate==0.9.0
90
  # FYI pandoc already part of requirements.txt
91
 
92
+ # JSONLoader, but makes some trouble for some users
93
+ # jq==1.4.1
94
 
95
  # to check licenses
96
  # Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
97
  pip-licenses==4.3.0
98
+
99
+ # weaviate vector db
100
+ weaviate-client==3.19.2
stopping.py CHANGED
@@ -1,6 +1,8 @@
1
  import torch
2
  from transformers import StoppingCriteria, StoppingCriteriaList
3
 
 
 
4
 
5
  class StoppingCriteriaSub(StoppingCriteria):
6
 
@@ -24,14 +26,14 @@ class StoppingCriteriaSub(StoppingCriteria):
24
 
25
 
26
  def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:"):
27
- if prompt_type in ['human_bot', 'instruct_vicuna', 'instruct_with_end']:
28
- if prompt_type == 'human_bot':
29
  # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
30
  # stopping only starts once output is beyond prompt
31
  # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
32
  stop_words = [human, bot, '\n' + human, '\n' + bot]
33
  encounters = [1, 2]
34
- elif prompt_type == 'instruct_vicuna':
35
  # even below is not enough, generic strings and many ways to encode
36
  stop_words = [
37
  '### Human:',
@@ -58,7 +60,7 @@ def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:")
58
  stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
59
  stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
60
  # avoid padding in front of tokens
61
- if tokenizer.pad_token:
62
  stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
63
  # handle fake \n added
64
  stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
 
1
  import torch
2
  from transformers import StoppingCriteria, StoppingCriteriaList
3
 
4
+ from prompter import PromptType
5
+
6
 
7
  class StoppingCriteriaSub(StoppingCriteria):
8
 
 
26
 
27
 
28
  def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:"):
29
+ if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
30
+ if prompt_type == PromptType.human_bot.name:
31
  # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
32
  # stopping only starts once output is beyond prompt
33
  # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
34
  stop_words = [human, bot, '\n' + human, '\n' + bot]
35
  encounters = [1, 2]
36
+ elif prompt_type == PromptType.instruct_vicuna.name:
37
  # even below is not enough, generic strings and many ways to encode
38
  stop_words = [
39
  '### Human:',
 
60
  stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
61
  stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
62
  # avoid padding in front of tokens
63
+ if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
64
  stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
65
  # handle fake \n added
66
  stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
utils.py CHANGED
@@ -1,6 +1,7 @@
1
  import contextlib
2
  import functools
3
  import hashlib
 
4
  import os
5
  import gc
6
  import pathlib
@@ -16,6 +17,8 @@ from datetime import datetime
16
  import filelock
17
  import requests, uuid
18
  from typing import Tuple, Callable, Dict
 
 
19
  from concurrent.futures import ProcessPoolExecutor
20
  import numpy as np
21
  import pandas as pd
@@ -371,18 +374,15 @@ def sanitize_filename(name):
371
  return name
372
 
373
 
374
- def shutil_rmtree_simple(*args, **kwargs):
375
- path = args[0]
376
- assert not os.path.samefile(path, "./tmp"), "Should not be trying to remove entire data directory: %s" % str(path)
377
- # print("Removing path %s" % args[0]) # for debugging
378
  return shutil.rmtree(*args, **kwargs)
379
 
380
 
381
- def remove_simple(path: str):
382
  try:
383
  if path is not None and os.path.exists(path):
384
  if os.path.isdir(path):
385
- shutil_rmtree_simple(path, ignore_errors=True)
386
  else:
387
  with contextlib.suppress(FileNotFoundError):
388
  os.remove(path)
@@ -408,7 +408,7 @@ def atomic_move_simple(src, dst):
408
  shutil.move(src, dst)
409
  except (shutil.Error, FileExistsError):
410
  pass
411
- remove_simple(src)
412
 
413
 
414
  def download_simple(url, dest=None, print_func=None):
@@ -481,7 +481,7 @@ def download(url, dest=None, dest_path=None):
481
  shutil.move(dest_tmp, dest)
482
  except FileExistsError:
483
  pass
484
- remove_simple(dest_tmp)
485
  return dest
486
 
487
 
@@ -766,3 +766,78 @@ def call_subprocess_onetask(func, args=None, kwargs=None):
766
  with ProcessPoolExecutor(max_workers=1) as executor:
767
  future = executor.submit(_traced_func, *args, **kwargs)
768
  return future.result()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import contextlib
2
  import functools
3
  import hashlib
4
+ import inspect
5
  import os
6
  import gc
7
  import pathlib
 
17
  import filelock
18
  import requests, uuid
19
  from typing import Tuple, Callable, Dict
20
+ from tqdm.auto import tqdm
21
+ from joblib import Parallel
22
  from concurrent.futures import ProcessPoolExecutor
23
  import numpy as np
24
  import pandas as pd
 
374
  return name
375
 
376
 
377
+ def shutil_rmtree(*args, **kwargs):
 
 
 
378
  return shutil.rmtree(*args, **kwargs)
379
 
380
 
381
+ def remove(path: str):
382
  try:
383
  if path is not None and os.path.exists(path):
384
  if os.path.isdir(path):
385
+ shutil_rmtree(path, ignore_errors=True)
386
  else:
387
  with contextlib.suppress(FileNotFoundError):
388
  os.remove(path)
 
408
  shutil.move(src, dst)
409
  except (shutil.Error, FileExistsError):
410
  pass
411
+ remove(src)
412
 
413
 
414
  def download_simple(url, dest=None, print_func=None):
 
481
  shutil.move(dest_tmp, dest)
482
  except FileExistsError:
483
  pass
484
+ remove(dest_tmp)
485
  return dest
486
 
487
 
 
766
  with ProcessPoolExecutor(max_workers=1) as executor:
767
  future = executor.submit(_traced_func, *args, **kwargs)
768
  return future.result()
769
+
770
+
771
+ class ProgressParallel(Parallel):
772
+ def __init__(self, use_tqdm=True, total=None, *args, **kwargs):
773
+ self._use_tqdm = use_tqdm
774
+ self._total = total
775
+ super().__init__(*args, **kwargs)
776
+
777
+ def __call__(self, *args, **kwargs):
778
+ with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
779
+ return Parallel.__call__(self, *args, **kwargs)
780
+
781
+ def print_progress(self):
782
+ if self._total is None:
783
+ self._pbar.total = self.n_dispatched_tasks
784
+ self._pbar.n = self.n_completed_tasks
785
+ self._pbar.refresh()
786
+
787
+
788
+ def get_kwargs(func, exclude_names=None, **kwargs):
789
+ func_names = list(inspect.signature(func).parameters)
790
+ missing_kwargs = [x for x in func_names if x not in kwargs]
791
+ if exclude_names:
792
+ for k in exclude_names:
793
+ if k in missing_kwargs:
794
+ missing_kwargs.remove(k)
795
+ if k in func_names:
796
+ func_names.remove(k)
797
+ assert not missing_kwargs, "Missing %s" % missing_kwargs
798
+ kwargs = {k: v for k, v in kwargs.items() if k in func_names}
799
+ return kwargs
800
+
801
+
802
+ import pkg_resources
803
+ have_faiss = False
804
+
805
+ try:
806
+ assert pkg_resources.get_distribution('faiss') is not None
807
+ have_faiss = True
808
+ except (pkg_resources.DistributionNotFound, AssertionError):
809
+ pass
810
+ try:
811
+ assert pkg_resources.get_distribution('faiss_gpu') is not None
812
+ have_faiss = True
813
+ except (pkg_resources.DistributionNotFound, AssertionError):
814
+ pass
815
+ try:
816
+ assert pkg_resources.get_distribution('faiss_cpu') is not None
817
+ have_faiss = True
818
+ except (pkg_resources.DistributionNotFound, AssertionError):
819
+ pass
820
+
821
+
822
+ def hash_file(file):
823
+ try:
824
+ import hashlib
825
+
826
+ # BUF_SIZE is totally arbitrary, change for your app!
827
+ BUF_SIZE = 65536 # lets read stuff in 64kb chunks!
828
+
829
+ md5 = hashlib.md5()
830
+ #sha1 = hashlib.sha1()
831
+
832
+ with open(file, 'rb') as f:
833
+ while True:
834
+ data = f.read(BUF_SIZE)
835
+ if not data:
836
+ break
837
+ md5.update(data)
838
+ #sha1.update(data)
839
+ except BaseException as e:
840
+ print("Cannot hash %s due to %s" % (file, str(e)))
841
+ traceback.print_exc()
842
+ md5 = None
843
+ return md5.hexdigest()