pseudotensor commited on
Commit
b368114
1 Parent(s): b64f5c9

Update with h2oGPT hash c37e5ee65166e4d964435193d5d8c23aaa8d3f09

Browse files
Files changed (8) hide show
  1. client_test.py +20 -7
  2. enums.py +8 -0
  3. evaluate_params.py +1 -0
  4. gen.py +59 -23
  5. gpt_langchain.py +145 -38
  6. gradio_runner.py +15 -1
  7. prompter.py +25 -0
  8. utils.py +17 -1
client_test.py CHANGED
@@ -69,6 +69,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
69
  top_k_docs=3,
70
  langchain_mode='Disabled',
71
  langchain_action=LangChainAction.QUERY.value,
 
72
  prompt_dict=None):
73
  from collections import OrderedDict
74
  kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
@@ -95,6 +96,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
95
  iinput_nochat='', # only for chat=False
96
  langchain_mode=langchain_mode,
97
  langchain_action=langchain_action,
 
98
  top_k_docs=top_k_docs,
99
  chunk=True,
100
  chunk_size=512,
@@ -203,6 +205,7 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
203
  iinput_nochat='',
204
  langchain_mode='Disabled',
205
  langchain_action=LangChainAction.QUERY.value,
 
206
  top_k_docs=4,
207
  document_subset=DocumentChoices.Relevant.name,
208
  document_choice=[],
@@ -225,23 +228,30 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
225
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
226
  def test_client_chat(prompt_type='human_bot'):
227
  return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
228
- langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
 
 
229
 
230
 
231
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
232
  def test_client_chat_stream(prompt_type='human_bot'):
233
  return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
234
  stream_output=True, max_new_tokens=512,
235
- langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
 
 
236
 
237
 
238
- def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode, langchain_action,
 
239
  prompt_dict=None):
240
  client = get_client(serialize=False)
241
 
242
  kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
243
- max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
 
244
  langchain_action=langchain_action,
 
245
  prompt_dict=prompt_dict)
246
  return run_client(client, prompt, args, kwargs)
247
 
@@ -285,15 +295,18 @@ def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
285
  def test_client_nochat_stream(prompt_type='human_bot'):
286
  return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
287
  stream_output=True, max_new_tokens=512,
288
- langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
 
 
289
 
290
 
291
- def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode, langchain_action):
 
292
  client = get_client(serialize=False)
293
 
294
  kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
295
  max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
296
- langchain_action=langchain_action)
297
  return run_client_gen(client, prompt, args, kwargs)
298
 
299
 
 
69
  top_k_docs=3,
70
  langchain_mode='Disabled',
71
  langchain_action=LangChainAction.QUERY.value,
72
+ langchain_agents=[],
73
  prompt_dict=None):
74
  from collections import OrderedDict
75
  kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
 
96
  iinput_nochat='', # only for chat=False
97
  langchain_mode=langchain_mode,
98
  langchain_action=langchain_action,
99
+ langchain_agents=langchain_agents,
100
  top_k_docs=top_k_docs,
101
  chunk=True,
102
  chunk_size=512,
 
205
  iinput_nochat='',
206
  langchain_mode='Disabled',
207
  langchain_action=LangChainAction.QUERY.value,
208
+ langchain_agents=[],
209
  top_k_docs=4,
210
  document_subset=DocumentChoices.Relevant.name,
211
  document_choice=[],
 
228
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
229
  def test_client_chat(prompt_type='human_bot'):
230
  return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
231
+ langchain_mode='Disabled',
232
+ langchain_action=LangChainAction.QUERY.value,
233
+ langchain_agents=[])
234
 
235
 
236
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
237
  def test_client_chat_stream(prompt_type='human_bot'):
238
  return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
239
  stream_output=True, max_new_tokens=512,
240
+ langchain_mode='Disabled',
241
+ langchain_action=LangChainAction.QUERY.value,
242
+ langchain_agents=[])
243
 
244
 
245
+ def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens,
246
+ langchain_mode, langchain_action, langchain_agents,
247
  prompt_dict=None):
248
  client = get_client(serialize=False)
249
 
250
  kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
251
+ max_new_tokens=max_new_tokens,
252
+ langchain_mode=langchain_mode,
253
  langchain_action=langchain_action,
254
+ langchain_agents=langchain_agents,
255
  prompt_dict=prompt_dict)
256
  return run_client(client, prompt, args, kwargs)
257
 
 
295
  def test_client_nochat_stream(prompt_type='human_bot'):
296
  return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
297
  stream_output=True, max_new_tokens=512,
298
+ langchain_mode='Disabled',
299
+ langchain_action=LangChainAction.QUERY.value,
300
+ langchain_agents=[])
301
 
302
 
303
+ def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens,
304
+ langchain_mode, langchain_action, langchain_agents):
305
  client = get_client(serialize=False)
306
 
307
  kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
308
  max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
309
+ langchain_action=langchain_action, langchain_agents=langchain_agents)
310
  return run_client_gen(client, prompt, args, kwargs)
311
 
312
 
enums.py CHANGED
@@ -31,6 +31,7 @@ class PromptType(Enum):
31
  mptinstruct = 25
32
  mptchat = 26
33
  falcon = 27
 
34
 
35
 
36
  class DocumentChoices(Enum):
@@ -71,6 +72,13 @@ class LangChainAction(Enum):
71
  SUMMARIZE_REFINE = "Summarize_refine"
72
 
73
 
 
 
 
 
 
 
 
74
  no_server_str = no_lora_str = no_model_str = '[None/Remove]'
75
 
76
  # from site-packages/langchain/llms/openai.py
 
31
  mptinstruct = 25
32
  mptchat = 26
33
  falcon = 27
34
+ guanaco = 28
35
 
36
 
37
  class DocumentChoices(Enum):
 
72
  SUMMARIZE_REFINE = "Summarize_refine"
73
 
74
 
75
+ class LangChainAgent(Enum):
76
+ """LangChain agents"""
77
+
78
+ SEARCH = "Search"
79
+ # CSV = "csv" # WIP
80
+
81
+
82
  no_server_str = no_lora_str = no_model_str = '[None/Remove]'
83
 
84
  # from site-packages/langchain/llms/openai.py
evaluate_params.py CHANGED
@@ -31,6 +31,7 @@ eval_func_param_names = ['instruction',
31
  'iinput_nochat',
32
  'langchain_mode',
33
  'langchain_action',
 
34
  'top_k_docs',
35
  'chunk',
36
  'chunk_size',
 
31
  'iinput_nochat',
32
  'langchain_mode',
33
  'langchain_action',
34
+ 'langchain_agents',
35
  'top_k_docs',
36
  'chunk',
37
  'chunk_size',
gen.py CHANGED
@@ -29,11 +29,11 @@ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is
29
 
30
  from evaluate_params import eval_func_param_names, no_default_param_names
31
  from enums import DocumentChoices, LangChainMode, no_lora_str, model_token_mapping, no_model_str, source_prefix, \
32
- source_postfix, LangChainAction
33
  from loaders import get_loaders
34
  from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
35
  import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove, \
36
- have_langchain
37
 
38
  start_faulthandler()
39
  import_matplotlib()
@@ -54,6 +54,8 @@ langchain_modes = [x.value for x in list(LangChainMode)]
54
 
55
  langchain_actions = [x.value for x in list(LangChainAction)]
56
 
 
 
57
  scratch_base_dir = '/tmp/'
58
 
59
 
@@ -134,7 +136,7 @@ def main(
134
  extra_lora_options: typing.List[str] = [],
135
  extra_server_options: typing.List[str] = [],
136
 
137
- score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
138
 
139
  eval_filename: str = None,
140
  eval_prompts_only_num: int = 0,
@@ -143,15 +145,18 @@ def main(
143
 
144
  langchain_mode: str = None,
145
  langchain_action: str = LangChainAction.QUERY.value,
 
146
  force_langchain_evaluate: bool = False,
147
  visible_langchain_modes: list = ['UserData', 'MyData'],
148
  # WIP:
149
  # visible_langchain_actions: list = langchain_actions.copy(),
150
  visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value],
 
151
  document_subset: str = DocumentChoices.Relevant.name,
152
  document_choice: list = [],
153
  user_path: str = None,
154
  detect_user_path_changes_every_query: bool = False,
 
155
  load_db_if_exists: bool = True,
156
  keep_sources_in_context: bool = False,
157
  db_type: str = 'chroma',
@@ -196,6 +201,8 @@ def main(
196
  Or Address can be "openai_chat" or "openai" for OpenAI API
197
  e.g. python generate.py --inference_server="openai_chat" --base_model=gpt-3.5-turbo
198
  e.g. python generate.py --inference_server="openai" --base_model=text-davinci-003
 
 
199
  :param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
200
  :param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
201
  :param model_lock: Lock models to specific combinations, for ease of use and extending to many models
@@ -271,18 +278,24 @@ def main(
271
  :param extra_model_options: extra models to show in list in gradio
272
  :param extra_lora_options: extra LORA to show in list in gradio
273
  :param extra_server_options: extra servers to show in list in gradio
274
- :param score_model: which model to score responses (None means no scoring)
 
 
 
275
  :param eval_filename: json file to use for evaluation, if None is sharegpt
276
  :param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
277
  :param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
278
  :param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
279
  :param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
 
280
  WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
281
  :param langchain_action: Mode langchain operations in on documents.
282
  Query: Make query of document(s)
283
  Summarize or Summarize_map_reduce: Summarize document(s) via map_reduce
284
  Summarize_all: Summarize document(s) using entire document at once
285
  Summarize_refine: Summarize document(s) using entire document, and try to refine before returning summary
 
 
286
  :param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing.
287
  :param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
288
  If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
@@ -293,17 +306,18 @@ def main(
293
  But wiki_full is expensive and requires preparation
294
  To allow scratch space only live in session, add 'MyData' to list
295
  Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
296
- FIXME: Avoid 'All' for now, not implemented
297
  :param visible_langchain_actions: Which actions to allow
 
298
  :param document_subset: Default document choice when taking subset of collection
299
  :param document_choice: Chosen document(s) by internal name
 
300
  :param load_db_if_exists: Whether to load chroma db if exists or re-generate db
301
  :param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
302
  :param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
303
  :param use_openai_embedding: Whether to use OpenAI embeddings for vector db
304
  :param use_openai_model: Whether to use OpenAI model for use with vector db
305
  :param hf_embedding_model: Which HF embedding model to use for vector db
306
- Default is instructor-large with 768 parameters per embedding if have GPUs, else all-MiniLM-L6-v1 if no GPUs
307
  Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2"
308
  Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl'
309
  We support automatically changing of embeddings for chroma, with a backup of db made if this is done
@@ -327,6 +341,7 @@ def main(
327
  captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
328
  captions_model: str = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state
329
  Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
 
330
  :param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
331
  parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
332
  Recommended if using larger caption model
@@ -394,6 +409,8 @@ def main(
394
  visible_langchain_modes += [langchain_mode]
395
 
396
  assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
 
 
397
 
398
  # if specifically chose not to show My or User Data, disable upload, so gradio elements are simpler
399
  if LangChainMode.MY_DATA.value not in visible_langchain_modes:
@@ -413,7 +430,8 @@ def main(
413
  " set user_path and ensure allow_upload_to_user_data=True" % langchain_mode, flush=True)
414
  else:
415
  raise RuntimeError("Please pass --langchain_mode=<chosen mode> out of %s" % langchain_modes)
416
- if not have_langchain and langchain_mode not in [None, LangChainMode.DISABLED.value, LangChainMode.LLM.value, LangChainMode.CHAT_LLM.value]:
 
417
  raise RuntimeError("Asked for LangChain mode but langchain python package cannot be found.")
418
  if langchain_mode is None:
419
  # if not set yet, disable
@@ -474,7 +492,7 @@ def main(
474
  # HF accounted for later in get_max_max_new_tokens()
475
  save_dir = os.getenv('SAVE_DIR', save_dir)
476
  score_model = os.getenv('SCORE_MODEL', score_model)
477
- if score_model == 'None' or score_model is None:
478
  score_model = ''
479
  concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count))
480
  api_open = bool(int(os.getenv('API_OPEN', str(int(api_open)))))
@@ -482,6 +500,7 @@ def main(
482
 
483
  n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
484
  if n_gpus == 0:
 
485
  gpu_id = None
486
  load_8bit = False
487
  load_4bit = False
@@ -499,7 +518,11 @@ def main(
499
  if hf_embedding_model is None:
500
  # if no GPUs, use simpler embedding model to avoid cost in time
501
  hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
 
 
502
  else:
 
 
503
  if hf_embedding_model is None:
504
  # if still None, then set default
505
  hf_embedding_model = 'hkunlp/instructor-large'
@@ -967,11 +990,13 @@ def get_model(
967
  client = gr_client or hf_client
968
  # Don't return None, None for model, tokenizer so triggers
969
  return client, tokenizer, 'http'
970
- if isinstance(inference_server, str) and inference_server.startswith('openai'):
971
- assert os.getenv('OPENAI_API_KEY'), "Set environment for OPENAI_API_KEY"
972
- # Don't return None, None for model, tokenizer so triggers
973
- # include small token cushion
974
- tokenizer = FakeTokenizer(model_max_length=model_token_mapping[base_model] - 50)
 
 
975
  return inference_server, tokenizer, inference_server
976
  assert not inference_server, "Malformed inference_server=%s" % inference_server
977
  if base_model in non_hf_types:
@@ -1278,6 +1303,7 @@ def evaluate(
1278
  iinput_nochat,
1279
  langchain_mode,
1280
  langchain_action,
 
1281
  top_k_docs,
1282
  chunk,
1283
  chunk_size,
@@ -1298,6 +1324,7 @@ def evaluate(
1298
  raise_generate_gpu_exceptions=None,
1299
  chat_context=None,
1300
  lora_weights=None,
 
1301
  load_db_if_exists=True,
1302
  dbs=None,
1303
  user_path=None,
@@ -1452,6 +1479,8 @@ def evaluate(
1452
  # THIRD PLACE where LangChain referenced, but imports only occur if enabled and have db to use
1453
  assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
1454
  assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
 
 
1455
  if langchain_mode in ['MyData'] and my_db_state is not None and len(my_db_state) > 0 and my_db_state[0] is not None:
1456
  db1 = my_db_state[0]
1457
  elif dbs is not None and langchain_mode in dbs:
@@ -1484,6 +1513,7 @@ def evaluate(
1484
  inference_server=inference_server,
1485
  stream_output=stream_output,
1486
  prompter=prompter,
 
1487
  load_db_if_exists=load_db_if_exists,
1488
  db=db1,
1489
  user_path=user_path,
@@ -1498,6 +1528,7 @@ def evaluate(
1498
  chunk_size=chunk_size,
1499
  langchain_mode=langchain_mode,
1500
  langchain_action=langchain_action,
 
1501
  document_subset=document_subset,
1502
  document_choice=document_choice,
1503
  db_type=db_type,
@@ -1526,6 +1557,7 @@ def evaluate(
1526
  inference_server=inference_server,
1527
  langchain_mode=langchain_mode,
1528
  langchain_action=langchain_action,
 
1529
  document_subset=document_subset,
1530
  document_choice=document_choice,
1531
  num_prompt_tokens=num_prompt_tokens,
@@ -1549,12 +1581,12 @@ def evaluate(
1549
  clear_torch_cache()
1550
  return
1551
 
1552
- if inference_server.startswith('openai') or inference_server.startswith('http'):
1553
- if inference_server.startswith('openai'):
1554
- import openai
1555
  where_from = "openai_client"
 
1556
 
1557
- openai.api_key = os.getenv("OPENAI_API_KEY")
1558
  terminate_response = prompter.terminate_response or []
1559
  stop_sequences = list(set(terminate_response + [prompter.PreResponse]))
1560
  stop_sequences = [x for x in stop_sequences if x]
@@ -1567,7 +1599,7 @@ def evaluate(
1567
  n=num_return_sequences,
1568
  presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
1569
  )
1570
- if inference_server == 'openai':
1571
  response = openai.Completion.create(
1572
  model=base_model,
1573
  prompt=prompt,
@@ -1590,7 +1622,9 @@ def evaluate(
1590
  yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
1591
  sanitize_bot_response=sanitize_bot_response),
1592
  sources='')
1593
- elif inference_server == 'openai_chat':
 
 
1594
  response = openai.ChatCompletion.create(
1595
  model=base_model,
1596
  messages=[
@@ -1643,6 +1677,7 @@ def evaluate(
1643
  where_from = "gr_client"
1644
  client_langchain_mode = 'Disabled'
1645
  client_langchain_action = LangChainAction.QUERY.value
 
1646
  gen_server_kwargs = dict(temperature=temperature,
1647
  top_p=top_p,
1648
  top_k=top_k,
@@ -1695,6 +1730,7 @@ def evaluate(
1695
  iinput_nochat=gr_iinput, # only for chat=False
1696
  langchain_mode=client_langchain_mode,
1697
  langchain_action=client_langchain_action,
 
1698
  top_k_docs=top_k_docs,
1699
  chunk=chunk,
1700
  chunk_size=chunk_size,
@@ -2276,8 +2312,8 @@ y = np.random.randint(0, 1, 100)
2276
 
2277
  # move to correct position
2278
  for example in examples:
2279
- example += [chat, '', '', LangChainMode.DISABLED.value, LangChainAction.QUERY.value,
2280
- top_k_docs, chunk, chunk_size, [DocumentChoices.Relevant.name], []
2281
  ]
2282
  # adjust examples if non-chat mode
2283
  if not chat:
@@ -2383,14 +2419,14 @@ def check_locals(**kwargs):
2383
 
2384
 
2385
  def get_model_max_length(model_state):
2386
- if not isinstance(model_state['tokenizer'], (str, types.NoneType)):
2387
  return model_state['tokenizer'].model_max_length
2388
  else:
2389
  return 2048
2390
 
2391
 
2392
  def get_max_max_new_tokens(model_state, **kwargs):
2393
- if not isinstance(model_state['tokenizer'], (str, types.NoneType)):
2394
  max_max_new_tokens = model_state['tokenizer'].model_max_length
2395
  else:
2396
  max_max_new_tokens = None
 
29
 
30
  from evaluate_params import eval_func_param_names, no_default_param_names
31
  from enums import DocumentChoices, LangChainMode, no_lora_str, model_token_mapping, no_model_str, source_prefix, \
32
+ source_postfix, LangChainAction, LangChainAgent
33
  from loaders import get_loaders
34
  from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
35
  import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove, \
36
+ have_langchain, set_openai
37
 
38
  start_faulthandler()
39
  import_matplotlib()
 
54
 
55
  langchain_actions = [x.value for x in list(LangChainAction)]
56
 
57
+ langchain_agents_list = [x.value for x in list(LangChainAgent)]
58
+
59
  scratch_base_dir = '/tmp/'
60
 
61
 
 
136
  extra_lora_options: typing.List[str] = [],
137
  extra_server_options: typing.List[str] = [],
138
 
139
+ score_model: str = 'auto',
140
 
141
  eval_filename: str = None,
142
  eval_prompts_only_num: int = 0,
 
145
 
146
  langchain_mode: str = None,
147
  langchain_action: str = LangChainAction.QUERY.value,
148
+ langchain_agents: list = [],
149
  force_langchain_evaluate: bool = False,
150
  visible_langchain_modes: list = ['UserData', 'MyData'],
151
  # WIP:
152
  # visible_langchain_actions: list = langchain_actions.copy(),
153
  visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value],
154
+ visible_langchain_agents: list = langchain_agents_list.copy(),
155
  document_subset: str = DocumentChoices.Relevant.name,
156
  document_choice: list = [],
157
  user_path: str = None,
158
  detect_user_path_changes_every_query: bool = False,
159
+ use_llm_if_no_docs: bool = False,
160
  load_db_if_exists: bool = True,
161
  keep_sources_in_context: bool = False,
162
  db_type: str = 'chroma',
 
201
  Or Address can be "openai_chat" or "openai" for OpenAI API
202
  e.g. python generate.py --inference_server="openai_chat" --base_model=gpt-3.5-turbo
203
  e.g. python generate.py --inference_server="openai" --base_model=text-davinci-003
204
+ Or Address can be "vllm:IP:port" or "vllm:IP:port" for OpenAI-compliant vLLM endpoint
205
+ Note: vllm_chat not supported by vLLM project.
206
  :param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
207
  :param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
208
  :param model_lock: Lock models to specific combinations, for ease of use and extending to many models
 
278
  :param extra_model_options: extra models to show in list in gradio
279
  :param extra_lora_options: extra LORA to show in list in gradio
280
  :param extra_server_options: extra servers to show in list in gradio
281
+ :param score_model: which model to score responses
282
+ None: no response scoring
283
+ 'auto': auto mode, '' (no model) for CPU, 'OpenAssistant/reward-model-deberta-v3-large-v2' for GPU,
284
+ because on CPU takes too much compute just for scoring response
285
  :param eval_filename: json file to use for evaluation, if None is sharegpt
286
  :param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
287
  :param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
288
  :param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
289
  :param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
290
+ None: auto mode, check if langchain package exists, at least do ChatLLM if so, else Disabled
291
  WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
292
  :param langchain_action: Mode langchain operations in on documents.
293
  Query: Make query of document(s)
294
  Summarize or Summarize_map_reduce: Summarize document(s) via map_reduce
295
  Summarize_all: Summarize document(s) using entire document at once
296
  Summarize_refine: Summarize document(s) using entire document, and try to refine before returning summary
297
+ :param langchain_agents: Which agents to use
298
+ 'search': Use Web Search as context for LLM response, e.g. SERP if have SERPAPI_API_KEY in env
299
  :param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing.
300
  :param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
301
  If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
 
306
  But wiki_full is expensive and requires preparation
307
  To allow scratch space only live in session, add 'MyData' to list
308
  Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
 
309
  :param visible_langchain_actions: Which actions to allow
310
+ :param visible_langchain_agents: Which agents to allow
311
  :param document_subset: Default document choice when taking subset of collection
312
  :param document_choice: Chosen document(s) by internal name
313
+ :param use_llm_if_no_docs: Whether to use LLM even if no documents, when langchain_mode=UserData or MyData
314
  :param load_db_if_exists: Whether to load chroma db if exists or re-generate db
315
  :param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
316
  :param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
317
  :param use_openai_embedding: Whether to use OpenAI embeddings for vector db
318
  :param use_openai_model: Whether to use OpenAI model for use with vector db
319
  :param hf_embedding_model: Which HF embedding model to use for vector db
320
+ Default is instructor-large with 768 parameters per embedding if have GPUs, else all-MiniLM-L6-v2 if no GPUs
321
  Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2"
322
  Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl'
323
  We support automatically changing of embeddings for chroma, with a backup of db made if this is done
 
341
  captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
342
  captions_model: str = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state
343
  Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
344
+ Disabled for CPU since BLIP requires CUDA
345
  :param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
346
  parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
347
  Recommended if using larger caption model
 
409
  visible_langchain_modes += [langchain_mode]
410
 
411
  assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
412
+ assert len(
413
+ set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents
414
 
415
  # if specifically chose not to show My or User Data, disable upload, so gradio elements are simpler
416
  if LangChainMode.MY_DATA.value not in visible_langchain_modes:
 
430
  " set user_path and ensure allow_upload_to_user_data=True" % langchain_mode, flush=True)
431
  else:
432
  raise RuntimeError("Please pass --langchain_mode=<chosen mode> out of %s" % langchain_modes)
433
+ if not have_langchain and langchain_mode not in [None, LangChainMode.DISABLED.value, LangChainMode.LLM.value,
434
+ LangChainMode.CHAT_LLM.value]:
435
  raise RuntimeError("Asked for LangChain mode but langchain python package cannot be found.")
436
  if langchain_mode is None:
437
  # if not set yet, disable
 
492
  # HF accounted for later in get_max_max_new_tokens()
493
  save_dir = os.getenv('SAVE_DIR', save_dir)
494
  score_model = os.getenv('SCORE_MODEL', score_model)
495
+ if str(score_model) == 'None':
496
  score_model = ''
497
  concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count))
498
  api_open = bool(int(os.getenv('API_OPEN', str(int(api_open)))))
 
500
 
501
  n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
502
  if n_gpus == 0:
503
+ enable_captions = False
504
  gpu_id = None
505
  load_8bit = False
506
  load_4bit = False
 
518
  if hf_embedding_model is None:
519
  # if no GPUs, use simpler embedding model to avoid cost in time
520
  hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
521
+ if score_model == 'auto':
522
+ score_model = ''
523
  else:
524
+ if score_model == 'auto':
525
+ score_model = 'OpenAssistant/reward-model-deberta-v3-large-v2'
526
  if hf_embedding_model is None:
527
  # if still None, then set default
528
  hf_embedding_model = 'hkunlp/instructor-large'
 
990
  client = gr_client or hf_client
991
  # Don't return None, None for model, tokenizer so triggers
992
  return client, tokenizer, 'http'
993
+ if isinstance(inference_server, str) and (
994
+ inference_server.startswith('openai') or inference_server.startswith('vllm')):
995
+ if inference_server.startswith('openai'):
996
+ assert os.getenv('OPENAI_API_KEY'), "Set environment for OPENAI_API_KEY"
997
+ # Don't return None, None for model, tokenizer so triggers
998
+ # include small token cushion
999
+ tokenizer = FakeTokenizer(model_max_length=model_token_mapping[base_model] - 50)
1000
  return inference_server, tokenizer, inference_server
1001
  assert not inference_server, "Malformed inference_server=%s" % inference_server
1002
  if base_model in non_hf_types:
 
1303
  iinput_nochat,
1304
  langchain_mode,
1305
  langchain_action,
1306
+ langchain_agents,
1307
  top_k_docs,
1308
  chunk,
1309
  chunk_size,
 
1324
  raise_generate_gpu_exceptions=None,
1325
  chat_context=None,
1326
  lora_weights=None,
1327
+ use_llm_if_no_docs=False,
1328
  load_db_if_exists=True,
1329
  dbs=None,
1330
  user_path=None,
 
1479
  # THIRD PLACE where LangChain referenced, but imports only occur if enabled and have db to use
1480
  assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
1481
  assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
1482
+ assert len(
1483
+ set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents
1484
  if langchain_mode in ['MyData'] and my_db_state is not None and len(my_db_state) > 0 and my_db_state[0] is not None:
1485
  db1 = my_db_state[0]
1486
  elif dbs is not None and langchain_mode in dbs:
 
1513
  inference_server=inference_server,
1514
  stream_output=stream_output,
1515
  prompter=prompter,
1516
+ use_llm_if_no_docs=use_llm_if_no_docs,
1517
  load_db_if_exists=load_db_if_exists,
1518
  db=db1,
1519
  user_path=user_path,
 
1528
  chunk_size=chunk_size,
1529
  langchain_mode=langchain_mode,
1530
  langchain_action=langchain_action,
1531
+ langchain_agents=langchain_agents,
1532
  document_subset=document_subset,
1533
  document_choice=document_choice,
1534
  db_type=db_type,
 
1557
  inference_server=inference_server,
1558
  langchain_mode=langchain_mode,
1559
  langchain_action=langchain_action,
1560
+ langchain_agents=langchain_agents,
1561
  document_subset=document_subset,
1562
  document_choice=document_choice,
1563
  num_prompt_tokens=num_prompt_tokens,
 
1581
  clear_torch_cache()
1582
  return
1583
 
1584
+ if inference_server.startswith('vllm') or inference_server.startswith('openai') or inference_server.startswith(
1585
+ 'http'):
1586
+ if inference_server.startswith('vllm') or inference_server.startswith('openai'):
1587
  where_from = "openai_client"
1588
+ openai, inf_type = set_openai(inference_server)
1589
 
 
1590
  terminate_response = prompter.terminate_response or []
1591
  stop_sequences = list(set(terminate_response + [prompter.PreResponse]))
1592
  stop_sequences = [x for x in stop_sequences if x]
 
1599
  n=num_return_sequences,
1600
  presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
1601
  )
1602
+ if inf_type == 'vllm' or inference_server == 'openai':
1603
  response = openai.Completion.create(
1604
  model=base_model,
1605
  prompt=prompt,
 
1622
  yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
1623
  sanitize_bot_response=sanitize_bot_response),
1624
  sources='')
1625
+ elif inf_type == 'vllm_chat' or inference_server == 'openai_chat':
1626
+ if inf_type == 'vllm_chat':
1627
+ raise NotImplementedError('%s not supported by vLLM' % inf_type)
1628
  response = openai.ChatCompletion.create(
1629
  model=base_model,
1630
  messages=[
 
1677
  where_from = "gr_client"
1678
  client_langchain_mode = 'Disabled'
1679
  client_langchain_action = LangChainAction.QUERY.value
1680
+ client_langchain_agents = []
1681
  gen_server_kwargs = dict(temperature=temperature,
1682
  top_p=top_p,
1683
  top_k=top_k,
 
1730
  iinput_nochat=gr_iinput, # only for chat=False
1731
  langchain_mode=client_langchain_mode,
1732
  langchain_action=client_langchain_action,
1733
+ langchain_agents=client_langchain_agents,
1734
  top_k_docs=top_k_docs,
1735
  chunk=chunk,
1736
  chunk_size=chunk_size,
 
2312
 
2313
  # move to correct position
2314
  for example in examples:
2315
+ example += [chat, '', '', LangChainMode.DISABLED.value, LangChainAction.QUERY.value, [],
2316
+ top_k_docs, chunk, chunk_size, DocumentChoices.Relevant.name, []
2317
  ]
2318
  # adjust examples if non-chat mode
2319
  if not chat:
 
2419
 
2420
 
2421
  def get_model_max_length(model_state):
2422
+ if not isinstance(model_state['tokenizer'], (str, type(None))):
2423
  return model_state['tokenizer'].model_max_length
2424
  else:
2425
  return 2048
2426
 
2427
 
2428
  def get_max_max_new_tokens(model_state, **kwargs):
2429
+ if not isinstance(model_state['tokenizer'], (str, type(None))):
2430
  max_max_new_tokens = model_state['tokenizer'].model_max_length
2431
  else:
2432
  max_max_new_tokens = None
gpt_langchain.py CHANGED
@@ -21,6 +21,7 @@ import filelock
21
  from joblib import delayed
22
  from langchain.callbacks import streaming_stdout
23
  from langchain.embeddings import HuggingFaceInstructEmbeddings
 
24
  from tqdm import tqdm
25
 
26
  from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
@@ -30,7 +31,7 @@ from gen import get_model, SEED
30
  from prompter import non_hf_types, PromptType, Prompter
31
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
32
  get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
33
- have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_pymupdf
34
  from utils_langchain import StreamingGradioCallbackHandler
35
 
36
  import_matplotlib()
@@ -276,15 +277,7 @@ from typing import Any, Dict, List, Optional, Set
276
 
277
  from pydantic import Extra, Field, root_validator
278
 
279
- from langchain.callbacks.manager import CallbackManagerForLLMRun
280
-
281
- """Wrapper around Huggingface text generation inference API."""
282
- from functools import partial
283
- from typing import Any, Dict, List, Optional
284
-
285
- from pydantic import Extra, Field, root_validator
286
-
287
- from langchain.callbacks.manager import CallbackManagerForLLMRun
288
  from langchain.llms.base import LLM
289
 
290
 
@@ -356,6 +349,7 @@ class GradioInference(LLM):
356
  gr_client = self.client
357
  client_langchain_mode = 'Disabled'
358
  client_langchain_action = LangChainAction.QUERY.value
 
359
  top_k_docs = 1
360
  chunk = True
361
  chunk_size = 512
@@ -385,6 +379,7 @@ class GradioInference(LLM):
385
  iinput_nochat='', # only for chat=False
386
  langchain_mode=client_langchain_mode,
387
  langchain_action=client_langchain_action,
 
388
  top_k_docs=top_k_docs,
389
  chunk=chunk,
390
  chunk_size=chunk_size,
@@ -566,6 +561,92 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
566
 
567
 
568
  from langchain.chat_models import ChatOpenAI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
 
570
 
571
  class H2OChatOpenAI(ChatOpenAI):
@@ -599,14 +680,26 @@ def get_llm(use_openai_model=False,
599
  sanitize_bot_response=False,
600
  verbose=False,
601
  ):
602
- if use_openai_model or inference_server in ['openai', 'openai_chat']:
603
  if use_openai_model and model_name is None:
604
  model_name = "gpt-3.5-turbo"
605
- if inference_server == 'openai':
606
- from langchain.llms import OpenAI
607
- cls = OpenAI
608
- else:
609
  cls = H2OChatOpenAI
 
 
 
 
 
 
 
 
 
 
 
 
610
  callbacks = [StreamingGradioCallbackHandler()]
611
  llm = cls(model_name=model_name,
612
  temperature=temperature if do_sample else 0,
@@ -616,11 +709,18 @@ def get_llm(use_openai_model=False,
616
  frequency_penalty=0,
617
  presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
618
  callbacks=callbacks if stream_output else None,
 
 
 
 
 
 
619
  )
620
  streamer = callbacks[0] if stream_output else None
621
  if inference_server in ['openai', 'openai_chat']:
622
  prompt_type = inference_server
623
  else:
 
624
  prompt_type = prompt_type or 'plain'
625
  elif inference_server:
626
  assert inference_server.startswith(
@@ -916,7 +1016,6 @@ def get_dai_docs(from_hf=False, get_pickle=True):
916
  return sources
917
 
918
 
919
-
920
  image_types = ["png", "jpg", "jpeg"]
921
  non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
922
  "md",
@@ -927,7 +1026,8 @@ non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
927
  ]
928
  # "msg", GPL3
929
 
930
- if have_libreoffice:
 
931
  non_image_types.extend(["docx", "doc", "xls", "xlsx"])
932
 
933
  file_types = non_image_types + image_types
@@ -936,9 +1036,11 @@ file_types = non_image_types + image_types
936
  def add_meta(docs1, file):
937
  file_extension = pathlib.Path(file).suffix
938
  hashid = hash_file(file)
 
939
  if not isinstance(docs1, (list, tuple, types.GeneratorType)):
940
  docs1 = [docs1]
941
- [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now()), hashid=hashid)) for x in docs1]
 
942
 
943
 
944
  def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
@@ -1011,11 +1113,11 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
1011
  add_meta(docs1, file)
1012
  docs1 = clean_doc(docs1)
1013
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.HTML)
1014
- elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
1015
  docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
1016
  add_meta(docs1, file)
1017
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1018
- elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and have_libreoffice:
1019
  docs1 = UnstructuredExcelLoader(file_path=file).load()
1020
  add_meta(docs1, file)
1021
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
@@ -1760,6 +1862,7 @@ def _run_qa_db(query=None,
1760
  cut_distanct=1.1,
1761
  sanitize_bot_response=False,
1762
  show_rank=False,
 
1763
  load_db_if_exists=False,
1764
  db=None,
1765
  do_sample=False,
@@ -1775,6 +1878,7 @@ def _run_qa_db(query=None,
1775
  num_return_sequences=1,
1776
  langchain_mode=None,
1777
  langchain_action=None,
 
1778
  document_subset=DocumentChoices.Relevant.name,
1779
  document_choice=[],
1780
  n_jobs=-1,
@@ -1857,20 +1961,21 @@ def _run_qa_db(query=None,
1857
  formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
1858
  yield formatted_doc_chunks, ''
1859
  return
1860
- if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
1861
- LangChainAction.SUMMARIZE_ALL.value,
1862
- LangChainAction.SUMMARIZE_REFINE.value]:
1863
- ret = 'No relevant documents to summarize.' if have_any_docs else 'No documents to summarize.'
1864
- extra = ''
1865
- yield ret, extra
1866
- return
1867
- if not docs and langchain_mode not in [LangChainMode.DISABLED.value,
1868
- LangChainMode.CHAT_LLM.value,
1869
- LangChainMode.LLM.value]:
1870
- ret = 'No relevant documents to query.' if have_any_docs else 'No documents to query.'
1871
- extra = ''
1872
- yield ret, extra
1873
- return
 
1874
 
1875
  if chain is None and model_name not in non_hf_types:
1876
  # here if no docs at all and not HF type
@@ -1948,6 +2053,7 @@ def get_chain(query=None,
1948
  db=None,
1949
  langchain_mode=None,
1950
  langchain_action=None,
 
1951
  document_subset=DocumentChoices.Relevant.name,
1952
  document_choice=[],
1953
  n_jobs=-1,
@@ -1961,6 +2067,7 @@ def get_chain(query=None,
1961
  auto_reduce_chunks=True,
1962
  max_chunks=100,
1963
  ):
 
1964
  # determine whether use of context out of docs is planned
1965
  if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
1966
  if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
@@ -2092,8 +2199,8 @@ def get_chain(query=None,
2092
  for result in zip(db_documents, db_metadatas)]
2093
 
2094
  # order documents
2095
- doc_hashes = [x['doc_hash'] for x in db_metadatas]
2096
- doc_chunk_ids = [x['chunk_id'] for x in db_metadatas]
2097
  docs_with_score = [x for _, _, x in
2098
  sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
2099
  ]
@@ -2302,6 +2409,7 @@ def clean_doc(docs1):
2302
 
2303
  def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
2304
  if not chunk:
 
2305
  return sources
2306
  if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources):
2307
  # if just one document
@@ -2320,8 +2428,7 @@ def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
2320
  source_chunks = splitter.split_documents(sources)
2321
 
2322
  # currently in order, but when pull from db won't be, so mark order and document by hash
2323
- doc_hash = str(uuid.uuid4())[:10]
2324
- [x.metadata.update(dict(doc_hash=doc_hash, chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)]
2325
 
2326
  return source_chunks
2327
 
 
21
  from joblib import delayed
22
  from langchain.callbacks import streaming_stdout
23
  from langchain.embeddings import HuggingFaceInstructEmbeddings
24
+ from langchain.schema import LLMResult
25
  from tqdm import tqdm
26
 
27
  from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
 
31
  from prompter import non_hf_types, PromptType, Prompter
32
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
33
  get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
34
+ have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_pymupdf, set_openai
35
  from utils_langchain import StreamingGradioCallbackHandler
36
 
37
  import_matplotlib()
 
277
 
278
  from pydantic import Extra, Field, root_validator
279
 
280
+ from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
 
 
 
 
 
 
 
 
281
  from langchain.llms.base import LLM
282
 
283
 
 
349
  gr_client = self.client
350
  client_langchain_mode = 'Disabled'
351
  client_langchain_action = LangChainAction.QUERY.value
352
+ client_langchain_agents = []
353
  top_k_docs = 1
354
  chunk = True
355
  chunk_size = 512
 
379
  iinput_nochat='', # only for chat=False
380
  langchain_mode=client_langchain_mode,
381
  langchain_action=client_langchain_action,
382
+ langchain_agents=client_langchain_agents,
383
  top_k_docs=top_k_docs,
384
  chunk=chunk,
385
  chunk_size=chunk_size,
 
561
 
562
 
563
  from langchain.chat_models import ChatOpenAI
564
+ from langchain.llms import OpenAI
565
+ from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \
566
+ update_token_usage
567
+
568
+
569
+ class H2OOpenAI(OpenAI):
570
+ """
571
+ New class to handle vLLM's use of OpenAI, no vllm_chat supported, so only need here
572
+ Handles prompting that OpenAI doesn't need, stopping as well
573
+ """
574
+ stop_sequences: Any = None
575
+ sanitize_bot_response: bool = False
576
+ prompter: Any = None
577
+ tokenizer: Any = None
578
+
579
+ @classmethod
580
+ def all_required_field_names(cls) -> Set:
581
+ all_required_field_names = super(OpenAI, cls).all_required_field_names()
582
+ all_required_field_names.update(
583
+ {'top_p', 'frequency_penalty', 'presence_penalty', 'stop_sequences', 'sanitize_bot_response', 'prompter',
584
+ 'tokenizer'})
585
+ return all_required_field_names
586
+
587
+ def _generate(
588
+ self,
589
+ prompts: List[str],
590
+ stop: Optional[List[str]] = None,
591
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
592
+ **kwargs: Any,
593
+ ) -> LLMResult:
594
+ stop = self.stop_sequences if not stop else self.stop_sequences + stop
595
+
596
+ # HF inference server needs control over input tokens
597
+ assert self.tokenizer is not None
598
+ from h2oai_pipeline import H2OTextGenerationPipeline
599
+ for prompti, prompt in enumerate(prompts):
600
+ prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
601
+ # NOTE: OpenAI/vLLM server does not add prompting, so must do here
602
+ data_point = dict(context='', instruction=prompt, input='')
603
+ prompt = self.prompter.generate_prompt(data_point)
604
+ prompts[prompti] = prompt
605
+
606
+ params = self._invocation_params
607
+ params = {**params, **kwargs}
608
+ sub_prompts = self.get_sub_prompts(params, prompts, stop)
609
+ choices = []
610
+ token_usage: Dict[str, int] = {}
611
+ # Get the token usage from the response.
612
+ # Includes prompt, completion, and total tokens used.
613
+ _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
614
+ text = ''
615
+ for _prompts in sub_prompts:
616
+ if self.streaming:
617
+ text_with_prompt = ""
618
+ prompt = _prompts[0]
619
+ if len(_prompts) > 1:
620
+ raise ValueError("Cannot stream results with multiple prompts.")
621
+ params["stream"] = True
622
+ response = _streaming_response_template()
623
+ first = True
624
+ for stream_resp in completion_with_retry(
625
+ self, prompt=_prompts, **params
626
+ ):
627
+ if first:
628
+ stream_resp["choices"][0]["text"] = prompt + stream_resp["choices"][0]["text"]
629
+ first = False
630
+ text_chunk = stream_resp["choices"][0]["text"]
631
+ text_with_prompt += text_chunk
632
+ text = self.prompter.get_response(text_with_prompt, prompt=prompt,
633
+ sanitize_bot_response=self.sanitize_bot_response)
634
+ if run_manager:
635
+ run_manager.on_llm_new_token(
636
+ text_chunk,
637
+ verbose=self.verbose,
638
+ logprobs=stream_resp["choices"][0]["logprobs"],
639
+ )
640
+ _update_response(response, stream_resp)
641
+ choices.extend(response["choices"])
642
+ else:
643
+ response = completion_with_retry(self, prompt=_prompts, **params)
644
+ choices.extend(response["choices"])
645
+ if not self.streaming:
646
+ # Can't update token usage if streaming
647
+ update_token_usage(_keys, response, token_usage)
648
+ choices[0]['text'] = text
649
+ return self.create_llm_result(choices, prompts, token_usage)
650
 
651
 
652
  class H2OChatOpenAI(ChatOpenAI):
 
680
  sanitize_bot_response=False,
681
  verbose=False,
682
  ):
683
+ if use_openai_model or inference_server.startswith('openai') or inference_server.startswith('vllm'):
684
  if use_openai_model and model_name is None:
685
  model_name = "gpt-3.5-turbo"
686
+ openai, inf_type = set_openai(
687
+ inference_server) # FIXME: Will later import be ignored? I think so, so should be fine
688
+ kwargs_extra = {}
689
+ if inference_server == 'openai_chat' or inf_type == 'vllm_chat':
690
  cls = H2OChatOpenAI
691
+ else:
692
+ cls = H2OOpenAI
693
+ if inf_type == 'vllm':
694
+ terminate_response = prompter.terminate_response or []
695
+ stop_sequences = list(set(terminate_response + [prompter.PreResponse]))
696
+ stop_sequences = [x for x in stop_sequences if x]
697
+ kwargs_extra = dict(stop_sequences=stop_sequences,
698
+ sanitize_bot_response=sanitize_bot_response,
699
+ prompter=prompter,
700
+ tokenizer=tokenizer,
701
+ client=None)
702
+
703
  callbacks = [StreamingGradioCallbackHandler()]
704
  llm = cls(model_name=model_name,
705
  temperature=temperature if do_sample else 0,
 
709
  frequency_penalty=0,
710
  presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
711
  callbacks=callbacks if stream_output else None,
712
+ openai_api_key=openai.api_key,
713
+ openai_api_base=openai.api_base,
714
+ logit_bias=None if inf_type =='vllm' else {},
715
+ max_retries=2,
716
+ streaming=stream_output,
717
+ **kwargs_extra
718
  )
719
  streamer = callbacks[0] if stream_output else None
720
  if inference_server in ['openai', 'openai_chat']:
721
  prompt_type = inference_server
722
  else:
723
+ # vllm goes here
724
  prompt_type = prompt_type or 'plain'
725
  elif inference_server:
726
  assert inference_server.startswith(
 
1016
  return sources
1017
 
1018
 
 
1019
  image_types = ["png", "jpg", "jpeg"]
1020
  non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
1021
  "md",
 
1026
  ]
1027
  # "msg", GPL3
1028
 
1029
+ if have_libreoffice or True:
1030
+ # or True so it tries to load, e.g. on MAC/Windows, even if don't have libreoffice since works without that
1031
  non_image_types.extend(["docx", "doc", "xls", "xlsx"])
1032
 
1033
  file_types = non_image_types + image_types
 
1036
  def add_meta(docs1, file):
1037
  file_extension = pathlib.Path(file).suffix
1038
  hashid = hash_file(file)
1039
+ doc_hash = str(uuid.uuid4())[:10]
1040
  if not isinstance(docs1, (list, tuple, types.GeneratorType)):
1041
  docs1 = [docs1]
1042
+ [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now()), hashid=hashid, doc_hash=doc_hash)) for
1043
+ x in docs1]
1044
 
1045
 
1046
  def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
 
1113
  add_meta(docs1, file)
1114
  docs1 = clean_doc(docs1)
1115
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.HTML)
1116
+ elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and (have_libreoffice or True):
1117
  docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
1118
  add_meta(docs1, file)
1119
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1120
+ elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and (have_libreoffice or True):
1121
  docs1 = UnstructuredExcelLoader(file_path=file).load()
1122
  add_meta(docs1, file)
1123
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
 
1862
  cut_distanct=1.1,
1863
  sanitize_bot_response=False,
1864
  show_rank=False,
1865
+ use_llm_if_no_docs=False,
1866
  load_db_if_exists=False,
1867
  db=None,
1868
  do_sample=False,
 
1878
  num_return_sequences=1,
1879
  langchain_mode=None,
1880
  langchain_action=None,
1881
+ langchain_agents=None,
1882
  document_subset=DocumentChoices.Relevant.name,
1883
  document_choice=[],
1884
  n_jobs=-1,
 
1961
  formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
1962
  yield formatted_doc_chunks, ''
1963
  return
1964
+ if not use_llm_if_no_docs:
1965
+ if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
1966
+ LangChainAction.SUMMARIZE_ALL.value,
1967
+ LangChainAction.SUMMARIZE_REFINE.value]:
1968
+ ret = 'No relevant documents to summarize.' if have_any_docs else 'No documents to summarize.'
1969
+ extra = ''
1970
+ yield ret, extra
1971
+ return
1972
+ if not docs and langchain_mode not in [LangChainMode.DISABLED.value,
1973
+ LangChainMode.CHAT_LLM.value,
1974
+ LangChainMode.LLM.value]:
1975
+ ret = 'No relevant documents to query.' if have_any_docs else 'No documents to query.'
1976
+ extra = ''
1977
+ yield ret, extra
1978
+ return
1979
 
1980
  if chain is None and model_name not in non_hf_types:
1981
  # here if no docs at all and not HF type
 
2053
  db=None,
2054
  langchain_mode=None,
2055
  langchain_action=None,
2056
+ langchain_agents=None,
2057
  document_subset=DocumentChoices.Relevant.name,
2058
  document_choice=[],
2059
  n_jobs=-1,
 
2067
  auto_reduce_chunks=True,
2068
  max_chunks=100,
2069
  ):
2070
+ assert langchain_agents is not None # should be at least []
2071
  # determine whether use of context out of docs is planned
2072
  if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
2073
  if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
 
2199
  for result in zip(db_documents, db_metadatas)]
2200
 
2201
  # order documents
2202
+ doc_hashes = [x.get('doc_hash', 'None') for x in db_metadatas]
2203
+ doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas]
2204
  docs_with_score = [x for _, _, x in
2205
  sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
2206
  ]
 
2409
 
2410
  def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
2411
  if not chunk:
2412
+ [x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(sources)]
2413
  return sources
2414
  if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources):
2415
  # if just one document
 
2428
  source_chunks = splitter.split_documents(sources)
2429
 
2430
  # currently in order, but when pull from db won't be, so mark order and document by hash
2431
+ [x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)]
 
2432
 
2433
  return source_chunks
2434
 
gradio_runner.py CHANGED
@@ -58,7 +58,7 @@ from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt
58
  from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
59
  ping, get_short_name, makedirs, get_kwargs, remove, system_info, ping_gpu, get_url, get_local_ip
60
  from gen import get_model, languages_covered, evaluate, score_qa, langchain_modes, inputs_kwargs_list, scratch_base_dir, \
61
- get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions
62
  from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults
63
 
64
  from apscheduler.schedulers.background import BackgroundScheduler
@@ -101,6 +101,7 @@ def go_gradio(**kwargs):
101
  db_type = kwargs['db_type']
102
  visible_langchain_modes = kwargs['visible_langchain_modes']
103
  visible_langchain_actions = kwargs['visible_langchain_actions']
 
104
  allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
105
  allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
106
  enable_sources_list = kwargs['enable_sources_list']
@@ -361,6 +362,14 @@ def go_gradio(**kwargs):
361
  value=allowed_actions[0] if len(allowed_actions) > 0 else None,
362
  label="Action",
363
  visible=True)
 
 
 
 
 
 
 
 
364
  col_tabs = gr.Column(elem_id="col_container", scale=10)
365
  with (col_tabs, gr.Tabs()):
366
  with gr.TabItem("Chat"):
@@ -469,6 +478,7 @@ def go_gradio(**kwargs):
469
  value=None,
470
  interactive=True,
471
  multiselect=False,
 
472
  )
473
  with gr.Column(scale=4):
474
  pass
@@ -1035,6 +1045,8 @@ def go_gradio(**kwargs):
1035
  user_kwargs['langchain_mode'] = 'Disabled'
1036
  if 'langchain_action' not in user_kwargs:
1037
  user_kwargs['langchain_action'] = LangChainAction.QUERY.value
 
 
1038
 
1039
  set1 = set(list(default_kwargs1.keys()))
1040
  set2 = set(eval_func_param_names)
@@ -1216,6 +1228,7 @@ def go_gradio(**kwargs):
1216
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
1217
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1218
  langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
 
1219
  document_subset1 = args_list[eval_func_param_names.index('document_subset')]
1220
  document_choice1 = args_list[eval_func_param_names.index('document_choice')]
1221
  if not prompt_type1:
@@ -1312,6 +1325,7 @@ def go_gradio(**kwargs):
1312
  args_list = args_list[:-3] # only keep rest needed for evaluate()
1313
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1314
  langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
 
1315
  document_subset1 = args_list[eval_func_param_names.index('document_subset')]
1316
  document_choice1 = args_list[eval_func_param_names.index('document_choice')]
1317
  if not history:
 
58
  from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
59
  ping, get_short_name, makedirs, get_kwargs, remove, system_info, ping_gpu, get_url, get_local_ip
60
  from gen import get_model, languages_covered, evaluate, score_qa, langchain_modes, inputs_kwargs_list, scratch_base_dir, \
61
+ get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions, langchain_agents_list
62
  from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults
63
 
64
  from apscheduler.schedulers.background import BackgroundScheduler
 
101
  db_type = kwargs['db_type']
102
  visible_langchain_modes = kwargs['visible_langchain_modes']
103
  visible_langchain_actions = kwargs['visible_langchain_actions']
104
+ visible_langchain_agents = kwargs['visible_langchain_agents']
105
  allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
106
  allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
107
  enable_sources_list = kwargs['enable_sources_list']
 
362
  value=allowed_actions[0] if len(allowed_actions) > 0 else None,
363
  label="Action",
364
  visible=True)
365
+ allowed_agents = [x for x in langchain_agents_list if x in visible_langchain_agents]
366
+ langchain_agents = gr.Dropdown(
367
+ langchain_agents_list,
368
+ value=kwargs['langchain_agents'],
369
+ label="Agents",
370
+ multiselect=True,
371
+ interactive=True,
372
+ visible=False) # WIP
373
  col_tabs = gr.Column(elem_id="col_container", scale=10)
374
  with (col_tabs, gr.Tabs()):
375
  with gr.TabItem("Chat"):
 
478
  value=None,
479
  interactive=True,
480
  multiselect=False,
481
+ visible=True,
482
  )
483
  with gr.Column(scale=4):
484
  pass
 
1045
  user_kwargs['langchain_mode'] = 'Disabled'
1046
  if 'langchain_action' not in user_kwargs:
1047
  user_kwargs['langchain_action'] = LangChainAction.QUERY.value
1048
+ if 'langchain_agents' not in user_kwargs:
1049
+ user_kwargs['langchain_agents'] = []
1050
 
1051
  set1 = set(list(default_kwargs1.keys()))
1052
  set2 = set(eval_func_param_names)
 
1228
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
1229
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1230
  langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
1231
+ langchain_agents1 = args_list[eval_func_param_names.index('langchain_agents')]
1232
  document_subset1 = args_list[eval_func_param_names.index('document_subset')]
1233
  document_choice1 = args_list[eval_func_param_names.index('document_choice')]
1234
  if not prompt_type1:
 
1325
  args_list = args_list[:-3] # only keep rest needed for evaluate()
1326
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1327
  langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
1328
+ langchain_agents1 = args_list[eval_func_param_names.index('langchain_agents')]
1329
  document_subset1 = args_list[eval_func_param_names.index('document_subset')]
1330
  document_choice1 = args_list[eval_func_param_names.index('document_choice')]
1331
  if not history:
prompter.py CHANGED
@@ -582,6 +582,20 @@ ASSISTANT:
582
  # if add space here, non-unique tokenization will often make LLM produce wrong output
583
  PreResponse = PreResponse
584
  # generates_leading_space = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
  else:
586
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
587
 
@@ -810,9 +824,20 @@ class Prompter(object):
810
  if oi > 0:
811
  # post fix outputs with seperator
812
  output += '\n'
 
813
  outputs[oi] = output
814
  # join all outputs, only one extra new line between outputs
815
  output = '\n'.join(outputs)
816
  if self.debug:
817
  print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
818
  return output
 
 
 
 
 
 
 
 
 
 
 
582
  # if add space here, non-unique tokenization will often make LLM produce wrong output
583
  PreResponse = PreResponse
584
  # generates_leading_space = True
585
+ elif prompt_type in [PromptType.guanaco.value, str(PromptType.guanaco.value),
586
+ PromptType.guanaco.name]:
587
+ # https://huggingface.co/TheBloke/guanaco-65B-GPTQ
588
+ promptA = promptB = "" if not (chat and reduced) else ''
589
+
590
+ PreInstruct = """### Human: """
591
+
592
+ PreInput = None
593
+
594
+ PreResponse = """### Assistant:"""
595
+ terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
596
+ chat_turn_sep = chat_sep = '\n'
597
+ humanstr = PreInstruct
598
+ botstr = PreResponse
599
  else:
600
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
601
 
 
824
  if oi > 0:
825
  # post fix outputs with seperator
826
  output += '\n'
827
+ output = self.fix_text(self.prompt_type, output)
828
  outputs[oi] = output
829
  # join all outputs, only one extra new line between outputs
830
  output = '\n'.join(outputs)
831
  if self.debug:
832
  print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
833
  return output
834
+
835
+ @staticmethod
836
+ def fix_text(prompt_type1, text1):
837
+ if prompt_type1 == 'human_bot':
838
+ # hack bug in vLLM with stopping, stops right, but doesn't return last token
839
+ hfix = '<human'
840
+ if text1.endswith(hfix):
841
+ text1 = text1[:-len(hfix)]
842
+ return text1
843
+
utils.py CHANGED
@@ -950,7 +950,6 @@ try:
950
  except (pkg_resources.DistributionNotFound, AssertionError):
951
  have_langchain = False
952
 
953
-
954
  import distutils.spawn
955
 
956
  have_tesseract = distutils.spawn.find_executable("tesseract")
@@ -985,3 +984,20 @@ except (pkg_resources.DistributionNotFound, AssertionError):
985
 
986
  # disable, hangs too often
987
  have_playwright = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
950
  except (pkg_resources.DistributionNotFound, AssertionError):
951
  have_langchain = False
952
 
 
953
  import distutils.spawn
954
 
955
  have_tesseract = distutils.spawn.find_executable("tesseract")
 
984
 
985
  # disable, hangs too often
986
  have_playwright = False
987
+
988
+
989
+ def set_openai(inference_server):
990
+ if inference_server.startswith('vllm'):
991
+ import openai_vllm
992
+ openai_vllm.api_key = "EMPTY"
993
+ inf_type = inference_server.split(':')[0]
994
+ ip_vllm = inference_server.split(':')[1]
995
+ port_vllm = inference_server.split(':')[2]
996
+ openai_vllm.api_base = f"http://{ip_vllm}:{port_vllm}/v1"
997
+ return openai_vllm, inf_type
998
+ else:
999
+ import openai
1000
+ openai.api_key = os.getenv("OPENAI_API_KEY")
1001
+ openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
1002
+ inf_type = inference_server
1003
+ return openai, inf_type