pseudotensor commited on
Commit
1ec3d3a
1 Parent(s): 30e5d19

Update with h2oGPT hash 221daabcabfa7f54b732394c15934a347da01079

Browse files
client_test.py CHANGED
@@ -48,6 +48,8 @@ import markdown # pip install markdown
48
  import pytest
49
  from bs4 import BeautifulSoup # pip install beautifulsoup4
50
 
 
 
51
  debug = False
52
 
53
  os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
@@ -62,7 +64,10 @@ def get_client(serialize=True):
62
  return client
63
 
64
 
65
- def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_tokens=50, langchain_mode='Disabled'):
 
 
 
66
  from collections import OrderedDict
67
  kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
68
  iinput='', # only for chat=True
@@ -71,6 +76,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token
71
  # but leave stream_output=False for simple input/output mode
72
  stream_output=stream_output,
73
  prompt_type=prompt_type,
 
74
  temperature=0.1,
75
  top_p=0.75,
76
  top_k=40,
@@ -86,9 +92,13 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token
86
  instruction_nochat=prompt if not chat else '',
87
  iinput_nochat='', # only for chat=False
88
  langchain_mode=langchain_mode,
89
- top_k_docs=4,
90
- document_choice=['All'],
 
 
91
  )
 
 
92
  if chat:
93
  # add chatbot output on end. Assumes serialize=False
94
  kwargs.update(dict(chatbot=[]))
@@ -97,8 +107,8 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token
97
 
98
 
99
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
100
- def test_client_basic():
101
- return run_client_nochat(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
102
 
103
 
104
  def run_client_nochat(prompt, prompt_type, max_new_tokens):
@@ -112,15 +122,110 @@ def run_client_nochat(prompt, prompt_type, max_new_tokens):
112
  )
113
  print("Raw client result: %s" % res, flush=True)
114
  res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  response=md_to_text(ast.literal_eval(res)['response']),
116
  sources=ast.literal_eval(res)['sources'])
117
  print(res_dict)
118
- return res_dict
 
 
 
 
 
 
119
 
120
 
121
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
122
- def test_client_chat():
123
- return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50,
 
124
  langchain_mode='Disabled')
125
 
126
 
@@ -133,6 +238,7 @@ def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchai
133
 
134
 
135
  def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
 
136
  res = client.predict(*tuple(args), api_name='/instruction')
137
  args[-1] += [res[-1]]
138
 
@@ -166,6 +272,46 @@ def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
166
  return res_dict, client
167
 
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  def md_to_text(md, do_md_to_text=True):
170
  if not do_md_to_text:
171
  return md
@@ -175,5 +321,16 @@ def md_to_text(md, do_md_to_text=True):
175
  return soup.get_text()
176
 
177
 
 
 
 
 
 
 
 
 
 
 
 
178
  if __name__ == '__main__':
179
- test_client_basic()
 
48
  import pytest
49
  from bs4 import BeautifulSoup # pip install beautifulsoup4
50
 
51
+ from enums import DocumentChoices
52
+
53
  debug = False
54
 
55
  os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
 
64
  return client
65
 
66
 
67
+ def get_args(prompt, prompt_type, chat=False, stream_output=False,
68
+ max_new_tokens=50,
69
+ top_k_docs=3,
70
+ langchain_mode='Disabled'):
71
  from collections import OrderedDict
72
  kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
73
  iinput='', # only for chat=True
 
76
  # but leave stream_output=False for simple input/output mode
77
  stream_output=stream_output,
78
  prompt_type=prompt_type,
79
+ prompt_dict='',
80
  temperature=0.1,
81
  top_p=0.75,
82
  top_k=40,
 
92
  instruction_nochat=prompt if not chat else '',
93
  iinput_nochat='', # only for chat=False
94
  langchain_mode=langchain_mode,
95
+ top_k_docs=top_k_docs,
96
+ chunk=True,
97
+ chunk_size=512,
98
+ document_choice=[DocumentChoices.All_Relevant.name],
99
  )
100
+ from generate import eval_func_param_names
101
+ assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
102
  if chat:
103
  # add chatbot output on end. Assumes serialize=False
104
  kwargs.update(dict(chatbot=[]))
 
107
 
108
 
109
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
110
+ def test_client_basic(prompt_type='human_bot'):
111
+ return run_client_nochat(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
112
 
113
 
114
  def run_client_nochat(prompt, prompt_type, max_new_tokens):
 
122
  )
123
  print("Raw client result: %s" % res, flush=True)
124
  res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
125
+ response=md_to_text(res))
126
+ print(res_dict)
127
+ return res_dict, client
128
+
129
+
130
+ @pytest.mark.skip(reason="For manual use against some server, no server launched")
131
+ def test_client_basic_api(prompt_type='human_bot'):
132
+ return run_client_nochat_api(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
133
+
134
+
135
+ def run_client_nochat_api(prompt, prompt_type, max_new_tokens):
136
+ kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens)
137
+
138
+ api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
139
+ client = get_client(serialize=True)
140
+ res = client.predict(
141
+ str(dict(kwargs)),
142
+ api_name=api_name,
143
+ )
144
+ print("Raw client result: %s" % res, flush=True)
145
+ res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
146
+ response=md_to_text(ast.literal_eval(res)['response']),
147
+ sources=ast.literal_eval(res)['sources'])
148
+ print(res_dict)
149
+ return res_dict, client
150
+
151
+
152
+ @pytest.mark.skip(reason="For manual use against some server, no server launched")
153
+ def test_client_basic_api_lean(prompt_type='human_bot'):
154
+ return run_client_nochat_api_lean(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
155
+
156
+
157
+ def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens):
158
+ kwargs = dict(instruction_nochat=prompt)
159
+
160
+ api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
161
+ client = get_client(serialize=True)
162
+ res = client.predict(
163
+ str(dict(kwargs)),
164
+ api_name=api_name,
165
+ )
166
+ print("Raw client result: %s" % res, flush=True)
167
+ res_dict = dict(prompt=kwargs['instruction_nochat'],
168
+ response=md_to_text(ast.literal_eval(res)['response']),
169
+ sources=ast.literal_eval(res)['sources'])
170
+ print(res_dict)
171
+ return res_dict, client
172
+
173
+
174
+ @pytest.mark.skip(reason="For manual use against some server, no server launched")
175
+ def test_client_basic_api_lean_morestuff(prompt_type='human_bot'):
176
+ return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
177
+
178
+
179
+ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_new_tokens=512):
180
+ kwargs = dict(
181
+ instruction='',
182
+ iinput='',
183
+ context='',
184
+ stream_output=False,
185
+ prompt_type=prompt_type,
186
+ temperature=0.1,
187
+ top_p=0.75,
188
+ top_k=40,
189
+ num_beams=1,
190
+ max_new_tokens=256,
191
+ min_new_tokens=0,
192
+ early_stopping=False,
193
+ max_time=20,
194
+ repetition_penalty=1.0,
195
+ num_return_sequences=1,
196
+ do_sample=True,
197
+ chat=False,
198
+ instruction_nochat=prompt,
199
+ iinput_nochat='',
200
+ langchain_mode='Disabled',
201
+ top_k_docs=4,
202
+ document_choice=['All'],
203
+ )
204
+
205
+ api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
206
+ client = get_client(serialize=True)
207
+ res = client.predict(
208
+ str(dict(kwargs)),
209
+ api_name=api_name,
210
+ )
211
+ print("Raw client result: %s" % res, flush=True)
212
+ res_dict = dict(prompt=kwargs['instruction_nochat'],
213
  response=md_to_text(ast.literal_eval(res)['response']),
214
  sources=ast.literal_eval(res)['sources'])
215
  print(res_dict)
216
+ return res_dict, client
217
+
218
+
219
+ @pytest.mark.skip(reason="For manual use against some server, no server launched")
220
+ def test_client_chat(prompt_type='human_bot'):
221
+ return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
222
+ langchain_mode='Disabled')
223
 
224
 
225
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
226
+ def test_client_chat_stream(prompt_type='human_bot'):
227
+ return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
228
+ stream_output=True, max_new_tokens=512,
229
  langchain_mode='Disabled')
230
 
231
 
 
238
 
239
 
240
  def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
241
+ assert kwargs['chat'], "Chat mode only"
242
  res = client.predict(*tuple(args), api_name='/instruction')
243
  args[-1] += [res[-1]]
244
 
 
272
  return res_dict, client
273
 
274
 
275
+ @pytest.mark.skip(reason="For manual use against some server, no server launched")
276
+ def test_client_nochat_stream(prompt_type='human_bot'):
277
+ return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
278
+ stream_output=True, max_new_tokens=512,
279
+ langchain_mode='Disabled')
280
+
281
+
282
+ def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode):
283
+ client = get_client(serialize=False)
284
+
285
+ kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
286
+ max_new_tokens=max_new_tokens, langchain_mode=langchain_mode)
287
+ return run_client_gen(client, prompt, args, kwargs)
288
+
289
+
290
+ def run_client_gen(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
291
+ res_dict = kwargs
292
+ res_dict['prompt'] = prompt
293
+ if not kwargs['stream_output']:
294
+ res = client.predict(str(dict(kwargs)), api_name='/submit_nochat_api')
295
+ res_dict['response'] = res[0]
296
+ print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
297
+ return res_dict, client
298
+ else:
299
+ job = client.submit(str(dict(kwargs)), api_name='/submit_nochat_api')
300
+ while not job.done():
301
+ outputs_list = job.communicator.job.outputs
302
+ if outputs_list:
303
+ res = job.communicator.job.outputs[-1]
304
+ res_dict = ast.literal_eval(res)
305
+ print('Stream: %s' % res_dict['response'])
306
+ time.sleep(0.1)
307
+ res_list = job.outputs()
308
+ assert len(res_list) > 0, "No response, check server"
309
+ res = res_list[-1]
310
+ res_dict = ast.literal_eval(res)
311
+ print('Final: %s' % res_dict['response'])
312
+ return res_dict, client
313
+
314
+
315
  def md_to_text(md, do_md_to_text=True):
316
  if not do_md_to_text:
317
  return md
 
321
  return soup.get_text()
322
 
323
 
324
+ def run_client_many(prompt_type='human_bot'):
325
+ ret1, _ = test_client_chat(prompt_type=prompt_type)
326
+ ret2, _ = test_client_chat_stream(prompt_type=prompt_type)
327
+ ret3, _ = test_client_nochat_stream(prompt_type=prompt_type)
328
+ ret4, _ = test_client_basic(prompt_type=prompt_type)
329
+ ret5, _ = test_client_basic_api(prompt_type=prompt_type)
330
+ ret6, _ = test_client_basic_api_lean(prompt_type=prompt_type)
331
+ ret7, _ = test_client_basic_api_lean_morestuff(prompt_type=prompt_type)
332
+ return ret1, ret2, ret3, ret4, ret5, ret6, ret7
333
+
334
+
335
  if __name__ == '__main__':
336
+ run_client_many()
create_data.py CHANGED
@@ -567,7 +567,7 @@ def test_show_prompts():
567
  from prompter import generate_prompt
568
  for data_points in file_points:
569
  for data_point in data_points:
570
- print(generate_prompt(data_point, 'plain', False, False)[0])
571
 
572
 
573
  def test_get_open_datasets():
@@ -1571,7 +1571,7 @@ def test_check_stats_data():
1571
 
1572
  llama_type = False
1573
  tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b'
1574
- model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
1575
  local_files_only = False
1576
  resume_download = True
1577
  use_auth_token = False
 
567
  from prompter import generate_prompt
568
  for data_points in file_points:
569
  for data_point in data_points:
570
+ print(generate_prompt(data_point, 'plain', '', False, False, False)[0])
571
 
572
 
573
  def test_get_open_datasets():
 
1571
 
1572
  llama_type = False
1573
  tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b'
1574
+ model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type)
1575
  local_files_only = False
1576
  resume_download = True
1577
  use_auth_token = False
enums.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class PromptType(Enum):
5
+ custom = -1
6
+ plain = 0
7
+ instruct = 1
8
+ quality = 2
9
+ human_bot = 3
10
+ dai_faq = 4
11
+ summarize = 5
12
+ simple_instruct = 6
13
+ instruct_vicuna = 7
14
+ instruct_with_end = 8
15
+ human_bot_orig = 9
16
+ prompt_answer = 10
17
+ open_assistant = 11
18
+ wizard_lm = 12
19
+ wizard_mega = 13
20
+ instruct_vicuna2 = 14
21
+ instruct_vicuna3 = 15
22
+ wizard2 = 16
23
+ wizard3 = 17
24
+ instruct_simple = 18
25
+ wizard_vicuna = 19
26
+ openai = 20
27
+ openai_chat = 21
28
+ gptj = 22
29
+ prompt_answer_openllama = 23
30
+ vicuna11 = 24
31
+
32
+
33
+ class DocumentChoices(Enum):
34
+ All_Relevant = 0
35
+ All_Relevant_Only_Sources = 1
36
+ Only_All_Sources = 2
37
+ Just_LLM = 3
38
+
39
+
40
+ class LangChainMode(Enum):
41
+ """LangChain mode"""
42
+
43
+ DISABLED = "Disabled"
44
+ CHAT_LLM = "ChatLLM"
45
+ LLM = "LLM"
46
+ ALL = "All"
47
+ WIKI = "wiki"
48
+ WIKI_FULL = "wiki_full"
49
+ USER_DATA = "UserData"
50
+ MY_DATA = "MyData"
51
+ GITHUB_H2OGPT = "github h2oGPT"
52
+ H2O_DAI_DOCS = "DriverlessAI docs"
53
+
54
+
55
+ no_server_str = no_lora_str = no_model_str = '[None/Remove]'
56
+
57
+
58
+ # from site-packages/langchain/llms/openai.py, but needed since ChatOpenAI doesn't have this information
59
+ model_token_mapping = {
60
+ "gpt-4": 8192,
61
+ "gpt-4-0314": 8192,
62
+ "gpt-4-32k": 32768,
63
+ "gpt-4-32k-0314": 32768,
64
+ "gpt-3.5-turbo": 4096,
65
+ "gpt-3.5-turbo-16k": 16*1024,
66
+ "gpt-3.5-turbo-0301": 4096,
67
+ "text-ada-001": 2049,
68
+ "ada": 2049,
69
+ "text-babbage-001": 2040,
70
+ "babbage": 2049,
71
+ "text-curie-001": 2049,
72
+ "curie": 2049,
73
+ "davinci": 2049,
74
+ "text-davinci-003": 4097,
75
+ "text-davinci-002": 4097,
76
+ "code-davinci-002": 8001,
77
+ "code-davinci-001": 8001,
78
+ "code-cushman-002": 2048,
79
+ "code-cushman-001": 2048,
80
+ }
81
+
82
+
83
+ source_prefix = "Sources [Score | Link]:"
84
+ source_postfix = "End Sources<p>"
finetune.py CHANGED
@@ -5,8 +5,11 @@ from typing import List, Union
5
  import fire
6
  import numpy as np
7
 
 
 
 
8
  from loaders import get_loaders, get_tokenizer
9
- from prompter import generate_prompt, prompt_types
10
  from utils import get_githash, copy_code
11
  import torch
12
 
@@ -104,7 +107,6 @@ def train(
104
  save_total_limit: int = 3,
105
  add_eos_token: bool = False,
106
  ):
107
-
108
  if llama_flash_attn:
109
  # Need to call this before importing transformers.
110
  from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
@@ -129,10 +131,12 @@ def train(
129
  if not output_dir:
130
  output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
131
  if os.path.exists(output_dir) and not resume_from_checkpoint:
132
- raise FileExistsError(f"output_dir {output_dir} based on run_id {run_id} already exists. Please pick a different run_id.")
 
133
  else:
134
  if os.path.exists(output_dir) and not resume_from_checkpoint:
135
- raise FileExistsError(f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
 
136
  device_map = "auto"
137
 
138
  if save_code:
@@ -181,7 +185,7 @@ def train(
181
  log("num_gpus: %d" % gpus)
182
  log("max mem: %s" % max_memory)
183
 
184
- model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
185
 
186
  model = model_loader.from_pretrained(
187
  base_model,
@@ -398,7 +402,8 @@ def train(
398
  if train_data_mix_in:
399
  train_data = concatenate_datasets([train_data, train_data_mix_in])
400
  log("Tokenizing %s training rows" % train_data.num_rows)
401
- train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count() // torch.cuda.device_count())
 
402
  if drop_truncations:
403
  log("avoid keeping truncated cases to avoid contaminating model with truncation cases. Original size: %s" % train_data.num_rows)
404
  prune_long_sequences_func = partial(prune_long_sequences, cutoff_len=cutoff_len)
@@ -413,7 +418,8 @@ def train(
413
 
414
  if valid_data:
415
  log("Tokenizing %s validation rows" % valid_data.num_rows)
416
- valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count() // torch.cuda.device_count())
 
417
  val_set_size = len(valid_data)
418
  else:
419
  val_set_size = 0
@@ -468,7 +474,7 @@ def train(
468
  elif save_steps > eval_steps:
469
  # save steps must be round multiple of eval_steps
470
  save_steps0 = save_steps
471
- save_steps = max(1, (save_steps//eval_steps)) * eval_steps
472
  if save_steps0 != save_steps:
473
  log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps))
474
 
@@ -478,21 +484,21 @@ def train(
478
  label_ids = eval_preds.label_ids
479
  predictions = eval_preds.predictions
480
 
481
- #inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)
482
- #decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)
483
- #decoded_inputs = [pred.strip() for pred in decoded_inputs]
484
 
485
  label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
486
  # tokenizer behavior like generate time
487
  decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True,
488
- clean_up_tokenization_spaces=True)
489
  decoded_labels = [pred.strip() for pred in decoded_labels]
490
 
491
  predictions = np.argmax(predictions, -1)
492
  predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
493
  # tokenizer behavior like generate time
494
  decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True,
495
- clean_up_tokenization_spaces=True)
496
  decoded_predictions = [pred.strip() for pred in decoded_predictions]
497
 
498
  result = {}
@@ -541,8 +547,8 @@ def train(
541
  load_best_model_at_end=True if val_set_size > 0 else False,
542
  ddp_find_unused_parameters=False if ddp else None,
543
  group_by_length=group_by_length,
544
- #fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None,
545
- #fsdp_min_num_params=20000 if gpus > 1 and not ddp else None,
546
  report_to='tensorboard' if not neptune_run else 'neptune',
547
  ),
548
  data_collator=transformers.DataCollatorForSeq2Seq(
@@ -553,13 +559,6 @@ def train(
553
  )
554
  model.config.use_cache = False
555
 
556
- old_state_dict = model.state_dict
557
- from peft import get_peft_model_state_dict
558
-
559
- model.state_dict = (
560
- lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
561
- ).__get__(model, type(model))
562
-
563
  if torch.__version__ >= "2" and sys.platform != "win32":
564
  model = torch.compile(model)
565
  # WIP (not generally replacing layers until pytorch 2.1)
@@ -616,10 +615,12 @@ def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=F
616
  assert prompt_type is not None
617
  assert cutoff_len is not None
618
  assert tokenizer is not None
619
- full_prompt, _, _, _ = generate_prompt(data_point, prompt_type, False, False)
 
 
620
  tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
621
  if not train_on_inputs:
622
- user_prompt, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
623
  tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
624
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
625
  if add_eos_token:
@@ -638,7 +639,7 @@ def test_debug():
638
  fire.Fire(train)
639
 
640
 
641
- if __name__ == "__main__":
642
  CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
643
  CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
644
  log(f"""
@@ -665,6 +666,11 @@ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank
665
 
666
  if os.environ.get("LOCAL_RANK") is None:
667
  # then not using torchrun, so can't do distributed, ensure CVD set
668
- assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
 
669
 
670
  fire.Fire(train)
 
 
 
 
 
5
  import fire
6
  import numpy as np
7
 
8
+ if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
9
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
10
+
11
  from loaders import get_loaders, get_tokenizer
12
+ from prompter import generate_prompt, prompt_types, PromptType
13
  from utils import get_githash, copy_code
14
  import torch
15
 
 
107
  save_total_limit: int = 3,
108
  add_eos_token: bool = False,
109
  ):
 
110
  if llama_flash_attn:
111
  # Need to call this before importing transformers.
112
  from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
 
131
  if not output_dir:
132
  output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
133
  if os.path.exists(output_dir) and not resume_from_checkpoint:
134
+ raise FileExistsError(
135
+ f"output_dir {output_dir} based on run_id {run_id} already exists. Please pick a different run_id.")
136
  else:
137
  if os.path.exists(output_dir) and not resume_from_checkpoint:
138
+ raise FileExistsError(
139
+ f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
140
  device_map = "auto"
141
 
142
  if save_code:
 
185
  log("num_gpus: %d" % gpus)
186
  log("max mem: %s" % max_memory)
187
 
188
+ model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type)
189
 
190
  model = model_loader.from_pretrained(
191
  base_model,
 
402
  if train_data_mix_in:
403
  train_data = concatenate_datasets([train_data, train_data_mix_in])
404
  log("Tokenizing %s training rows" % train_data.num_rows)
405
+ train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun,
406
+ num_proc=os.cpu_count() // torch.cuda.device_count())
407
  if drop_truncations:
408
  log("avoid keeping truncated cases to avoid contaminating model with truncation cases. Original size: %s" % train_data.num_rows)
409
  prune_long_sequences_func = partial(prune_long_sequences, cutoff_len=cutoff_len)
 
418
 
419
  if valid_data:
420
  log("Tokenizing %s validation rows" % valid_data.num_rows)
421
+ valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt_fun,
422
+ num_proc=os.cpu_count() // torch.cuda.device_count())
423
  val_set_size = len(valid_data)
424
  else:
425
  val_set_size = 0
 
474
  elif save_steps > eval_steps:
475
  # save steps must be round multiple of eval_steps
476
  save_steps0 = save_steps
477
+ save_steps = max(1, (save_steps // eval_steps)) * eval_steps
478
  if save_steps0 != save_steps:
479
  log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps))
480
 
 
484
  label_ids = eval_preds.label_ids
485
  predictions = eval_preds.predictions
486
 
487
+ # inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)
488
+ # decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)
489
+ # decoded_inputs = [pred.strip() for pred in decoded_inputs]
490
 
491
  label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
492
  # tokenizer behavior like generate time
493
  decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True,
494
+ clean_up_tokenization_spaces=True)
495
  decoded_labels = [pred.strip() for pred in decoded_labels]
496
 
497
  predictions = np.argmax(predictions, -1)
498
  predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
499
  # tokenizer behavior like generate time
500
  decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True,
501
+ clean_up_tokenization_spaces=True)
502
  decoded_predictions = [pred.strip() for pred in decoded_predictions]
503
 
504
  result = {}
 
547
  load_best_model_at_end=True if val_set_size > 0 else False,
548
  ddp_find_unused_parameters=False if ddp else None,
549
  group_by_length=group_by_length,
550
+ # fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None,
551
+ # fsdp_min_num_params=20000 if gpus > 1 and not ddp else None,
552
  report_to='tensorboard' if not neptune_run else 'neptune',
553
  ),
554
  data_collator=transformers.DataCollatorForSeq2Seq(
 
559
  )
560
  model.config.use_cache = False
561
 
 
 
 
 
 
 
 
562
  if torch.__version__ >= "2" and sys.platform != "win32":
563
  model = torch.compile(model)
564
  # WIP (not generally replacing layers until pytorch 2.1)
 
615
  assert prompt_type is not None
616
  assert cutoff_len is not None
617
  assert tokenizer is not None
618
+ prompt_dict = '' # only for custom prompt_type
619
+ assert prompt_type != PromptType.custom.name, "custom not setup for finetune"
620
+ full_prompt, _, _, _, _ = generate_prompt(data_point, prompt_type, prompt_dict, False, False, False)
621
  tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
622
  if not train_on_inputs:
623
+ user_prompt, _, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, prompt_dict, False, False, False)
624
  tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
625
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
626
  if add_eos_token:
 
639
  fire.Fire(train)
640
 
641
 
642
+ def entrypoint_main():
643
  CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
644
  CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
645
  log(f"""
 
666
 
667
  if os.environ.get("LOCAL_RANK") is None:
668
  # then not using torchrun, so can't do distributed, ensure CVD set
669
+ assert os.environ.get(
670
+ "CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
671
 
672
  fire.Fire(train)
673
+
674
+
675
+ if __name__ == "__main__":
676
+ entrypoint_main()
generate.py CHANGED
The diff for this file is too large to render. See raw diff
 
gpt4all_llm.py CHANGED
@@ -1,23 +1,13 @@
1
  import inspect
2
  import os
3
- import sys
4
  from typing import Dict, Any, Optional, List
5
  from langchain.callbacks.manager import CallbackManagerForLLMRun
6
  from pydantic import root_validator
7
  from langchain.llms import gpt4all
8
  from dotenv import dotenv_values
9
 
10
-
11
- class FakeTokenizer:
12
-
13
- def encode(self, x, *args, **kwargs):
14
- return dict(input_ids=[x])
15
-
16
- def decode(self, x, *args, **kwargs):
17
- return x
18
-
19
- def __call__(self, x, *args, **kwargs):
20
- return self.encode(x, *args, **kwargs)
21
 
22
 
23
  def get_model_tokenizer_gpt4all(base_model, **kwargs):
@@ -73,9 +63,9 @@ class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
73
  pass
74
 
75
 
76
- def get_model_kwargs(env_kwargs, default_kwargs, cls):
77
  # default from class
78
- model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items()}
79
  # from our defaults
80
  model_kwargs.update(default_kwargs)
81
  # from user defaults
@@ -93,10 +83,14 @@ def get_llm_gpt4all(model_name,
93
  repetition_penalty=1.0,
94
  top_k=40,
95
  top_p=0.7,
96
- verbose=False):
 
 
 
 
 
97
  env_gpt4all_file = ".env_gpt4all"
98
  env_kwargs = dotenv_values(env_gpt4all_file)
99
- callbacks = [H2OStreamingStdOutCallbackHandler()]
100
  n_ctx = env_kwargs.pop('n_ctx', 2048 - max_new_tokens)
101
  default_kwargs = dict(context_erase=0.5,
102
  n_batch=1,
@@ -113,21 +107,23 @@ def get_llm_gpt4all(model_name,
113
  if model_name == 'llama':
114
  cls = H2OLlamaCpp
115
  model_path = env_kwargs.pop('model_path_llama') if model is None else model
116
- model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
117
- model_kwargs.update(dict(model_path=model_path, callbacks=callbacks))
118
  llm = cls(**model_kwargs)
119
  llm.client.verbose = verbose
120
  elif model_name == 'gpt4all_llama':
121
  cls = H2OGPT4All
122
  model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
123
- model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
124
- model_kwargs.update(dict(model=model_path, backend='llama', callbacks=callbacks))
 
125
  llm = cls(**model_kwargs)
126
  elif model_name == 'gptj':
127
  cls = H2OGPT4All
128
  model_path = env_kwargs.pop('model_path_gptj') if model is None else model
129
- model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
130
- model_kwargs.update(dict(model=model_path, backend='gptj', callbacks=callbacks))
 
131
  llm = cls(**model_kwargs)
132
  else:
133
  raise RuntimeError("No such model_name %s" % model_name)
@@ -136,6 +132,7 @@ def get_llm_gpt4all(model_name,
136
 
137
  class H2OGPT4All(gpt4all.GPT4All):
138
  model: Any
 
139
  """Path to the pre-trained GPT4All model file."""
140
 
141
  @root_validator()
@@ -155,9 +152,16 @@ class H2OGPT4All(gpt4all.GPT4All):
155
  model_type=values["backend"],
156
  allow_download=False,
157
  )
 
 
 
158
  else:
159
  values["client"] = values["model"]
160
- values["backend"] = values["client"].model.model_type
 
 
 
 
161
 
162
  except ImportError:
163
  raise ValueError(
@@ -171,12 +175,19 @@ class H2OGPT4All(gpt4all.GPT4All):
171
  prompt: str,
172
  stop: Optional[List[str]] = None,
173
  run_manager: Optional[CallbackManagerForLLMRun] = None,
 
174
  ) -> str:
175
  # Roughly 4 chars per token if natural language
176
  prompt = prompt[-self.n_ctx * 4:]
 
 
 
 
 
177
  verbose = False
178
  if verbose:
179
  print("_call prompt: %s" % prompt, flush=True)
 
180
  return super()._call(prompt, stop=stop, run_manager=run_manager)
181
 
182
 
@@ -185,6 +196,7 @@ from langchain.llms import LlamaCpp
185
 
186
  class H2OLlamaCpp(LlamaCpp):
187
  model_path: Any
 
188
  """Path to the pre-trained GPT4All model file."""
189
 
190
  @root_validator()
@@ -236,9 +248,12 @@ class H2OLlamaCpp(LlamaCpp):
236
  prompt: str,
237
  stop: Optional[List[str]] = None,
238
  run_manager: Optional[CallbackManagerForLLMRun] = None,
 
239
  ) -> str:
240
  verbose = False
241
  # tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
 
 
242
  prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
243
  num_prompt_tokens = len(prompt_tokens)
244
  if num_prompt_tokens > self.n_ctx:
@@ -250,6 +265,33 @@ class H2OLlamaCpp(LlamaCpp):
250
  prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
251
  num_prompt_tokens2 = len(prompt_tokens2)
252
  print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
 
 
 
 
 
253
  if verbose:
254
  print("_call prompt: %s" % prompt, flush=True)
255
- return super()._call(prompt, stop=stop, run_manager=run_manager)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import inspect
2
  import os
3
+ from functools import partial
4
  from typing import Dict, Any, Optional, List
5
  from langchain.callbacks.manager import CallbackManagerForLLMRun
6
  from pydantic import root_validator
7
  from langchain.llms import gpt4all
8
  from dotenv import dotenv_values
9
 
10
+ from utils import FakeTokenizer
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  def get_model_tokenizer_gpt4all(base_model, **kwargs):
 
63
  pass
64
 
65
 
66
+ def get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=[]):
67
  # default from class
68
+ model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items() if k not in exclude_list}
69
  # from our defaults
70
  model_kwargs.update(default_kwargs)
71
  # from user defaults
 
83
  repetition_penalty=1.0,
84
  top_k=40,
85
  top_p=0.7,
86
+ streaming=False,
87
+ callbacks=None,
88
+ prompter=None,
89
+ verbose=False,
90
+ ):
91
+ assert prompter is not None
92
  env_gpt4all_file = ".env_gpt4all"
93
  env_kwargs = dotenv_values(env_gpt4all_file)
 
94
  n_ctx = env_kwargs.pop('n_ctx', 2048 - max_new_tokens)
95
  default_kwargs = dict(context_erase=0.5,
96
  n_batch=1,
 
107
  if model_name == 'llama':
108
  cls = H2OLlamaCpp
109
  model_path = env_kwargs.pop('model_path_llama') if model is None else model
110
+ model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
111
+ model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming, prompter=prompter))
112
  llm = cls(**model_kwargs)
113
  llm.client.verbose = verbose
114
  elif model_name == 'gpt4all_llama':
115
  cls = H2OGPT4All
116
  model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
117
+ model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
118
+ model_kwargs.update(
119
+ dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming, prompter=prompter))
120
  llm = cls(**model_kwargs)
121
  elif model_name == 'gptj':
122
  cls = H2OGPT4All
123
  model_path = env_kwargs.pop('model_path_gptj') if model is None else model
124
+ model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
125
+ model_kwargs.update(
126
+ dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming, prompter=prompter))
127
  llm = cls(**model_kwargs)
128
  else:
129
  raise RuntimeError("No such model_name %s" % model_name)
 
132
 
133
  class H2OGPT4All(gpt4all.GPT4All):
134
  model: Any
135
+ prompter: Any
136
  """Path to the pre-trained GPT4All model file."""
137
 
138
  @root_validator()
 
152
  model_type=values["backend"],
153
  allow_download=False,
154
  )
155
+ if values["n_threads"] is not None:
156
+ # set n_threads
157
+ values["client"].model.set_thread_count(values["n_threads"])
158
  else:
159
  values["client"] = values["model"]
160
+ try:
161
+ values["backend"] = values["client"].model_type
162
+ except AttributeError:
163
+ # The below is for compatibility with GPT4All Python bindings <= 0.2.3.
164
+ values["backend"] = values["client"].model.model_type
165
 
166
  except ImportError:
167
  raise ValueError(
 
175
  prompt: str,
176
  stop: Optional[List[str]] = None,
177
  run_manager: Optional[CallbackManagerForLLMRun] = None,
178
+ **kwargs,
179
  ) -> str:
180
  # Roughly 4 chars per token if natural language
181
  prompt = prompt[-self.n_ctx * 4:]
182
+
183
+ # use instruct prompting
184
+ data_point = dict(context='', instruction=prompt, input='')
185
+ prompt = self.prompter.generate_prompt(data_point)
186
+
187
  verbose = False
188
  if verbose:
189
  print("_call prompt: %s" % prompt, flush=True)
190
+ # FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout
191
  return super()._call(prompt, stop=stop, run_manager=run_manager)
192
 
193
 
 
196
 
197
  class H2OLlamaCpp(LlamaCpp):
198
  model_path: Any
199
+ prompter: Any
200
  """Path to the pre-trained GPT4All model file."""
201
 
202
  @root_validator()
 
248
  prompt: str,
249
  stop: Optional[List[str]] = None,
250
  run_manager: Optional[CallbackManagerForLLMRun] = None,
251
+ **kwargs,
252
  ) -> str:
253
  verbose = False
254
  # tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
255
+ # still have to avoid crazy sizes, else hit llama_tokenize: too many tokens -- might still hit, not fatal
256
+ prompt = prompt[-self.n_ctx * 4:]
257
  prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
258
  num_prompt_tokens = len(prompt_tokens)
259
  if num_prompt_tokens > self.n_ctx:
 
265
  prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
266
  num_prompt_tokens2 = len(prompt_tokens2)
267
  print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
268
+
269
+ # use instruct prompting
270
+ data_point = dict(context='', instruction=prompt, input='')
271
+ prompt = self.prompter.generate_prompt(data_point)
272
+
273
  if verbose:
274
  print("_call prompt: %s" % prompt, flush=True)
275
+
276
+ if self.streaming:
277
+ text_callback = None
278
+ if run_manager:
279
+ text_callback = partial(
280
+ run_manager.on_llm_new_token, verbose=self.verbose
281
+ )
282
+ # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
283
+ if text_callback:
284
+ text_callback(prompt)
285
+ text = ""
286
+ for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager):
287
+ text_chunk = token["choices"][0]["text"]
288
+ # self.stream already calls text_callback
289
+ # if text_callback:
290
+ # text_callback(text_chunk)
291
+ text += text_chunk
292
+ return text
293
+ else:
294
+ params = self._get_parameters(stop)
295
+ params = {**params, **kwargs}
296
+ result = self.client(prompt=prompt, **params)
297
+ return result["choices"][0]["text"]
gpt_langchain.py CHANGED
@@ -1,27 +1,34 @@
 
1
  import glob
2
  import inspect
3
  import os
4
  import pathlib
5
  import pickle
6
- import queue
7
  import shutil
8
  import subprocess
9
- import sys
10
  import tempfile
 
11
  import traceback
 
12
  import uuid
13
  import zipfile
14
  from collections import defaultdict
15
  from datetime import datetime
16
  from functools import reduce
17
  from operator import concat
 
18
 
19
- from joblib import Parallel, delayed
 
 
20
  from tqdm import tqdm
21
 
22
- from prompter import non_hf_types
 
 
23
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
24
- get_device, ProgressParallel, remove, hash_file
 
25
 
26
  import_matplotlib()
27
 
@@ -36,19 +43,22 @@ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
36
  from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
37
  UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
38
  EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
39
- UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader
40
- from langchain.text_splitter import RecursiveCharacterTextSplitter
41
  from langchain.chains.question_answering import load_qa_chain
42
  from langchain.docstore.document import Document
43
- from langchain import PromptTemplate
44
  from langchain.vectorstores import Chroma
45
 
46
 
47
- def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directory="db_dir", langchain_mode='notset',
 
 
48
  collection_name=None,
49
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
50
  if not sources:
51
  return None
 
52
  # get embedding model
53
  embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
54
  assert collection_name is not None or langchain_mode != 'notset'
@@ -59,29 +69,41 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directo
59
  if db_type == 'faiss':
60
  from langchain.vectorstores import FAISS
61
  db = FAISS.from_documents(sources, embedding)
62
-
63
  elif db_type == 'weaviate':
64
  import weaviate
65
  from weaviate.embedded import EmbeddedOptions
66
  from langchain.vectorstores import Weaviate
67
 
68
- # TODO: add support for connecting via docker compose
69
- client = weaviate.Client(
70
- embedded_options=EmbeddedOptions()
71
- )
 
 
72
  index_name = collection_name.capitalize()
73
  db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False,
74
  index_name=index_name)
75
-
76
  elif db_type == 'chroma':
77
  assert persist_directory is not None
78
  os.makedirs(persist_directory, exist_ok=True)
79
- db = Chroma.from_documents(documents=sources,
80
- embedding=embedding,
81
- persist_directory=persist_directory,
82
- collection_name=collection_name,
83
- anonymized_telemetry=False)
84
- db.persist()
 
 
 
 
 
 
 
 
 
 
 
 
85
  else:
86
  raise RuntimeError("No such db_type=%s" % db_type)
87
 
@@ -104,7 +126,10 @@ def _get_unique_sources_in_weaviate(db):
104
 
105
  def add_to_db(db, sources, db_type='faiss',
106
  avoid_dup_by_file=False,
107
- avoid_dup_by_content=True):
 
 
 
108
  num_new_sources = len(sources)
109
  if not sources:
110
  return db, num_new_sources, []
@@ -120,7 +145,7 @@ def add_to_db(db, sources, db_type='faiss',
120
  return db, num_new_sources, []
121
  db.add_documents(documents=sources)
122
  elif db_type == 'chroma':
123
- collection = db.get()
124
  # files we already have:
125
  metadata_files = set([x['source'] for x in collection['metadatas']])
126
  if avoid_dup_by_file:
@@ -135,11 +160,15 @@ def add_to_db(db, sources, db_type='faiss',
135
  [x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]])
136
  # avoid sources with same hash
137
  sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids]
 
 
 
138
  # get new file names that match existing file names. delete existing files we are overridding
139
  dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files])
140
  print("Removing %s duplicate files from db because ingesting those as new documents" % len(
141
  dup_metadata_files), flush=True)
142
- client_collection = db._client.get_collection(name=db._collection.name)
 
143
  for dup_file in dup_metadata_files:
144
  dup_file_meta = dict(source=dup_file)
145
  try:
@@ -151,6 +180,8 @@ def add_to_db(db, sources, db_type='faiss',
151
  return db, num_new_sources, []
152
  db.add_documents(documents=sources)
153
  db.persist()
 
 
154
  else:
155
  raise RuntimeError("No such db_type=%s" % db_type)
156
 
@@ -165,10 +196,13 @@ def create_or_update_db(db_type, persist_directory, collection_name,
165
  import weaviate
166
  from weaviate.embedded import EmbeddedOptions
167
 
168
- # TODO: add support for connecting via docker compose
169
- client = weaviate.Client(
170
- embedded_options=EmbeddedOptions()
171
- )
 
 
 
172
  index_name = collection_name.replace(' ', '_').capitalize()
173
  if client.schema.exists(index_name) and not add_if_exists:
174
  client.schema.delete_class(index_name)
@@ -205,14 +239,20 @@ def get_embedding(use_openai_embedding, hf_embedding_model="sentence-transformer
205
  if use_openai_embedding:
206
  assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY"
207
  from langchain.embeddings import OpenAIEmbeddings
208
- embedding = OpenAIEmbeddings()
209
  else:
210
  # to ensure can fork without deadlock
211
  from langchain.embeddings import HuggingFaceEmbeddings
212
 
213
  device, torch_dtype, context_class = get_device_dtype()
214
  model_kwargs = dict(device=device)
215
- embedding = HuggingFaceEmbeddings(model_name=hf_embedding_model, model_kwargs=model_kwargs)
 
 
 
 
 
 
216
  return embedding
217
 
218
 
@@ -226,63 +266,481 @@ def get_answer_from_sources(chain, sources, question):
226
  )["output_text"]
227
 
228
 
229
- def get_llm(use_openai_model=False, model_name=None, model=None,
230
- tokenizer=None, stream_output=False,
231
- max_new_tokens=256,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  temperature=0.1,
233
- repetition_penalty=1.0,
234
  top_k=40,
235
  top_p=0.7,
 
 
 
 
 
 
 
236
  prompt_type=None,
 
237
  prompter=None,
 
238
  verbose=False,
239
  ):
240
- if use_openai_model:
241
- from langchain.llms import OpenAI
242
- llm = OpenAI(temperature=0)
243
- model_name = 'openai'
244
- streamer = None
245
- prompt_type = 'plain'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  elif model_name in non_hf_types:
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  from gpt4all_llm import get_llm_gpt4all
248
  llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
249
  temperature=temperature,
250
  repetition_penalty=repetition_penalty,
251
  top_k=top_k,
252
  top_p=top_p,
 
253
  verbose=verbose,
 
 
254
  )
255
- streamer = None
256
- prompt_type = 'plain'
257
  else:
258
- from transformers import AutoTokenizer, AutoModelForCausalLM
259
-
260
  if model is None:
261
  # only used if didn't pass model in
262
- assert model_name is None
263
  assert tokenizer is None
264
  prompt_type = 'human_bot'
265
- model_name = 'h2oai/h2ogpt-oasst1-512-12b'
266
- # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
267
- # model_name = 'h2oai/h2ogpt-oasst1-512-20b'
268
- tokenizer = AutoTokenizer.from_pretrained(model_name)
269
- device, torch_dtype, context_class = get_device_dtype()
270
-
271
- with context_class(device):
272
- load_8bit = True
273
- # FIXME: for now not to spread across hetero GPUs
274
- # device_map={"": 0} if load_8bit and device == 'cuda' else "auto"
275
- device_map = {"": 0} if device == 'cuda' else "auto"
276
- model = AutoModelForCausalLM.from_pretrained(model_name,
277
- device_map=device_map,
278
- torch_dtype=torch_dtype,
279
- load_in_8bit=load_8bit)
280
 
281
  max_max_tokens = tokenizer.model_max_length
282
- gen_kwargs = dict(max_new_tokens=max_new_tokens,
 
 
 
 
 
 
 
 
 
 
283
  return_full_text=True,
284
- early_stopping=False,
285
- handle_long_generation='hole')
286
 
287
  if stream_output:
288
  skip_prompt = False
@@ -297,10 +755,12 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
297
  pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
298
  prompter=prompter,
299
  prompt_type=prompt_type,
300
- sanitize_bot_response=True,
 
301
  chat=False, stream_output=stream_output,
302
  tokenizer=tokenizer,
303
- max_input_tokens=max_max_tokens - max_new_tokens,
 
304
  **gen_kwargs)
305
  # pipe.task = "text-generation"
306
  # below makes it listen only to our prompt removal,
@@ -345,7 +805,7 @@ def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True):
345
  data = json.load(open(filename, "rt"))
346
  page_content = list(data["query"]["pages"].values())[0]["extract"]
347
  if take_head is not None and text_limit is not None:
348
- page_content = page_content[:text_limit] if take_head else page_content[:-text_limit]
349
  title_url = str(title).replace(' ', '_')
350
  return Document(
351
  page_content=page_content,
@@ -467,6 +927,21 @@ try:
467
  except (pkg_resources.DistributionNotFound, AssertionError):
468
  have_pymupdf = False
469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  image_types = ["png", "jpg", "jpeg"]
471
  non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
472
  "md", "html",
@@ -484,12 +959,13 @@ file_types = non_image_types + image_types
484
  def add_meta(docs1, file):
485
  file_extension = pathlib.Path(file).suffix
486
  hashid = hash_file(file)
487
- if not isinstance(docs1, list):
488
  docs1 = [docs1]
489
  [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
490
 
491
 
492
- def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, chunk=True, chunk_size=512,
 
493
  is_url=False, is_txt=False,
494
  enable_captions=True,
495
  captions_model=None,
@@ -525,9 +1001,25 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
525
  else:
526
  docs1 = []
527
  else:
 
 
528
  docs1 = UnstructuredURLLoader(urls=[file]).load()
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  [x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1]
530
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
 
531
  elif is_txt:
532
  base_path = "user_paste"
533
  source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10])
@@ -536,44 +1028,49 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
536
  f.write(file)
537
  metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
538
  doc1 = Document(page_content=file, metadata=metadata)
 
539
  elif file.lower().endswith('.html') or file.lower().endswith('.mhtml'):
540
  docs1 = UnstructuredHTMLLoader(file_path=file).load()
541
  add_meta(docs1, file)
542
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
 
543
  elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
544
  docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
545
  add_meta(docs1, file)
546
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
547
  elif file.lower().endswith('.odt'):
548
  docs1 = UnstructuredODTLoader(file_path=file).load()
549
  add_meta(docs1, file)
550
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
551
  elif file.lower().endswith('pptx') or file.lower().endswith('ppt'):
552
  docs1 = UnstructuredPowerPointLoader(file_path=file).load()
553
  add_meta(docs1, file)
554
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
 
555
  elif file.lower().endswith('.txt'):
556
  # use UnstructuredFileLoader ?
557
  docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
558
  # makes just one, but big one
559
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
 
560
  add_meta(doc1, file)
561
  elif file.lower().endswith('.rtf'):
562
  docs1 = UnstructuredRTFLoader(file).load()
563
  add_meta(docs1, file)
564
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
565
  elif file.lower().endswith('.md'):
566
  docs1 = UnstructuredMarkdownLoader(file).load()
567
  add_meta(docs1, file)
568
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
 
569
  elif file.lower().endswith('.enex'):
570
  docs1 = EverNoteLoader(file).load()
571
  add_meta(doc1, file)
572
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
573
  elif file.lower().endswith('.epub'):
574
  docs1 = UnstructuredEPubLoader(file).load()
575
  add_meta(docs1, file)
576
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
577
  elif file.lower().endswith('.jpeg') or file.lower().endswith('.jpg') or file.lower().endswith('.png'):
578
  docs1 = []
579
  if have_tesseract and enable_ocr:
@@ -603,7 +1100,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
603
  doci.metadata['source'] = doci.metadata['image_path']
604
  doci.metadata['hash'] = hash_file(doci.metadata['source'])
605
  if docs1:
606
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
607
  elif file.lower().endswith('.msg'):
608
  raise RuntimeError("Not supported, GPL3 license")
609
  # docs1 = OutlookMessageLoader(file).load()
@@ -612,14 +1109,14 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
612
  try:
613
  docs1 = UnstructuredEmailLoader(file).load()
614
  add_meta(docs1, file)
615
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
616
  except ValueError as e:
617
  if 'text/html content not found in email' in str(e):
618
  # e.g. plain/text dict key exists, but not
619
  # doc1 = TextLoader(file, encoding="utf8").load()
620
  docs1 = UnstructuredEmailLoader(file, content_source="text/plain").load()
621
  add_meta(docs1, file)
622
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
623
  else:
624
  raise
625
  # elif file.lower().endswith('.gcsdir'):
@@ -630,6 +1127,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
630
  with open(file, "r") as f:
631
  doc1 = Document(page_content=f.read(), metadata={"source": file})
632
  add_meta(doc1, file)
 
633
  elif file.lower().endswith('.pdf'):
634
  env_gpt4all_file = ".env_gpt4all"
635
  from dotenv import dotenv_values
@@ -638,11 +1136,19 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
638
  if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
639
  # GPL, only use if installed
640
  from langchain.document_loaders import PyMuPDFLoader
641
- doc1 = PyMuPDFLoader(file).load_and_split()
 
 
 
 
 
642
  else:
643
  # open-source fallback
644
- doc1 = PyPDFLoader(file).load_and_split()
 
 
645
  # Some PDFs return nothing or junk from PDFMinerLoader
 
646
  add_meta(doc1, file)
647
  elif file.lower().endswith('.csv'):
648
  doc1 = CSVLoader(file).load()
@@ -650,6 +1156,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
650
  elif file.lower().endswith('.py'):
651
  doc1 = PythonLoader(file).load()
652
  add_meta(doc1, file)
 
653
  elif file.lower().endswith('.toml'):
654
  doc1 = TomlLoader(file).load()
655
  add_meta(doc1, file)
@@ -657,7 +1164,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
657
  with open(file, "r") as f:
658
  docs1 = UnstructuredURLLoader(urls=f.readlines()).load()
659
  add_meta(docs1, file)
660
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
661
  elif file.lower().endswith('.zip'):
662
  with zipfile.ZipFile(file, 'r') as zip_ref:
663
  # don't put into temporary path, since want to keep references to docs inside zip
@@ -672,12 +1179,12 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
672
  # if list of length one, don't trust and chunk it
673
  if not isinstance(doc1, list):
674
  if chunk:
675
- docs = chunk_sources([doc1], chunk_size=chunk_size)
676
  else:
677
  docs = [doc1]
678
  elif isinstance(doc1, list) and len(doc1) == 1:
679
  if chunk:
680
- docs = chunk_sources(doc1, chunk_size=chunk_size)
681
  else:
682
  docs = doc1
683
  else:
@@ -687,7 +1194,8 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
687
  return docs
688
 
689
 
690
- def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True, chunk=True, chunk_size=512,
 
691
  is_url=False, is_txt=False,
692
  enable_captions=True,
693
  captions_model=None,
@@ -739,15 +1247,16 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
739
  existing_files=[],
740
  existing_hash_ids={},
741
  ):
 
742
  globs_image_types = []
743
  globs_non_image_types = []
744
  if not path_or_paths and not url and not text:
745
  return []
746
  elif url:
747
- globs_non_image_types = [url]
748
  elif text:
749
- globs_non_image_types = [text]
750
- elif isinstance(path_or_paths, str):
751
  # single path, only consume allowed files
752
  path = path_or_paths
753
  # Below globs should match patterns in file_to_doc()
@@ -756,8 +1265,11 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
756
  [globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
757
  for ftype in non_image_types]
758
  else:
 
 
759
  # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
760
- assert isinstance(path_or_paths, (list, tuple)), "Wrong type for path_or_paths: %s" % type(path_or_paths)
 
761
  # reform out of allowed types
762
  globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
763
  # could do below:
@@ -861,12 +1373,12 @@ def prep_langchain(persist_directory,
861
 
862
  if db_dir_exists and user_path is None:
863
  print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
864
- db = get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
865
  hf_embedding_model)
866
  else:
867
  if db_dir_exists and user_path is not None:
868
  print("Prep: persist_directory=%s exists, user_path=%s passed, adding any changed or new documents" % (
869
- persist_directory, user_path), flush=True)
870
  elif not db_dir_exists:
871
  print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True)
872
  db = None
@@ -912,24 +1424,78 @@ class FakeConsumer(object):
912
  posthog.Consumer = FakeConsumer
913
 
914
 
915
- def get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
916
- hf_embedding_model):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
917
  if load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
918
  os.path.join(persist_directory, 'index')):
919
- print("DO Loading db: %s" % langchain_mode, flush=True)
920
- embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
921
- from chromadb.config import Settings
922
- client_settings = Settings(anonymized_telemetry=False,
923
- chroma_db_impl="duckdb+parquet",
924
- persist_directory=persist_directory)
925
- db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
926
- collection_name=langchain_mode.replace(' ', '_'),
927
- client_settings=client_settings)
928
- print("DONE Loading db: %s" % langchain_mode, flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
  return db
930
  return None
931
 
932
 
 
 
 
 
 
 
 
 
933
  def make_db(**langchain_kwargs):
934
  func_names = list(inspect.signature(_make_db).parameters)
935
  missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
@@ -945,9 +1511,33 @@ def make_db(**langchain_kwargs):
945
  return _make_db(**langchain_kwargs)
946
 
947
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
948
  def _make_db(use_openai_embedding=False,
949
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
950
- first_para=False, text_limit=None, chunk=False, chunk_size=1024,
 
951
  langchain_mode=None,
952
  user_path=None,
953
  db_type='faiss',
@@ -955,19 +1545,13 @@ def _make_db(use_openai_embedding=False,
955
  db=None,
956
  n_jobs=-1,
957
  verbose=False):
958
- persist_directory = 'db_dir_%s' % langchain_mode # single place, no special names for each case
959
- if not db and load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
960
- os.path.join(persist_directory, 'index')):
961
- assert langchain_mode not in ['MyData'], "Should not load MyData db this way"
962
- print("Loading existing db", flush=True)
963
- embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
964
- from chromadb.config import Settings
965
- client_settings = Settings(anonymized_telemetry=False,
966
- chroma_db_impl="duckdb+parquet",
967
- persist_directory=persist_directory)
968
- db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
969
- collection_name=langchain_mode.replace(' ', '_'),
970
- client_settings=client_settings)
971
  sources = []
972
  if not db and langchain_mode not in ['MyData'] or \
973
  user_path is not None and \
@@ -992,24 +1576,24 @@ def _make_db(use_openai_embedding=False,
992
  sources1 = get_all_documents(small_test=small_test, n_jobs=os.cpu_count() // 2)
993
  print("Got new wiki", flush=True)
994
  if chunk:
995
- sources1 = chunk_sources(sources1, chunk_size=chunk_size)
996
  print("Chunked new wiki", flush=True)
997
  sources.extend(sources1)
998
  if langchain_mode in ['wiki', 'All', "'All'"]:
999
  sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
1000
  if chunk:
1001
- sources1 = chunk_sources(sources1, chunk_size=chunk_size)
1002
  sources.extend(sources1)
1003
  if langchain_mode in ['github h2oGPT', 'All', "'All'"]:
1004
  # sources = get_github_docs("dagster-io", "dagster")
1005
  sources1 = get_github_docs("h2oai", "h2ogpt")
1006
  # FIXME: always chunk for now
1007
- sources1 = chunk_sources(sources1, chunk_size=chunk_size)
1008
  sources.extend(sources1)
1009
  if langchain_mode in ['DriverlessAI docs', 'All', "'All'"]:
1010
  sources1 = get_dai_docs(from_hf=True)
1011
  if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
1012
- sources1 = chunk_sources(sources1, chunk_size=chunk_size)
1013
  sources.extend(sources1)
1014
  if langchain_mode in ['All', 'UserData']:
1015
  if user_path:
@@ -1023,6 +1607,8 @@ def _make_db(use_openai_embedding=False,
1023
  existing_files = []
1024
  existing_hash_ids = []
1025
  # chunk internally for speed over multiple docs
 
 
1026
  sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
1027
  existing_files=existing_files, existing_hash_ids=existing_hash_ids)
1028
  new_metadata_sources = set([x.metadata['source'] for x in sources1])
@@ -1066,7 +1652,9 @@ def _make_db(use_openai_embedding=False,
1066
  new_sources_metadata = [x.metadata for x in sources]
1067
  elif user_path is not None and langchain_mode in ['UserData']:
1068
  print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
1069
- db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type)
 
 
1070
  print("Existing db, added %s new sources from user_path=%s" % (num_new_sources, user_path), flush=True)
1071
  else:
1072
  new_sources_metadata = [x.metadata for x in sources]
@@ -1074,63 +1662,140 @@ def _make_db(use_openai_embedding=False,
1074
  return db, len(new_sources_metadata), new_sources_metadata
1075
 
1076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1077
  def get_existing_files(db):
1078
- collection = db.get()
1079
- metadata_sources = set([x['source'] for x in collection['metadatas']])
1080
  return metadata_sources
1081
 
1082
 
1083
  def get_existing_hash_ids(db):
1084
- collection = db.get()
1085
  # assume consistency, that any prior hashed source was single hashed file at the time among all source chunks
1086
- metadata_hash_ids = {x['source']: x.get('hashid') for x in collection['metadatas']}
1087
  return metadata_hash_ids
1088
 
1089
 
1090
- source_prefix = "Sources [Score | Link]:"
1091
- source_postfix = "End Sources<p>"
1092
-
1093
-
1094
  def run_qa_db(**kwargs):
1095
  func_names = list(inspect.signature(_run_qa_db).parameters)
1096
  # hard-coded defaults
1097
  kwargs['answer_with_sources'] = True
1098
- kwargs['sanitize_bot_response'] = True
1099
  kwargs['show_rank'] = False
1100
  missing_kwargs = [x for x in func_names if x not in kwargs]
1101
  assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
1102
  # only keep actual used
1103
  kwargs = {k: v for k, v in kwargs.items() if k in func_names}
1104
- return _run_qa_db(**kwargs)
 
 
 
1105
 
1106
 
1107
  def _run_qa_db(query=None,
1108
  use_openai_model=False, use_openai_embedding=False,
1109
- first_para=False, text_limit=None, k=4, chunk=False, chunk_size=1024,
1110
  user_path=None,
1111
  detect_user_path_changes_every_query=False,
1112
  db_type='faiss',
1113
- model_name=None, model=None, tokenizer=None,
1114
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1115
  stream_output=False,
1116
  prompter=None,
1117
  prompt_type=None,
 
1118
  answer_with_sources=True,
1119
  cut_distanct=1.1,
1120
- sanitize_bot_response=True,
1121
  show_rank=False,
1122
  load_db_if_exists=False,
1123
  db=None,
1124
- max_new_tokens=256,
1125
  temperature=0.1,
1126
- repetition_penalty=1.0,
1127
  top_k=40,
1128
  top_p=0.7,
 
 
 
 
 
 
 
1129
  langchain_mode=None,
1130
- document_choice=['All'],
1131
  n_jobs=-1,
1132
  verbose=False,
1133
- cli=False):
 
 
 
 
 
1134
  """
1135
 
1136
  :param query:
@@ -1149,39 +1814,63 @@ def _run_qa_db(query=None,
1149
  :param answer_with_sources
1150
  :return:
1151
  """
 
 
1152
  assert query is not None
1153
  assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate
1154
  if prompter is not None:
1155
  prompt_type = prompter.prompt_type
 
1156
  if model is not None:
1157
  assert prompt_type is not None
 
 
 
 
 
1158
  llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
1159
- model=model, tokenizer=tokenizer,
 
 
1160
  stream_output=stream_output,
1161
- max_new_tokens=max_new_tokens,
1162
  temperature=temperature,
1163
- repetition_penalty=repetition_penalty,
1164
  top_k=top_k,
1165
  top_p=top_p,
 
 
 
 
 
 
 
1166
  prompt_type=prompt_type,
 
1167
  prompter=prompter,
 
1168
  verbose=verbose,
1169
  )
1170
 
1171
- if model_name in non_hf_types:
1172
- # FIXME: for now, streams to stdout/stderr currently
1173
- stream_output = False
1174
-
1175
  use_context = False
1176
  scores = []
1177
  chain = None
1178
 
 
 
 
 
 
 
 
 
 
 
1179
  func_names = list(inspect.signature(get_similarity_chain).parameters)
1180
  sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
1181
  missing_kwargs = [x for x in func_names if x not in sim_kwargs]
1182
  assert not missing_kwargs, "Missing: %s" % missing_kwargs
1183
  docs, chain, scores, use_context = get_similarity_chain(**sim_kwargs)
1184
- if len(document_choice) > 0 and document_choice[0] == 'Only':
1185
  formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
1186
  yield formatted_doc_chunks, ''
1187
  return
@@ -1189,43 +1878,49 @@ def _run_qa_db(query=None,
1189
  # can only return if HF type
1190
  return
1191
 
1192
- if stream_output:
1193
- answer = None
1194
- assert streamer is not None
1195
- import queue
1196
- bucket = queue.Queue()
1197
- thread = EThread(target=chain, streamer=streamer, bucket=bucket)
1198
- thread.start()
1199
- outputs = ""
1200
- prompt = None # FIXME
1201
- try:
1202
- for new_text in streamer:
1203
- # print("new_text: %s" % new_text, flush=True)
1204
- if bucket.qsize() > 0 or thread.exc:
1205
- thread.join()
1206
- outputs += new_text
1207
- if prompter: # and False: # FIXME: pipeline can already use prompter
1208
- output1 = prompter.get_response(outputs, prompt=prompt,
1209
- sanitize_bot_response=sanitize_bot_response)
1210
- yield output1, ''
1211
- else:
1212
- yield outputs, ''
1213
- except BaseException:
1214
- # if any exception, raise that exception if was from thread, first
1215
- if thread.exc:
1216
- raise thread.exc
1217
- raise
1218
- finally:
1219
- # in case no exception and didn't join with thread yet, then join
1220
- if not thread.exc:
1221
- answer = thread.join()
1222
- # in case raise StopIteration or broke queue loop in streamer, but still have exception
1223
- if thread.exc:
1224
- raise thread.exc
1225
- # FIXME: answer is not string outputs from streamer. How to get actual final output?
1226
- # answer = outputs
1227
- else:
1228
- answer = chain()
 
 
 
 
 
 
1229
 
1230
  if not use_context:
1231
  ret = answer['output_text']
@@ -1239,22 +1934,31 @@ def _run_qa_db(query=None,
1239
 
1240
  def get_similarity_chain(query=None,
1241
  use_openai_model=False, use_openai_embedding=False,
1242
- first_para=False, text_limit=None, k=4, chunk=False, chunk_size=1024,
1243
  user_path=None,
1244
  detect_user_path_changes_every_query=False,
1245
  db_type='faiss',
1246
  model_name=None,
 
1247
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1248
  prompt_type=None,
 
1249
  cut_distanct=1.1,
1250
  load_db_if_exists=False,
1251
  db=None,
1252
  langchain_mode=None,
1253
- document_choice=['All'],
1254
  n_jobs=-1,
1255
  # beyond run_db_query:
1256
  llm=None,
 
1257
  verbose=False,
 
 
 
 
 
 
1258
  ):
1259
  # determine whether use of context out of docs is planned
1260
  if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
@@ -1266,10 +1970,14 @@ def get_similarity_chain(query=None,
1266
  use_context = True
1267
 
1268
  # https://github.com/hwchase17/langchain/issues/1946
1269
- # FIXME: Seems to way to get size of chroma db to limit k to avoid
1270
  # Chroma collection MyData contains fewer than 4 elements.
1271
  # type logger error
1272
- k_db = 1000 if db_type == 'chroma' else k # k=100 works ok too for
 
 
 
 
1273
 
1274
  # FIXME: For All just go over all dbs instead of a separate db for All
1275
  if not detect_user_path_changes_every_query and db is not None:
@@ -1279,7 +1987,8 @@ def get_similarity_chain(query=None,
1279
  user_path = None
1280
  db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
1281
  hf_embedding_model=hf_embedding_model,
1282
- first_para=first_para, text_limit=text_limit, chunk=chunk,
 
1283
  chunk_size=chunk_size,
1284
  langchain_mode=langchain_mode,
1285
  user_path=user_path,
@@ -1289,37 +1998,133 @@ def get_similarity_chain(query=None,
1289
  n_jobs=n_jobs,
1290
  verbose=verbose)
1291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1292
  if db and use_context:
1293
- if isinstance(document_choice, str):
1294
- # support string as well
1295
- document_choice = [document_choice]
1296
- if not isinstance(db, Chroma) or \
1297
- len(document_choice) == 0 or \
1298
- len(document_choice) <= 1 and document_choice[0] == 'All':
1299
- # treat empty list as All for now, not 'None'
1300
- filter_kwargs = {}
1301
- elif len(document_choice) > 0 and document_choice[0] == 'Only':
1302
- # Only means All docs, but only will return sources, not LLM response
1303
  filter_kwargs = {}
1304
  else:
 
1305
  if len(document_choice) >= 2:
1306
  or_filter = [{"source": {"$eq": x}} for x in document_choice]
1307
  filter_kwargs = dict(filter={"$or": or_filter})
1308
- elif len(document_choice) > 0:
 
1309
  one_filter = [{"source": {"$eq": x}} for x in document_choice][0]
1310
  filter_kwargs = dict(filter=one_filter)
1311
  else:
 
1312
  filter_kwargs = {}
1313
- if len(document_choice) == 1 and document_choice[0] == 'None':
1314
- k_db = 1
1315
- k = 0
1316
- docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:k]
1317
- # cut off so no high distance docs/sources considered
1318
- docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
1319
- scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
1320
- if len(scores) > 0 and verbose:
1321
- print("Distance: min: %s max: %s mean: %s median: %s" %
1322
- (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1323
  else:
1324
  docs = []
1325
  scores = []
@@ -1328,7 +2133,7 @@ def get_similarity_chain(query=None,
1328
  # if HF type and have no docs, can bail out
1329
  return docs, None, [], False
1330
 
1331
- if len(document_choice) > 0 and document_choice[0] == 'Only':
1332
  # no LLM use
1333
  return docs, None, [], False
1334
 
@@ -1348,19 +2153,11 @@ def get_similarity_chain(query=None,
1348
  if len(docs) == 0:
1349
  # avoid context == in prompt then
1350
  use_context = False
 
1351
 
1352
- if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
1353
  # instruct-like, rather than few-shot prompt_type='plain' as default
1354
  # but then sources confuse the model with how inserted among rest of text, so avoid
1355
- prefix = ""
1356
- if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
1357
- template = """%s{context}{question}""" % prefix
1358
- else:
1359
- template = """%s
1360
- ==
1361
- {context}
1362
- ==
1363
- {question}""" % prefix
1364
  prompt = PromptTemplate(
1365
  # input_variables=["summaries", "question"],
1366
  input_variables=["context", "question"],
@@ -1420,15 +2217,32 @@ def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, ve
1420
  return ret, extra
1421
 
1422
 
1423
- def chunk_sources(sources, chunk_size=1024):
1424
- source_chunks = []
1425
- # Below for known separator
1426
- # splitter = CharacterTextSplitter(separator=" ", chunk_size=chunk_size, chunk_overlap=0)
1427
- splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0)
1428
- for source in sources:
1429
- # print(source.metadata['source'], flush=True)
1430
- for chunky in splitter.split_text(source.page_content):
1431
- source_chunks.append(Document(page_content=chunky, metadata=source.metadata))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1432
  return source_chunks
1433
 
1434
 
@@ -1439,6 +2253,8 @@ def get_db_from_hf(dest=".", db_dir='db_dir_DriverlessAI_docs.zip'):
1439
  path_to_zip_file = hf_hub_download('h2oai/db_dirs', db_dir, token=token, repo_type='dataset')
1440
  import zipfile
1441
  with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
 
 
1442
  zip_ref.extractall(dest)
1443
  return path_to_zip_file
1444
 
@@ -1467,5 +2283,28 @@ def get_some_dbs_from_hf(dest='.', db_zips=None):
1467
  assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected
1468
 
1469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1470
  if __name__ == '__main__':
1471
  pass
 
1
+ import ast
2
  import glob
3
  import inspect
4
  import os
5
  import pathlib
6
  import pickle
 
7
  import shutil
8
  import subprocess
 
9
  import tempfile
10
+ import time
11
  import traceback
12
+ import types
13
  import uuid
14
  import zipfile
15
  from collections import defaultdict
16
  from datetime import datetime
17
  from functools import reduce
18
  from operator import concat
19
+ import filelock
20
 
21
+ from joblib import delayed
22
+ from langchain.callbacks import streaming_stdout
23
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
24
  from tqdm import tqdm
25
 
26
+ from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix
27
+ from generate import gen_hyper, get_model, SEED
28
+ from prompter import non_hf_types, PromptType, Prompter
29
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
30
+ get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer
31
+ from utils_langchain import StreamingGradioCallbackHandler
32
 
33
  import_matplotlib()
34
 
 
43
  from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
44
  UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
45
  EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
46
+ UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader
47
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
48
  from langchain.chains.question_answering import load_qa_chain
49
  from langchain.docstore.document import Document
50
+ from langchain import PromptTemplate, HuggingFaceTextGenInference
51
  from langchain.vectorstores import Chroma
52
 
53
 
54
+ def get_db(sources, use_openai_embedding=False, db_type='faiss',
55
+ persist_directory="db_dir", load_db_if_exists=True,
56
+ langchain_mode='notset',
57
  collection_name=None,
58
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
59
  if not sources:
60
  return None
61
+
62
  # get embedding model
63
  embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
64
  assert collection_name is not None or langchain_mode != 'notset'
 
69
  if db_type == 'faiss':
70
  from langchain.vectorstores import FAISS
71
  db = FAISS.from_documents(sources, embedding)
 
72
  elif db_type == 'weaviate':
73
  import weaviate
74
  from weaviate.embedded import EmbeddedOptions
75
  from langchain.vectorstores import Weaviate
76
 
77
+ if os.getenv('WEAVIATE_URL', None):
78
+ client = _create_local_weaviate_client()
79
+ else:
80
+ client = weaviate.Client(
81
+ embedded_options=EmbeddedOptions()
82
+ )
83
  index_name = collection_name.capitalize()
84
  db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False,
85
  index_name=index_name)
 
86
  elif db_type == 'chroma':
87
  assert persist_directory is not None
88
  os.makedirs(persist_directory, exist_ok=True)
89
+
90
+ # see if already actually have persistent db, and deal with possible changes in embedding
91
+ db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
92
+ hf_embedding_model, verbose=False)
93
+ if db is None:
94
+ db = Chroma.from_documents(documents=sources,
95
+ embedding=embedding,
96
+ persist_directory=persist_directory,
97
+ collection_name=collection_name,
98
+ anonymized_telemetry=False)
99
+ db.persist()
100
+ clear_embedding(db)
101
+ save_embed(db, use_openai_embedding, hf_embedding_model)
102
+ else:
103
+ # then just add
104
+ db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
105
+ use_openai_embedding=use_openai_embedding,
106
+ hf_embedding_model=hf_embedding_model)
107
  else:
108
  raise RuntimeError("No such db_type=%s" % db_type)
109
 
 
126
 
127
  def add_to_db(db, sources, db_type='faiss',
128
  avoid_dup_by_file=False,
129
+ avoid_dup_by_content=True,
130
+ use_openai_embedding=False,
131
+ hf_embedding_model=None):
132
+ assert hf_embedding_model is not None
133
  num_new_sources = len(sources)
134
  if not sources:
135
  return db, num_new_sources, []
 
145
  return db, num_new_sources, []
146
  db.add_documents(documents=sources)
147
  elif db_type == 'chroma':
148
+ collection = get_documents(db)
149
  # files we already have:
150
  metadata_files = set([x['source'] for x in collection['metadatas']])
151
  if avoid_dup_by_file:
 
160
  [x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]])
161
  # avoid sources with same hash
162
  sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids]
163
+ num_nohash = len([x for x in sources if not x.metadata.get('hashid')])
164
+ print("Found %s new sources (%d have no hash in original source,"
165
+ " so have to reprocess for migration to sources with hash)" % (len(sources), num_nohash), flush=True)
166
  # get new file names that match existing file names. delete existing files we are overridding
167
  dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files])
168
  print("Removing %s duplicate files from db because ingesting those as new documents" % len(
169
  dup_metadata_files), flush=True)
170
+ client_collection = db._client.get_collection(name=db._collection.name,
171
+ embedding_function=db._collection._embedding_function)
172
  for dup_file in dup_metadata_files:
173
  dup_file_meta = dict(source=dup_file)
174
  try:
 
180
  return db, num_new_sources, []
181
  db.add_documents(documents=sources)
182
  db.persist()
183
+ clear_embedding(db)
184
+ save_embed(db, use_openai_embedding, hf_embedding_model)
185
  else:
186
  raise RuntimeError("No such db_type=%s" % db_type)
187
 
 
196
  import weaviate
197
  from weaviate.embedded import EmbeddedOptions
198
 
199
+ if os.getenv('WEAVIATE_URL', None):
200
+ client = _create_local_weaviate_client()
201
+ else:
202
+ client = weaviate.Client(
203
+ embedded_options=EmbeddedOptions()
204
+ )
205
+
206
  index_name = collection_name.replace(' ', '_').capitalize()
207
  if client.schema.exists(index_name) and not add_if_exists:
208
  client.schema.delete_class(index_name)
 
239
  if use_openai_embedding:
240
  assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY"
241
  from langchain.embeddings import OpenAIEmbeddings
242
+ embedding = OpenAIEmbeddings(disallowed_special=())
243
  else:
244
  # to ensure can fork without deadlock
245
  from langchain.embeddings import HuggingFaceEmbeddings
246
 
247
  device, torch_dtype, context_class = get_device_dtype()
248
  model_kwargs = dict(device=device)
249
+ if 'instructor' in hf_embedding_model:
250
+ encode_kwargs = {'normalize_embeddings': True}
251
+ embedding = HuggingFaceInstructEmbeddings(model_name=hf_embedding_model,
252
+ model_kwargs=model_kwargs,
253
+ encode_kwargs=encode_kwargs)
254
+ else:
255
+ embedding = HuggingFaceEmbeddings(model_name=hf_embedding_model, model_kwargs=model_kwargs)
256
  return embedding
257
 
258
 
 
266
  )["output_text"]
267
 
268
 
269
+ """Wrapper around Huggingface text generation inference API."""
270
+ from functools import partial
271
+ from typing import Any, Dict, List, Optional, Set
272
+
273
+ from pydantic import Extra, Field, root_validator
274
+
275
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
276
+
277
+ """Wrapper around Huggingface text generation inference API."""
278
+ from functools import partial
279
+ from typing import Any, Dict, List, Optional
280
+
281
+ from pydantic import Extra, Field, root_validator
282
+
283
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
284
+ from langchain.llms.base import LLM
285
+
286
+
287
+ class GradioInference(LLM):
288
+ """
289
+ Gradio generation inference API.
290
+ """
291
+ inference_server_url: str = ""
292
+
293
+ temperature: float = 0.8
294
+ top_p: Optional[float] = 0.95
295
+ top_k: Optional[int] = None
296
+ num_beams: Optional[int] = 1
297
+ max_new_tokens: int = 512
298
+ min_new_tokens: int = 1
299
+ early_stopping: bool = False
300
+ max_time: int = 180
301
+ repetition_penalty: Optional[float] = None
302
+ num_return_sequences: Optional[int] = 1
303
+ do_sample: bool = False
304
+ chat_client: bool = False
305
+
306
+ return_full_text: bool = True
307
+ stream: bool = False
308
+ sanitize_bot_response: bool = False
309
+
310
+ prompter: Any = None
311
+ client: Any = None
312
+
313
+ class Config:
314
+ """Configuration for this pydantic object."""
315
+
316
+ extra = Extra.forbid
317
+
318
+ @root_validator()
319
+ def validate_environment(cls, values: Dict) -> Dict:
320
+ """Validate that python package exists in environment."""
321
+
322
+ try:
323
+ if values['client'] is None:
324
+ import gradio_client
325
+ values["client"] = gradio_client.Client(
326
+ values["inference_server_url"]
327
+ )
328
+ except ImportError:
329
+ raise ImportError(
330
+ "Could not import gradio_client python package. "
331
+ "Please install it with `pip install gradio_client`."
332
+ )
333
+ return values
334
+
335
+ @property
336
+ def _llm_type(self) -> str:
337
+ """Return type of llm."""
338
+ return "gradio_inference"
339
+
340
+ def _call(
341
+ self,
342
+ prompt: str,
343
+ stop: Optional[List[str]] = None,
344
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
345
+ **kwargs: Any,
346
+ ) -> str:
347
+ # NOTE: prompt here has no prompt_type (e.g. human: bot:) prompt injection,
348
+ # so server should get prompt_type or '', not plain
349
+ # This is good, so gradio server can also handle stopping.py conditions
350
+ # this is different than TGI server that uses prompter to inject prompt_type prompting
351
+ stream_output = self.stream
352
+ gr_client = self.client
353
+ client_langchain_mode = 'Disabled'
354
+ top_k_docs = 1
355
+ chunk = True
356
+ chunk_size = 512
357
+ client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True
358
+ iinput='', # only for chat=True
359
+ context='',
360
+ # streaming output is supported, loops over and outputs each generation in streaming mode
361
+ # but leave stream_output=False for simple input/output mode
362
+ stream_output=stream_output,
363
+ prompt_type=self.prompter.prompt_type,
364
+ prompt_dict='',
365
+
366
+ temperature=self.temperature,
367
+ top_p=self.top_p,
368
+ top_k=self.top_k,
369
+ num_beams=self.num_beams,
370
+ max_new_tokens=self.max_new_tokens,
371
+ min_new_tokens=self.min_new_tokens,
372
+ early_stopping=self.early_stopping,
373
+ max_time=self.max_time,
374
+ repetition_penalty=self.repetition_penalty,
375
+ num_return_sequences=self.num_return_sequences,
376
+ do_sample=self.do_sample,
377
+ chat=self.chat_client,
378
+
379
+ instruction_nochat=prompt if not self.chat_client else '',
380
+ iinput_nochat='', # only for chat=False
381
+ langchain_mode=client_langchain_mode,
382
+ top_k_docs=top_k_docs,
383
+ chunk=chunk,
384
+ chunk_size=chunk_size,
385
+ document_choice=[DocumentChoices.All_Relevant.name],
386
+ )
387
+ api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
388
+ if not stream_output:
389
+ res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name)
390
+ res_dict = ast.literal_eval(res)
391
+ text = res_dict['response']
392
+ return self.prompter.get_response(prompt + text, prompt=prompt,
393
+ sanitize_bot_response=self.sanitize_bot_response)
394
+ else:
395
+ text_callback = None
396
+ if run_manager:
397
+ text_callback = partial(
398
+ run_manager.on_llm_new_token, verbose=self.verbose
399
+ )
400
+
401
+ job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name)
402
+ text0 = ''
403
+ while not job.done():
404
+ outputs_list = job.communicator.job.outputs
405
+ if outputs_list:
406
+ res = job.communicator.job.outputs[-1]
407
+ res_dict = ast.literal_eval(res)
408
+ text = res_dict['response']
409
+ text = self.prompter.get_response(prompt + text, prompt=prompt,
410
+ sanitize_bot_response=self.sanitize_bot_response)
411
+ # FIXME: derive chunk from full for now
412
+ text_chunk = text[len(text0):]
413
+ # save old
414
+ text0 = text
415
+
416
+ if text_callback:
417
+ text_callback(text_chunk)
418
+
419
+ time.sleep(0.01)
420
+
421
+ # ensure get last output to avoid race
422
+ res_all = job.outputs()
423
+ if len(res_all) > 0:
424
+ res = res_all[-1]
425
+ res_dict = ast.literal_eval(res)
426
+ text = res_dict['response']
427
+ # FIXME: derive chunk from full for now
428
+ else:
429
+ # go with old if failure
430
+ text = text0
431
+ text_chunk = text[len(text0):]
432
+ if text_callback:
433
+ text_callback(text_chunk)
434
+ return self.prompter.get_response(prompt + text, prompt=prompt,
435
+ sanitize_bot_response=self.sanitize_bot_response)
436
+
437
+
438
+ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
439
+ max_new_tokens: int = 512
440
+ do_sample: bool = False
441
+ top_k: Optional[int] = None
442
+ top_p: Optional[float] = 0.95
443
+ typical_p: Optional[float] = 0.95
444
+ temperature: float = 0.8
445
+ repetition_penalty: Optional[float] = None
446
+ return_full_text: bool = False
447
+ stop_sequences: List[str] = Field(default_factory=list)
448
+ seed: Optional[int] = None
449
+ inference_server_url: str = ""
450
+ timeout: int = 300
451
+ headers: dict = None
452
+ stream: bool = False
453
+ sanitize_bot_response: bool = False
454
+ prompter: Any = None
455
+ tokenizer: Any = None
456
+ client: Any = None
457
+
458
+ @root_validator()
459
+ def validate_environment(cls, values: Dict) -> Dict:
460
+ """Validate that python package exists in environment."""
461
+
462
+ try:
463
+ if values['client'] is None:
464
+ import text_generation
465
+
466
+ values["client"] = text_generation.Client(
467
+ values["inference_server_url"],
468
+ timeout=values["timeout"],
469
+ headers=values["headers"],
470
+ )
471
+ except ImportError:
472
+ raise ImportError(
473
+ "Could not import text_generation python package. "
474
+ "Please install it with `pip install text_generation`."
475
+ )
476
+ return values
477
+
478
+ def _call(
479
+ self,
480
+ prompt: str,
481
+ stop: Optional[List[str]] = None,
482
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
483
+ **kwargs: Any,
484
+ ) -> str:
485
+ if stop is None:
486
+ stop = self.stop_sequences
487
+ else:
488
+ stop += self.stop_sequences
489
+
490
+ # HF inference server needs control over input tokens
491
+ assert self.tokenizer is not None
492
+ from h2oai_pipeline import H2OTextGenerationPipeline
493
+ prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
494
+
495
+ # NOTE: TGI server does not add prompting, so must do here
496
+ data_point = dict(context='', instruction=prompt, input='')
497
+ prompt = self.prompter.generate_prompt(data_point)
498
+
499
+ gen_server_kwargs = dict(do_sample=self.do_sample,
500
+ stop_sequences=stop,
501
+ max_new_tokens=self.max_new_tokens,
502
+ top_k=self.top_k,
503
+ top_p=self.top_p,
504
+ typical_p=self.typical_p,
505
+ temperature=self.temperature,
506
+ repetition_penalty=self.repetition_penalty,
507
+ return_full_text=self.return_full_text,
508
+ seed=self.seed,
509
+ )
510
+ gen_server_kwargs.update(kwargs)
511
+
512
+ # lower bound because client is re-used if multi-threading
513
+ self.client.timeout = max(300, self.timeout)
514
+
515
+ if not self.stream:
516
+ res = self.client.generate(
517
+ prompt,
518
+ **gen_server_kwargs,
519
+ )
520
+ if self.return_full_text:
521
+ gen_text = res.generated_text[len(prompt):]
522
+ else:
523
+ gen_text = res.generated_text
524
+ # remove stop sequences from the end of the generated text
525
+ for stop_seq in stop:
526
+ if stop_seq in gen_text:
527
+ gen_text = gen_text[:gen_text.index(stop_seq)]
528
+ text = prompt + gen_text
529
+ text = self.prompter.get_response(text, prompt=prompt,
530
+ sanitize_bot_response=self.sanitize_bot_response)
531
+ else:
532
+ text_callback = None
533
+ if run_manager:
534
+ text_callback = partial(
535
+ run_manager.on_llm_new_token, verbose=self.verbose
536
+ )
537
+ # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
538
+ if text_callback:
539
+ text_callback(prompt)
540
+ text = ""
541
+ # Note: Streaming ignores return_full_text=True
542
+ for response in self.client.generate_stream(prompt, **gen_server_kwargs):
543
+ text_chunk = response.token.text
544
+ text += text_chunk
545
+ text = self.prompter.get_response(prompt + text, prompt=prompt,
546
+ sanitize_bot_response=self.sanitize_bot_response)
547
+ # stream part
548
+ is_stop = False
549
+ for stop_seq in stop:
550
+ if stop_seq in response.token.text:
551
+ is_stop = True
552
+ break
553
+ if is_stop:
554
+ break
555
+ if not response.token.special:
556
+ if text_callback:
557
+ text_callback(response.token.text)
558
+ return text
559
+
560
+
561
+ from langchain.chat_models import ChatOpenAI
562
+
563
+
564
+ class H2OChatOpenAI(ChatOpenAI):
565
+ @classmethod
566
+ def all_required_field_names(cls) -> Set:
567
+ all_required_field_names = super(ChatOpenAI, cls).all_required_field_names()
568
+ all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty'})
569
+ return all_required_field_names
570
+
571
+
572
+ def get_llm(use_openai_model=False,
573
+ model_name=None,
574
+ model=None,
575
+ tokenizer=None,
576
+ inference_server=None,
577
+ stream_output=False,
578
+ do_sample=False,
579
  temperature=0.1,
 
580
  top_k=40,
581
  top_p=0.7,
582
+ num_beams=1,
583
+ max_new_tokens=256,
584
+ min_new_tokens=1,
585
+ early_stopping=False,
586
+ max_time=180,
587
+ repetition_penalty=1.0,
588
+ num_return_sequences=1,
589
  prompt_type=None,
590
+ prompt_dict=None,
591
  prompter=None,
592
+ sanitize_bot_response=False,
593
  verbose=False,
594
  ):
595
+ if use_openai_model or inference_server in ['openai', 'openai_chat']:
596
+ if use_openai_model and model_name is None:
597
+ model_name = "gpt-3.5-turbo"
598
+ if inference_server == 'openai':
599
+ from langchain.llms import OpenAI
600
+ cls = OpenAI
601
+ else:
602
+ cls = H2OChatOpenAI
603
+ callbacks = [StreamingGradioCallbackHandler()]
604
+ llm = cls(model_name=model_name,
605
+ temperature=temperature if do_sample else 0,
606
+ # FIXME: Need to count tokens and reduce max_new_tokens to fit like in generate.py
607
+ max_tokens=max_new_tokens,
608
+ top_p=top_p if do_sample else 1,
609
+ frequency_penalty=0,
610
+ presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
611
+ callbacks=callbacks if stream_output else None,
612
+ )
613
+ streamer = callbacks[0] if stream_output else None
614
+ if inference_server in ['openai', 'openai_chat']:
615
+ prompt_type = inference_server
616
+ else:
617
+ prompt_type = prompt_type or 'plain'
618
+ elif inference_server:
619
+ assert inference_server.startswith(
620
+ 'http'), "Malformed inference_server=%s. Did you add http:// in front?" % inference_server
621
+
622
+ from gradio_utils.grclient import GradioClient
623
+ from text_generation import Client as HFClient
624
+ if isinstance(model, GradioClient):
625
+ gr_client = model
626
+ hf_client = None
627
+ else:
628
+ gr_client = None
629
+ hf_client = model
630
+ assert isinstance(hf_client, HFClient)
631
+
632
+ inference_server, headers = get_hf_server(inference_server)
633
+
634
+ # quick sanity check to avoid long timeouts, just see if can reach server
635
+ requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10')))
636
+
637
+ callbacks = [StreamingGradioCallbackHandler()]
638
+ assert prompter is not None
639
+ stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
640
+
641
+ if gr_client:
642
+ chat_client = False
643
+ llm = GradioInference(
644
+ inference_server_url=inference_server,
645
+ return_full_text=True,
646
+
647
+ temperature=temperature,
648
+ top_p=top_p,
649
+ top_k=top_k,
650
+ num_beams=num_beams,
651
+ max_new_tokens=max_new_tokens,
652
+ min_new_tokens=min_new_tokens,
653
+ early_stopping=early_stopping,
654
+ max_time=max_time,
655
+ repetition_penalty=repetition_penalty,
656
+ num_return_sequences=num_return_sequences,
657
+ do_sample=do_sample,
658
+ chat_client=chat_client,
659
+
660
+ callbacks=callbacks if stream_output else None,
661
+ stream=stream_output,
662
+ prompter=prompter,
663
+ client=gr_client,
664
+ sanitize_bot_response=sanitize_bot_response,
665
+ )
666
+ elif hf_client:
667
+ llm = H2OHuggingFaceTextGenInference(
668
+ inference_server_url=inference_server,
669
+ do_sample=do_sample,
670
+ max_new_tokens=max_new_tokens,
671
+ repetition_penalty=repetition_penalty,
672
+ return_full_text=True,
673
+ seed=SEED,
674
+
675
+ stop_sequences=stop_sequences,
676
+ temperature=temperature,
677
+ top_k=top_k,
678
+ top_p=top_p,
679
+ # typical_p=top_p,
680
+ callbacks=callbacks if stream_output else None,
681
+ stream=stream_output,
682
+ prompter=prompter,
683
+ tokenizer=tokenizer,
684
+ client=hf_client,
685
+ timeout=max_time,
686
+ sanitize_bot_response=sanitize_bot_response,
687
+ )
688
+ else:
689
+ raise RuntimeError("No defined client")
690
+ streamer = callbacks[0] if stream_output else None
691
  elif model_name in non_hf_types:
692
+ if model_name == 'llama':
693
+ callbacks = [StreamingGradioCallbackHandler()]
694
+ streamer = callbacks[0] if stream_output else None
695
+ else:
696
+ # stream_output = False
697
+ # doesn't stream properly as generator, but at least
698
+ callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
699
+ streamer = None
700
+ if prompter:
701
+ prompt_type = prompter.prompt_type
702
+ else:
703
+ prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=False, stream_output=stream_output)
704
+ pass # assume inputted prompt_type is correct
705
  from gpt4all_llm import get_llm_gpt4all
706
  llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
707
  temperature=temperature,
708
  repetition_penalty=repetition_penalty,
709
  top_k=top_k,
710
  top_p=top_p,
711
+ callbacks=callbacks,
712
  verbose=verbose,
713
+ streaming=stream_output,
714
+ prompter=prompter,
715
  )
 
 
716
  else:
 
 
717
  if model is None:
718
  # only used if didn't pass model in
 
719
  assert tokenizer is None
720
  prompt_type = 'human_bot'
721
+ if model_name is None:
722
+ model_name = 'h2oai/h2ogpt-oasst1-512-12b'
723
+ # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
724
+ # model_name = 'h2oai/h2ogpt-oasst1-512-20b'
725
+ inference_server = ''
726
+ model, tokenizer, device = get_model(load_8bit=True, base_model=model_name,
727
+ inference_server=inference_server, gpu_id=0)
 
 
 
 
 
 
 
 
728
 
729
  max_max_tokens = tokenizer.model_max_length
730
+ gen_kwargs = dict(do_sample=do_sample,
731
+ temperature=temperature,
732
+ top_k=top_k,
733
+ top_p=top_p,
734
+ num_beams=num_beams,
735
+ max_new_tokens=max_new_tokens,
736
+ min_new_tokens=min_new_tokens,
737
+ early_stopping=early_stopping,
738
+ max_time=max_time,
739
+ repetition_penalty=repetition_penalty,
740
+ num_return_sequences=num_return_sequences,
741
  return_full_text=True,
742
+ handle_long_generation=None)
743
+ assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0
744
 
745
  if stream_output:
746
  skip_prompt = False
 
755
  pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
756
  prompter=prompter,
757
  prompt_type=prompt_type,
758
+ prompt_dict=prompt_dict,
759
+ sanitize_bot_response=sanitize_bot_response,
760
  chat=False, stream_output=stream_output,
761
  tokenizer=tokenizer,
762
+ # leave some room for 1 paragraph, even if min_new_tokens=0
763
+ max_input_tokens=max_max_tokens - max(min_new_tokens, 256),
764
  **gen_kwargs)
765
  # pipe.task = "text-generation"
766
  # below makes it listen only to our prompt removal,
 
805
  data = json.load(open(filename, "rt"))
806
  page_content = list(data["query"]["pages"].values())[0]["extract"]
807
  if take_head is not None and text_limit is not None:
808
+ page_content = page_content[:text_limit] if take_head else page_content[-text_limit:]
809
  title_url = str(title).replace(' ', '_')
810
  return Document(
811
  page_content=page_content,
 
927
  except (pkg_resources.DistributionNotFound, AssertionError):
928
  have_pymupdf = False
929
 
930
+ try:
931
+ assert pkg_resources.get_distribution('selenium') is not None
932
+ have_selenium = True
933
+ except (pkg_resources.DistributionNotFound, AssertionError):
934
+ have_selenium = False
935
+
936
+ try:
937
+ assert pkg_resources.get_distribution('playwright') is not None
938
+ have_playwright = True
939
+ except (pkg_resources.DistributionNotFound, AssertionError):
940
+ have_playwright = False
941
+
942
+ # disable, hangs too often
943
+ have_playwright = False
944
+
945
  image_types = ["png", "jpg", "jpeg"]
946
  non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
947
  "md", "html",
 
959
  def add_meta(docs1, file):
960
  file_extension = pathlib.Path(file).suffix
961
  hashid = hash_file(file)
962
+ if not isinstance(docs1, (list, tuple, types.GeneratorType)):
963
  docs1 = [docs1]
964
  [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
965
 
966
 
967
+ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
968
+ chunk=True, chunk_size=512,
969
  is_url=False, is_txt=False,
970
  enable_captions=True,
971
  captions_model=None,
 
1001
  else:
1002
  docs1 = []
1003
  else:
1004
+ if not (file.startswith("http://") or file.startswith("file://") or file.startswith("https://")):
1005
+ file = 'http://' + file
1006
  docs1 = UnstructuredURLLoader(urls=[file]).load()
1007
+ if len(docs1) == 0 and have_playwright:
1008
+ # then something went wrong, try another loader:
1009
+ from langchain.document_loaders import PlaywrightURLLoader
1010
+ docs1 = PlaywrightURLLoader(urls=[file]).load()
1011
+ if len(docs1) == 0 and have_selenium:
1012
+ # then something went wrong, try another loader:
1013
+ # but requires Chrome binary, else get: selenium.common.exceptions.WebDriverException: Message: unknown error: cannot find Chrome binary
1014
+ from langchain.document_loaders import SeleniumURLLoader
1015
+ from selenium.common.exceptions import WebDriverException
1016
+ try:
1017
+ docs1 = SeleniumURLLoader(urls=[file]).load()
1018
+ except WebDriverException as e:
1019
+ print("No web driver: %s" % str(e), flush=True)
1020
  [x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1]
1021
+ docs1 = clean_doc(docs1)
1022
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1023
  elif is_txt:
1024
  base_path = "user_paste"
1025
  source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10])
 
1028
  f.write(file)
1029
  metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
1030
  doc1 = Document(page_content=file, metadata=metadata)
1031
+ doc1 = clean_doc(doc1)
1032
  elif file.lower().endswith('.html') or file.lower().endswith('.mhtml'):
1033
  docs1 = UnstructuredHTMLLoader(file_path=file).load()
1034
  add_meta(docs1, file)
1035
+ docs1 = clean_doc(docs1)
1036
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.HTML)
1037
  elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
1038
  docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
1039
  add_meta(docs1, file)
1040
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1041
  elif file.lower().endswith('.odt'):
1042
  docs1 = UnstructuredODTLoader(file_path=file).load()
1043
  add_meta(docs1, file)
1044
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1045
  elif file.lower().endswith('pptx') or file.lower().endswith('ppt'):
1046
  docs1 = UnstructuredPowerPointLoader(file_path=file).load()
1047
  add_meta(docs1, file)
1048
+ docs1 = clean_doc(docs1)
1049
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1050
  elif file.lower().endswith('.txt'):
1051
  # use UnstructuredFileLoader ?
1052
  docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
1053
  # makes just one, but big one
1054
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1055
+ doc1 = clean_doc(doc1)
1056
  add_meta(doc1, file)
1057
  elif file.lower().endswith('.rtf'):
1058
  docs1 = UnstructuredRTFLoader(file).load()
1059
  add_meta(docs1, file)
1060
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1061
  elif file.lower().endswith('.md'):
1062
  docs1 = UnstructuredMarkdownLoader(file).load()
1063
  add_meta(docs1, file)
1064
+ docs1 = clean_doc(docs1)
1065
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.MARKDOWN)
1066
  elif file.lower().endswith('.enex'):
1067
  docs1 = EverNoteLoader(file).load()
1068
  add_meta(doc1, file)
1069
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1070
  elif file.lower().endswith('.epub'):
1071
  docs1 = UnstructuredEPubLoader(file).load()
1072
  add_meta(docs1, file)
1073
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1074
  elif file.lower().endswith('.jpeg') or file.lower().endswith('.jpg') or file.lower().endswith('.png'):
1075
  docs1 = []
1076
  if have_tesseract and enable_ocr:
 
1100
  doci.metadata['source'] = doci.metadata['image_path']
1101
  doci.metadata['hash'] = hash_file(doci.metadata['source'])
1102
  if docs1:
1103
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1104
  elif file.lower().endswith('.msg'):
1105
  raise RuntimeError("Not supported, GPL3 license")
1106
  # docs1 = OutlookMessageLoader(file).load()
 
1109
  try:
1110
  docs1 = UnstructuredEmailLoader(file).load()
1111
  add_meta(docs1, file)
1112
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1113
  except ValueError as e:
1114
  if 'text/html content not found in email' in str(e):
1115
  # e.g. plain/text dict key exists, but not
1116
  # doc1 = TextLoader(file, encoding="utf8").load()
1117
  docs1 = UnstructuredEmailLoader(file, content_source="text/plain").load()
1118
  add_meta(docs1, file)
1119
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1120
  else:
1121
  raise
1122
  # elif file.lower().endswith('.gcsdir'):
 
1127
  with open(file, "r") as f:
1128
  doc1 = Document(page_content=f.read(), metadata={"source": file})
1129
  add_meta(doc1, file)
1130
+ doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size, language=Language.RST)
1131
  elif file.lower().endswith('.pdf'):
1132
  env_gpt4all_file = ".env_gpt4all"
1133
  from dotenv import dotenv_values
 
1136
  if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
1137
  # GPL, only use if installed
1138
  from langchain.document_loaders import PyMuPDFLoader
1139
+ # load() still chunks by pages, but every page has title at start to help
1140
+ doc1 = PyMuPDFLoader(file).load()
1141
+ doc1 = clean_doc(doc1)
1142
+ elif pdf_class_name == 'UnstructuredPDFLoader':
1143
+ doc1 = UnstructuredPDFLoader(file).load()
1144
+ # seems to not need cleaning in most cases
1145
  else:
1146
  # open-source fallback
1147
+ # load() still chunks by pages, but every page has title at start to help
1148
+ doc1 = PyPDFLoader(file).load()
1149
+ doc1 = clean_doc(doc1)
1150
  # Some PDFs return nothing or junk from PDFMinerLoader
1151
+ doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
1152
  add_meta(doc1, file)
1153
  elif file.lower().endswith('.csv'):
1154
  doc1 = CSVLoader(file).load()
 
1156
  elif file.lower().endswith('.py'):
1157
  doc1 = PythonLoader(file).load()
1158
  add_meta(doc1, file)
1159
+ doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size, language=Language.PYTHON)
1160
  elif file.lower().endswith('.toml'):
1161
  doc1 = TomlLoader(file).load()
1162
  add_meta(doc1, file)
 
1164
  with open(file, "r") as f:
1165
  docs1 = UnstructuredURLLoader(urls=f.readlines()).load()
1166
  add_meta(docs1, file)
1167
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1168
  elif file.lower().endswith('.zip'):
1169
  with zipfile.ZipFile(file, 'r') as zip_ref:
1170
  # don't put into temporary path, since want to keep references to docs inside zip
 
1179
  # if list of length one, don't trust and chunk it
1180
  if not isinstance(doc1, list):
1181
  if chunk:
1182
+ docs = chunk_sources([doc1], chunk=chunk, chunk_size=chunk_size)
1183
  else:
1184
  docs = [doc1]
1185
  elif isinstance(doc1, list) and len(doc1) == 1:
1186
  if chunk:
1187
+ docs = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
1188
  else:
1189
  docs = doc1
1190
  else:
 
1194
  return docs
1195
 
1196
 
1197
+ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True,
1198
+ chunk=True, chunk_size=512,
1199
  is_url=False, is_txt=False,
1200
  enable_captions=True,
1201
  captions_model=None,
 
1247
  existing_files=[],
1248
  existing_hash_ids={},
1249
  ):
1250
+ # path_or_paths could be str, list, tuple, generator
1251
  globs_image_types = []
1252
  globs_non_image_types = []
1253
  if not path_or_paths and not url and not text:
1254
  return []
1255
  elif url:
1256
+ globs_non_image_types = url if isinstance(url, (list, tuple, types.GeneratorType)) else [url]
1257
  elif text:
1258
+ globs_non_image_types = text if isinstance(text, (list, tuple, types.GeneratorType)) else [text]
1259
+ elif isinstance(path_or_paths, str) and os.path.isdir(path_or_paths):
1260
  # single path, only consume allowed files
1261
  path = path_or_paths
1262
  # Below globs should match patterns in file_to_doc()
 
1265
  [globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
1266
  for ftype in non_image_types]
1267
  else:
1268
+ if isinstance(path_or_paths, str) and (os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths)):
1269
+ path_or_paths = [path_or_paths]
1270
  # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
1271
+ assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)), "Wrong type for path_or_paths: %s" % type(
1272
+ path_or_paths)
1273
  # reform out of allowed types
1274
  globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
1275
  # could do below:
 
1373
 
1374
  if db_dir_exists and user_path is None:
1375
  print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
1376
+ db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
1377
  hf_embedding_model)
1378
  else:
1379
  if db_dir_exists and user_path is not None:
1380
  print("Prep: persist_directory=%s exists, user_path=%s passed, adding any changed or new documents" % (
1381
+ persist_directory, user_path), flush=True)
1382
  elif not db_dir_exists:
1383
  print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True)
1384
  db = None
 
1424
  posthog.Consumer = FakeConsumer
1425
 
1426
 
1427
+ def check_update_chroma_embedding(db, use_openai_embedding, hf_embedding_model, langchain_mode):
1428
+ changed_db = False
1429
+ if load_embed(db) != (use_openai_embedding, hf_embedding_model):
1430
+ print("Detected new embedding, updating db: %s" % langchain_mode, flush=True)
1431
+ # handle embedding changes
1432
+ db_get = get_documents(db)
1433
+ sources = [Document(page_content=result[0], metadata=result[1] or {})
1434
+ for result in zip(db_get['documents'], db_get['metadatas'])]
1435
+ # delete index, has to be redone
1436
+ persist_directory = db._persist_directory
1437
+ shutil.move(persist_directory, persist_directory + "_" + str(uuid.uuid4()) + ".bak")
1438
+ db_type = 'chroma'
1439
+ load_db_if_exists = False
1440
+ db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type,
1441
+ persist_directory=persist_directory, load_db_if_exists=load_db_if_exists,
1442
+ langchain_mode=langchain_mode,
1443
+ collection_name=None,
1444
+ hf_embedding_model=hf_embedding_model)
1445
+ if False:
1446
+ # below doesn't work if db already in memory, so have to switch to new db as above
1447
+ # upsert does new embedding, but if index already in memory, complains about size mismatch etc.
1448
+ client_collection = db._client.get_collection(name=db._collection.name,
1449
+ embedding_function=db._collection._embedding_function)
1450
+ client_collection.upsert(ids=db_get['ids'], metadatas=db_get['metadatas'], documents=db_get['documents'])
1451
+ changed_db = True
1452
+ print("Done updating db for new embedding: %s" % langchain_mode, flush=True)
1453
+
1454
+ return db, changed_db
1455
+
1456
+
1457
+ def get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
1458
+ hf_embedding_model, verbose=False, check_embedding=True):
1459
  if load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
1460
  os.path.join(persist_directory, 'index')):
1461
+ if db is None:
1462
+ if verbose:
1463
+ print("DO Loading db: %s" % langchain_mode, flush=True)
1464
+ embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
1465
+ from chromadb.config import Settings
1466
+ client_settings = Settings(anonymized_telemetry=False,
1467
+ chroma_db_impl="duckdb+parquet",
1468
+ persist_directory=persist_directory)
1469
+ db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
1470
+ collection_name=langchain_mode.replace(' ', '_'),
1471
+ client_settings=client_settings)
1472
+ if verbose:
1473
+ print("DONE Loading db: %s" % langchain_mode, flush=True)
1474
+ else:
1475
+ if verbose:
1476
+ print("USING already-loaded db: %s" % langchain_mode, flush=True)
1477
+ if check_embedding:
1478
+ db_trial, changed_db = check_update_chroma_embedding(db, use_openai_embedding, hf_embedding_model,
1479
+ langchain_mode)
1480
+ if changed_db:
1481
+ db = db_trial
1482
+ # only call persist if really changed db, else takes too long for large db
1483
+ if db is not None:
1484
+ db.persist()
1485
+ clear_embedding(db)
1486
+ save_embed(db, use_openai_embedding, hf_embedding_model)
1487
  return db
1488
  return None
1489
 
1490
 
1491
+ def clear_embedding(db):
1492
+ if db is None:
1493
+ return
1494
+ # don't keep on GPU, wastes memory, push back onto CPU and only put back on GPU once again embed
1495
+ db._embedding_function.client.cpu()
1496
+ clear_torch_cache()
1497
+
1498
+
1499
  def make_db(**langchain_kwargs):
1500
  func_names = list(inspect.signature(_make_db).parameters)
1501
  missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
 
1511
  return _make_db(**langchain_kwargs)
1512
 
1513
 
1514
+ def save_embed(db, use_openai_embedding, hf_embedding_model):
1515
+ if db is not None:
1516
+ embed_info_file = os.path.join(db._persist_directory, 'embed_info')
1517
+ with open(embed_info_file, 'wb') as f:
1518
+ pickle.dump((use_openai_embedding, hf_embedding_model), f)
1519
+ return use_openai_embedding, hf_embedding_model
1520
+
1521
+
1522
+ def load_embed(db):
1523
+ embed_info_file = os.path.join(db._persist_directory, 'embed_info')
1524
+ if os.path.isfile(embed_info_file):
1525
+ with open(embed_info_file, 'rb') as f:
1526
+ use_openai_embedding, hf_embedding_model = pickle.load(f)
1527
+ else:
1528
+ # migration, assume defaults
1529
+ use_openai_embedding, hf_embedding_model = False, "sentence-transformers/all-MiniLM-L6-v2"
1530
+ return use_openai_embedding, hf_embedding_model
1531
+
1532
+
1533
+ def get_persist_directory(langchain_mode):
1534
+ return 'db_dir_%s' % langchain_mode # single place, no special names for each case
1535
+
1536
+
1537
  def _make_db(use_openai_embedding=False,
1538
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1539
+ first_para=False, text_limit=None,
1540
+ chunk=True, chunk_size=512,
1541
  langchain_mode=None,
1542
  user_path=None,
1543
  db_type='faiss',
 
1545
  db=None,
1546
  n_jobs=-1,
1547
  verbose=False):
1548
+ persist_directory = get_persist_directory(langchain_mode)
1549
+ # see if can get persistent chroma db
1550
+ db_trial = get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
1551
+ hf_embedding_model, verbose=verbose)
1552
+ if db_trial is not None:
1553
+ db = db_trial
1554
+
 
 
 
 
 
 
1555
  sources = []
1556
  if not db and langchain_mode not in ['MyData'] or \
1557
  user_path is not None and \
 
1576
  sources1 = get_all_documents(small_test=small_test, n_jobs=os.cpu_count() // 2)
1577
  print("Got new wiki", flush=True)
1578
  if chunk:
1579
+ sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1580
  print("Chunked new wiki", flush=True)
1581
  sources.extend(sources1)
1582
  if langchain_mode in ['wiki', 'All', "'All'"]:
1583
  sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
1584
  if chunk:
1585
+ sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1586
  sources.extend(sources1)
1587
  if langchain_mode in ['github h2oGPT', 'All', "'All'"]:
1588
  # sources = get_github_docs("dagster-io", "dagster")
1589
  sources1 = get_github_docs("h2oai", "h2ogpt")
1590
  # FIXME: always chunk for now
1591
+ sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1592
  sources.extend(sources1)
1593
  if langchain_mode in ['DriverlessAI docs', 'All', "'All'"]:
1594
  sources1 = get_dai_docs(from_hf=True)
1595
  if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
1596
+ sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1597
  sources.extend(sources1)
1598
  if langchain_mode in ['All', 'UserData']:
1599
  if user_path:
 
1607
  existing_files = []
1608
  existing_hash_ids = []
1609
  # chunk internally for speed over multiple docs
1610
+ # FIXME: If first had old Hash=None and switch embeddings,
1611
+ # then re-embed, and then hit here and reload so have hash, and then re-embed.
1612
  sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
1613
  existing_files=existing_files, existing_hash_ids=existing_hash_ids)
1614
  new_metadata_sources = set([x.metadata['source'] for x in sources1])
 
1652
  new_sources_metadata = [x.metadata for x in sources]
1653
  elif user_path is not None and langchain_mode in ['UserData']:
1654
  print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
1655
+ db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
1656
+ use_openai_embedding=use_openai_embedding,
1657
+ hf_embedding_model=hf_embedding_model)
1658
  print("Existing db, added %s new sources from user_path=%s" % (num_new_sources, user_path), flush=True)
1659
  else:
1660
  new_sources_metadata = [x.metadata for x in sources]
 
1662
  return db, len(new_sources_metadata), new_sources_metadata
1663
 
1664
 
1665
+ def get_metadatas(db):
1666
+ from langchain.vectorstores import FAISS
1667
+ if isinstance(db, FAISS):
1668
+ metadatas = [v.metadata for k, v in db.docstore._dict.items()]
1669
+ elif isinstance(db, Chroma):
1670
+ metadatas = get_documents(db)['metadatas']
1671
+ else:
1672
+ # FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947
1673
+ # seems no way to get all metadata, so need to avoid this approach for weaviate
1674
+ metadatas = [x.metadata for x in db.similarity_search("", k=10000)]
1675
+ return metadatas
1676
+
1677
+
1678
+ def get_documents(db):
1679
+ if hasattr(db, '_persist_directory'):
1680
+ name_path = os.path.basename(db._persist_directory)
1681
+ base_path = 'locks'
1682
+ makedirs(base_path)
1683
+ with filelock.FileLock(os.path.join(base_path, "getdb_%s.lock" % name_path)):
1684
+ # get segfaults and other errors when multiple threads access this
1685
+ return _get_documents(db)
1686
+ else:
1687
+ return _get_documents(db)
1688
+
1689
+
1690
+ def _get_documents(db):
1691
+ from langchain.vectorstores import FAISS
1692
+ if isinstance(db, FAISS):
1693
+ documents = [v for k, v in db.docstore._dict.items()]
1694
+ elif isinstance(db, Chroma):
1695
+ documents = db.get()
1696
+ else:
1697
+ # FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947
1698
+ # seems no way to get all metadata, so need to avoid this approach for weaviate
1699
+ documents = [x for x in db.similarity_search("", k=10000)]
1700
+ return documents
1701
+
1702
+
1703
+ def get_docs_and_meta(db, top_k_docs, filter_kwargs={}):
1704
+ if hasattr(db, '_persist_directory'):
1705
+ name_path = os.path.basename(db._persist_directory)
1706
+ base_path = 'locks'
1707
+ makedirs(base_path)
1708
+ with filelock.FileLock(os.path.join(base_path, "getdb_%s.lock" % name_path)):
1709
+ return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
1710
+ else:
1711
+ return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
1712
+
1713
+
1714
+ def _get_docs_and_meta(db, top_k_docs, filter_kwargs={}):
1715
+ from langchain.vectorstores import FAISS
1716
+ if isinstance(db, Chroma):
1717
+ db_get = db._collection.get(where=filter_kwargs.get('filter'))
1718
+ db_metadatas = db_get['metadatas']
1719
+ db_documents = db_get['documents']
1720
+ elif isinstance(db, FAISS):
1721
+ import itertools
1722
+ db_metadatas = get_metadatas(db)
1723
+ # FIXME: FAISS has no filter
1724
+ # slice dict first
1725
+ db_documents = list(dict(itertools.islice(db.docstore._dict.items(), top_k_docs)).values())
1726
+ else:
1727
+ db_metadatas = get_metadatas(db)
1728
+ db_documents = get_documents(db)
1729
+ return db_documents, db_metadatas
1730
+
1731
+
1732
  def get_existing_files(db):
1733
+ metadatas = get_metadatas(db)
1734
+ metadata_sources = set([x['source'] for x in metadatas])
1735
  return metadata_sources
1736
 
1737
 
1738
  def get_existing_hash_ids(db):
1739
+ metadatas = get_metadatas(db)
1740
  # assume consistency, that any prior hashed source was single hashed file at the time among all source chunks
1741
+ metadata_hash_ids = {x['source']: x.get('hashid') for x in metadatas}
1742
  return metadata_hash_ids
1743
 
1744
 
 
 
 
 
1745
  def run_qa_db(**kwargs):
1746
  func_names = list(inspect.signature(_run_qa_db).parameters)
1747
  # hard-coded defaults
1748
  kwargs['answer_with_sources'] = True
 
1749
  kwargs['show_rank'] = False
1750
  missing_kwargs = [x for x in func_names if x not in kwargs]
1751
  assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
1752
  # only keep actual used
1753
  kwargs = {k: v for k, v in kwargs.items() if k in func_names}
1754
+ try:
1755
+ return _run_qa_db(**kwargs)
1756
+ finally:
1757
+ clear_torch_cache()
1758
 
1759
 
1760
  def _run_qa_db(query=None,
1761
  use_openai_model=False, use_openai_embedding=False,
1762
+ first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1763
  user_path=None,
1764
  detect_user_path_changes_every_query=False,
1765
  db_type='faiss',
1766
+ model_name=None, model=None, tokenizer=None, inference_server=None,
1767
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1768
  stream_output=False,
1769
  prompter=None,
1770
  prompt_type=None,
1771
+ prompt_dict=None,
1772
  answer_with_sources=True,
1773
  cut_distanct=1.1,
1774
+ sanitize_bot_response=False,
1775
  show_rank=False,
1776
  load_db_if_exists=False,
1777
  db=None,
1778
+ do_sample=False,
1779
  temperature=0.1,
 
1780
  top_k=40,
1781
  top_p=0.7,
1782
+ num_beams=1,
1783
+ max_new_tokens=256,
1784
+ min_new_tokens=1,
1785
+ early_stopping=False,
1786
+ max_time=180,
1787
+ repetition_penalty=1.0,
1788
+ num_return_sequences=1,
1789
  langchain_mode=None,
1790
+ document_choice=[DocumentChoices.All_Relevant.name],
1791
  n_jobs=-1,
1792
  verbose=False,
1793
+ cli=False,
1794
+ reverse_docs=True,
1795
+ lora_weights='',
1796
+ auto_reduce_chunks=True,
1797
+ max_chunks=100,
1798
+ ):
1799
  """
1800
 
1801
  :param query:
 
1814
  :param answer_with_sources
1815
  :return:
1816
  """
1817
+ if model is not None:
1818
+ assert model_name is not None # require so can make decisions
1819
  assert query is not None
1820
  assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate
1821
  if prompter is not None:
1822
  prompt_type = prompter.prompt_type
1823
+ prompt_dict = prompter.prompt_dict
1824
  if model is not None:
1825
  assert prompt_type is not None
1826
+ if prompt_type == PromptType.custom.name:
1827
+ assert prompt_dict is not None # should at least be {} or ''
1828
+ else:
1829
+ prompt_dict = ''
1830
+ assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
1831
  llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
1832
+ model=model,
1833
+ tokenizer=tokenizer,
1834
+ inference_server=inference_server,
1835
  stream_output=stream_output,
1836
+ do_sample=do_sample,
1837
  temperature=temperature,
 
1838
  top_k=top_k,
1839
  top_p=top_p,
1840
+ num_beams=num_beams,
1841
+ max_new_tokens=max_new_tokens,
1842
+ min_new_tokens=min_new_tokens,
1843
+ early_stopping=early_stopping,
1844
+ max_time=max_time,
1845
+ repetition_penalty=repetition_penalty,
1846
+ num_return_sequences=num_return_sequences,
1847
  prompt_type=prompt_type,
1848
+ prompt_dict=prompt_dict,
1849
  prompter=prompter,
1850
+ sanitize_bot_response=sanitize_bot_response,
1851
  verbose=verbose,
1852
  )
1853
 
 
 
 
 
1854
  use_context = False
1855
  scores = []
1856
  chain = None
1857
 
1858
+ if isinstance(document_choice, str):
1859
+ # support string as well
1860
+ document_choice = [document_choice]
1861
+ # get first DocumentChoices as command to use, ignore others
1862
+ doc_choices_set = set([x.name for x in list(DocumentChoices)])
1863
+ cmd = [x for x in document_choice if x in doc_choices_set]
1864
+ cmd = None if len(cmd) == 0 else cmd[0]
1865
+ # now have cmd, filter out for only docs
1866
+ document_choice = [x for x in document_choice if x not in doc_choices_set]
1867
+
1868
  func_names = list(inspect.signature(get_similarity_chain).parameters)
1869
  sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
1870
  missing_kwargs = [x for x in func_names if x not in sim_kwargs]
1871
  assert not missing_kwargs, "Missing: %s" % missing_kwargs
1872
  docs, chain, scores, use_context = get_similarity_chain(**sim_kwargs)
1873
+ if cmd in [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]:
1874
  formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
1875
  yield formatted_doc_chunks, ''
1876
  return
 
1878
  # can only return if HF type
1879
  return
1880
 
1881
+ # context stuff similar to used in evaluate()
1882
+ import torch
1883
+ device, torch_dtype, context_class = get_device_dtype()
1884
+ with torch.no_grad():
1885
+ have_lora_weights = lora_weights not in [no_lora_str, '', None]
1886
+ context_class_cast = NullContext if device == 'cpu' or have_lora_weights else torch.autocast
1887
+ with context_class_cast(device):
1888
+ if stream_output and streamer:
1889
+ answer = None
1890
+ import queue
1891
+ bucket = queue.Queue()
1892
+ thread = EThread(target=chain, streamer=streamer, bucket=bucket)
1893
+ thread.start()
1894
+ outputs = ""
1895
+ prompt = None # FIXME
1896
+ try:
1897
+ for new_text in streamer:
1898
+ # print("new_text: %s" % new_text, flush=True)
1899
+ if bucket.qsize() > 0 or thread.exc:
1900
+ thread.join()
1901
+ outputs += new_text
1902
+ if prompter: # and False: # FIXME: pipeline can already use prompter
1903
+ output1 = prompter.get_response(outputs, prompt=prompt,
1904
+ sanitize_bot_response=sanitize_bot_response)
1905
+ yield output1, ''
1906
+ else:
1907
+ yield outputs, ''
1908
+ except BaseException:
1909
+ # if any exception, raise that exception if was from thread, first
1910
+ if thread.exc:
1911
+ raise thread.exc
1912
+ raise
1913
+ finally:
1914
+ # in case no exception and didn't join with thread yet, then join
1915
+ if not thread.exc:
1916
+ answer = thread.join()
1917
+ # in case raise StopIteration or broke queue loop in streamer, but still have exception
1918
+ if thread.exc:
1919
+ raise thread.exc
1920
+ # FIXME: answer is not string outputs from streamer. How to get actual final output?
1921
+ # answer = outputs
1922
+ else:
1923
+ answer = chain()
1924
 
1925
  if not use_context:
1926
  ret = answer['output_text']
 
1934
 
1935
  def get_similarity_chain(query=None,
1936
  use_openai_model=False, use_openai_embedding=False,
1937
+ first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1938
  user_path=None,
1939
  detect_user_path_changes_every_query=False,
1940
  db_type='faiss',
1941
  model_name=None,
1942
+ inference_server='',
1943
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1944
  prompt_type=None,
1945
+ prompt_dict=None,
1946
  cut_distanct=1.1,
1947
  load_db_if_exists=False,
1948
  db=None,
1949
  langchain_mode=None,
1950
+ document_choice=[DocumentChoices.All_Relevant.name],
1951
  n_jobs=-1,
1952
  # beyond run_db_query:
1953
  llm=None,
1954
+ tokenizer=None,
1955
  verbose=False,
1956
+ cmd=None,
1957
+ reverse_docs=True,
1958
+
1959
+ # local
1960
+ auto_reduce_chunks=True,
1961
+ max_chunks=100,
1962
  ):
1963
  # determine whether use of context out of docs is planned
1964
  if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
 
1970
  use_context = True
1971
 
1972
  # https://github.com/hwchase17/langchain/issues/1946
1973
+ # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
1974
  # Chroma collection MyData contains fewer than 4 elements.
1975
  # type logger error
1976
+ if top_k_docs == -1:
1977
+ k_db = 1000 if db_type == 'chroma' else 100
1978
+ else:
1979
+ # top_k_docs=100 works ok too
1980
+ k_db = 1000 if db_type == 'chroma' else top_k_docs
1981
 
1982
  # FIXME: For All just go over all dbs instead of a separate db for All
1983
  if not detect_user_path_changes_every_query and db is not None:
 
1987
  user_path = None
1988
  db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
1989
  hf_embedding_model=hf_embedding_model,
1990
+ first_para=first_para, text_limit=text_limit,
1991
+ chunk=chunk,
1992
  chunk_size=chunk_size,
1993
  langchain_mode=langchain_mode,
1994
  user_path=user_path,
 
1998
  n_jobs=n_jobs,
1999
  verbose=verbose)
2000
 
2001
+ if 'falcon' in model_name:
2002
+ extra = "According to only the information in the document sources provided within the context above, "
2003
+ prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends."
2004
+ elif inference_server in ['openai', 'openai_chat']:
2005
+ extra = "According to (primarily) the information in the document sources provided within context above, "
2006
+ prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends. If the answer cannot be primarily obtained from information within the context, then respond that the answer does not appear in the context of the documents."
2007
+ else:
2008
+ extra = ""
2009
+ prefix = ""
2010
+ if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
2011
+ template_if_no_docs = template = """%s{context}{question}""" % prefix
2012
+ else:
2013
+ template = """%s
2014
+ \"\"\"
2015
+ {context}
2016
+ \"\"\"
2017
+ %s{question}""" % (prefix, extra)
2018
+ template_if_no_docs = """%s{context}%s{question}""" % (prefix, extra)
2019
+ if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
2020
+ use_template = True
2021
+ else:
2022
+ use_template = False
2023
+
2024
  if db and use_context:
2025
+ if not isinstance(db, Chroma):
2026
+ # only chroma supports filtering
 
 
 
 
 
 
 
 
2027
  filter_kwargs = {}
2028
  else:
2029
+ # if here then some cmd + documents selected or just documents selected
2030
  if len(document_choice) >= 2:
2031
  or_filter = [{"source": {"$eq": x}} for x in document_choice]
2032
  filter_kwargs = dict(filter={"$or": or_filter})
2033
+ elif len(document_choice) == 1:
2034
+ # degenerate UX bug in chroma
2035
  one_filter = [{"source": {"$eq": x}} for x in document_choice][0]
2036
  filter_kwargs = dict(filter=one_filter)
2037
  else:
2038
+ # shouldn't reach
2039
  filter_kwargs = {}
2040
+ if cmd == DocumentChoices.Just_LLM.name:
2041
+ docs = []
2042
+ scores = []
2043
+ elif cmd == DocumentChoices.Only_All_Sources.name:
2044
+ db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
2045
+ # similar to langchain's chroma's _results_to_docs_and_scores
2046
+ docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
2047
+ for result in zip(db_documents, db_metadatas)][:top_k_docs]
2048
+ docs = [x[0] for x in docs_with_score]
2049
+ scores = [x[1] for x in docs_with_score]
2050
+ else:
2051
+ if top_k_docs == -1 or auto_reduce_chunks:
2052
+ # docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
2053
+ top_k_docs_tokenize = 100
2054
+ base_path = 'locks'
2055
+ makedirs(base_path)
2056
+ if hasattr(db, '_persist_directory'):
2057
+ name_path = "sim_%s.lock" % os.path.basename(db._persist_directory)
2058
+ else:
2059
+ name_path = "sim.lock"
2060
+ with filelock.FileLock(os.path.join(base_path, name_path)):
2061
+ docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[
2062
+ :top_k_docs_tokenize]
2063
+ if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
2064
+ # more accurate
2065
+ tokens = [len(llm.pipeline.tokenizer(x[0].page_content)['input_ids']) for x in docs_with_score]
2066
+ template_tokens = len(llm.pipeline.tokenizer(template)['input_ids'])
2067
+ elif inference_server in ['openai', 'openai_chat'] or use_openai_model or db_type in ['faiss',
2068
+ 'weaviate']:
2069
+ # use ticktoken for faiss since embedding called differently
2070
+ tokens = [llm.get_num_tokens(x[0].page_content) for x in docs_with_score]
2071
+ template_tokens = llm.get_num_tokens(template)
2072
+ elif isinstance(tokenizer, FakeTokenizer):
2073
+ tokens = [tokenizer.num_tokens_from_string(x[0].page_content) for x in docs_with_score]
2074
+ template_tokens = tokenizer.num_tokens_from_string(template)
2075
+ else:
2076
+ # in case model is not our pipeline with HF tokenizer
2077
+ tokens = [db._embedding_function.client.tokenize([x[0].page_content])['input_ids'].shape[1] for x in
2078
+ docs_with_score]
2079
+ template_tokens = db._embedding_function.client.tokenize([template])['input_ids'].shape[1]
2080
+ tokens_cumsum = np.cumsum(tokens)
2081
+ if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'max_input_tokens'):
2082
+ max_input_tokens = llm.pipeline.max_input_tokens
2083
+ elif inference_server in ['openai']:
2084
+ max_tokens = llm.modelname_to_contextsize(model_name)
2085
+ # leave some room for 1 paragraph, even if min_new_tokens=0
2086
+ max_input_tokens = max_tokens - 256
2087
+ elif inference_server in ['openai_chat']:
2088
+ max_tokens = model_token_mapping[model_name]
2089
+ # leave some room for 1 paragraph, even if min_new_tokens=0
2090
+ max_input_tokens = max_tokens - 256
2091
+ elif isinstance(tokenizer, FakeTokenizer):
2092
+ max_input_tokens = tokenizer.model_max_length - 256
2093
+ else:
2094
+ # leave some room for 1 paragraph, even if min_new_tokens=0
2095
+ max_input_tokens = 2048 - 256
2096
+ max_input_tokens -= template_tokens
2097
+ # FIXME: Doesn't account for query, == context, or new lines between contexts
2098
+ where_res = np.where(tokens_cumsum < max_input_tokens)[0]
2099
+ if where_res.shape[0] == 0:
2100
+ # then no chunk can fit, still do first one
2101
+ top_k_docs_trial = 1
2102
+ else:
2103
+ top_k_docs_trial = 1 + where_res[-1]
2104
+ if 0 < top_k_docs_trial < max_chunks:
2105
+ # avoid craziness
2106
+ if top_k_docs == -1:
2107
+ top_k_docs = top_k_docs_trial
2108
+ else:
2109
+ top_k_docs = min(top_k_docs, top_k_docs_trial)
2110
+ if top_k_docs == -1:
2111
+ # if here, means 0 and just do best with 1 doc
2112
+ print("Unexpected large chunks and can't add to context, will add 1 anyways", flush=True)
2113
+ top_k_docs = 1
2114
+ docs_with_score = docs_with_score[:top_k_docs]
2115
+ else:
2116
+ docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
2117
+ # put most relevant chunks closest to question,
2118
+ # esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
2119
+ # BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
2120
+ if reverse_docs:
2121
+ docs_with_score.reverse()
2122
+ # cut off so no high distance docs/sources considered
2123
+ docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
2124
+ scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
2125
+ if len(scores) > 0 and verbose:
2126
+ print("Distance: min: %s max: %s mean: %s median: %s" %
2127
+ (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
2128
  else:
2129
  docs = []
2130
  scores = []
 
2133
  # if HF type and have no docs, can bail out
2134
  return docs, None, [], False
2135
 
2136
+ if cmd in [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]:
2137
  # no LLM use
2138
  return docs, None, [], False
2139
 
 
2153
  if len(docs) == 0:
2154
  # avoid context == in prompt then
2155
  use_context = False
2156
+ template = template_if_no_docs
2157
 
2158
+ if use_template:
2159
  # instruct-like, rather than few-shot prompt_type='plain' as default
2160
  # but then sources confuse the model with how inserted among rest of text, so avoid
 
 
 
 
 
 
 
 
 
2161
  prompt = PromptTemplate(
2162
  # input_variables=["summaries", "question"],
2163
  input_variables=["context", "question"],
 
2217
  return ret, extra
2218
 
2219
 
2220
+ def clean_doc(docs1):
2221
+ if not isinstance(docs1, (list, tuple, types.GeneratorType)):
2222
+ docs1 = [docs1]
2223
+ for doci, doc in enumerate(docs1):
2224
+ docs1[doci].page_content = '\n'.join([x.strip() for x in doc.page_content.split("\n") if x.strip()])
2225
+ return docs1
2226
+
2227
+
2228
+ def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
2229
+ if not chunk:
2230
+ return sources
2231
+ if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources):
2232
+ # if just one document
2233
+ sources = [sources]
2234
+ if language and False:
2235
+ # Bug in langchain, keep separator=True not working
2236
+ # https://github.com/hwchase17/langchain/issues/2836
2237
+ # so avoid this for now
2238
+ keep_separator = True
2239
+ separators = RecursiveCharacterTextSplitter.get_separators_for_language(language)
2240
+ else:
2241
+ separators = ["\n\n", "\n", " ", ""]
2242
+ keep_separator = False
2243
+ splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator,
2244
+ separators=separators)
2245
+ source_chunks = splitter.split_documents(sources)
2246
  return source_chunks
2247
 
2248
 
 
2253
  path_to_zip_file = hf_hub_download('h2oai/db_dirs', db_dir, token=token, repo_type='dataset')
2254
  import zipfile
2255
  with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
2256
+ persist_directory = os.path.dirname(zip_ref.namelist()[0])
2257
+ remove(persist_directory)
2258
  zip_ref.extractall(dest)
2259
  return path_to_zip_file
2260
 
 
2283
  assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected
2284
 
2285
 
2286
+ def _create_local_weaviate_client():
2287
+ WEAVIATE_URL = os.getenv('WEAVIATE_URL', "http://localhost:8080")
2288
+ WEAVIATE_USERNAME = os.getenv('WEAVIATE_USERNAME')
2289
+ WEAVIATE_PASSWORD = os.getenv('WEAVIATE_PASSWORD')
2290
+ WEAVIATE_SCOPE = os.getenv('WEAVIATE_SCOPE', "offline_access")
2291
+
2292
+ resource_owner_config = None
2293
+ try:
2294
+ import weaviate
2295
+ if WEAVIATE_USERNAME is not None and WEAVIATE_PASSWORD is not None:
2296
+ resource_owner_config = weaviate.AuthClientPassword(
2297
+ username=WEAVIATE_USERNAME,
2298
+ password=WEAVIATE_PASSWORD,
2299
+ scope=WEAVIATE_SCOPE
2300
+ )
2301
+
2302
+ client = weaviate.Client(WEAVIATE_URL, auth_client_secret=resource_owner_config)
2303
+ return client
2304
+ except Exception as e:
2305
+ print(f"Failed to create Weaviate client: {e}")
2306
+ return None
2307
+
2308
+
2309
  if __name__ == '__main__':
2310
  pass
gradio_runner.py CHANGED
The diff for this file is too large to render. See raw diff
 
gradio_themes.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
  from typing import Iterable
4
 
5
  from gradio.themes.soft import Soft
6
- from gradio.themes import Color
7
  from gradio.themes.utils import colors, sizes, fonts
8
 
9
  h2o_yellow = Color(
@@ -36,6 +36,42 @@ h2o_gray = Color(
36
  )
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  class H2oTheme(Soft):
40
  def __init__(
41
  self,
@@ -158,19 +194,23 @@ h2o_logo = '<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/
158
  '11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
159
 
160
 
161
- def get_h2o_title(title):
162
- return f"""<div style="display:flex; justify-content:center; margin-bottom:30px;">
 
 
 
 
163
  <div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
164
  <h1 style="line-height:60px">{title}</h1>
165
  </div>
166
  <div style="float:right; height: 80px; width: 80px; margin-top:-100px">
167
- <img src=https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png></img>
168
  </div>
169
  """
170
 
171
 
172
- def get_simple_title(title):
173
- return f"""<h1 align="center"> {title}</h1>"""
174
 
175
 
176
  def get_dark_js():
 
3
  from typing import Iterable
4
 
5
  from gradio.themes.soft import Soft
6
+ from gradio.themes import Color, Size
7
  from gradio.themes.utils import colors, sizes, fonts
8
 
9
  h2o_yellow = Color(
 
36
  )
37
 
38
 
39
+ text_xsm = Size(
40
+ name="text_xsm",
41
+ xxs="4px",
42
+ xs="5px",
43
+ sm="6px",
44
+ md="7px",
45
+ lg="8px",
46
+ xl="10px",
47
+ xxl="12px",
48
+ )
49
+
50
+
51
+ spacing_xsm = Size(
52
+ name="spacing_xsm",
53
+ xxs="1px",
54
+ xs="1px",
55
+ sm="1px",
56
+ md="2px",
57
+ lg="3px",
58
+ xl="5px",
59
+ xxl="7px",
60
+ )
61
+
62
+
63
+ radius_xsm = Size(
64
+ name="radius_xsm",
65
+ xxs="1px",
66
+ xs="1px",
67
+ sm="1px",
68
+ md="2px",
69
+ lg="3px",
70
+ xl="5px",
71
+ xxl="7px",
72
+ )
73
+
74
+
75
  class H2oTheme(Soft):
76
  def __init__(
77
  self,
 
194
  '11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
195
 
196
 
197
+ def get_h2o_title(title, description):
198
+ # NOTE: Check full width desktop, smallest width browser desktop, iPhone browsers to ensure no overlap etc.
199
+ return f"""<div style="float:left; justify-content:left; height: 80px; width: 195px; margin-top:0px">
200
+ {description}
201
+ </div>
202
+ <div style="display:flex; justify-content:center; margin-bottom:30px; margin-right:330px;">
203
  <div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
204
  <h1 style="line-height:60px">{title}</h1>
205
  </div>
206
  <div style="float:right; height: 80px; width: 80px; margin-top:-100px">
207
+ <img src="https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png">
208
  </div>
209
  """
210
 
211
 
212
+ def get_simple_title(title, description):
213
+ return f"""{description}<h1 align="center"> {title}</h1>"""
214
 
215
 
216
  def get_dark_js():
gradio_utils/__pycache__/css.cpython-310.pyc ADDED
Binary file (1.53 kB). View file
 
gradio_utils/__pycache__/grclient.cpython-310.pyc ADDED
Binary file (2.69 kB). View file
 
gradio_utils/__pycache__/prompt_form.cpython-310.pyc ADDED
Binary file (3.59 kB). View file
 
gradio_utils/css.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_css(kwargs) -> str:
2
+ if kwargs['h2ocolors']:
3
+ css_code = """footer {visibility: hidden;}
4
+ body{background:linear-gradient(#f5f5f5,#e5e5e5);}
5
+ body.dark{background:linear-gradient(#000000,#0d0d0d);}
6
+ """
7
+ else:
8
+ css_code = """footer {visibility: hidden}"""
9
+
10
+ css_code += make_css_base()
11
+ return css_code
12
+
13
+
14
+ def make_css_base() -> str:
15
+ return """
16
+ @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
17
+
18
+ body.dark{#warning {background-color: #555555};}
19
+
20
+ #small_btn {
21
+ margin: 0.6em 0em 0.55em 0;
22
+ max-width: 20em;
23
+ min-width: 5em !important;
24
+ height: 5em;
25
+ font-size: 14px !important;
26
+ }
27
+
28
+ #prompt-form {
29
+ border: 1px solid var(--primary-500) !important;
30
+ }
31
+
32
+ #prompt-form.block {
33
+ border-radius: var(--block-radius) !important;
34
+ }
35
+
36
+ #prompt-form textarea {
37
+ border: 1px solid rgb(209, 213, 219);
38
+ }
39
+
40
+ #prompt-form label > div {
41
+ margin-top: 4px;
42
+ }
43
+
44
+ button.primary:hover {
45
+ background-color: var(--primary-600) !important;
46
+ transition: .2s;
47
+ }
48
+
49
+ #prompt-form-area {
50
+ margin-bottom: 2.5rem;
51
+ }
52
+ .chatsmall chatbot {font-size: 10px !important}
53
+ """
gradio_utils/grclient.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from typing import Callable
3
+ import os
4
+
5
+ from gradio_client.client import Job
6
+
7
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
8
+
9
+ from gradio_client import Client
10
+
11
+
12
+ class GradioClient(Client):
13
+ """
14
+ Parent class of gradio client
15
+ To handle automatically refreshing client if detect gradio server changed
16
+ """
17
+
18
+ def __init__(self, *args, **kwargs):
19
+ self.args = args
20
+ self.kwargs = kwargs
21
+ super().__init__(*args, **kwargs)
22
+ self.server_hash = self.get_server_hash()
23
+
24
+ def get_server_hash(self):
25
+ """
26
+ Get server hash using super without any refresh action triggered
27
+ Returns: git hash of gradio server
28
+ """
29
+ return super().submit(api_name='/system_hash').result()
30
+
31
+ def refresh_client_if_should(self):
32
+ # get current hash in order to update api_name -> fn_index map in case gradio server changed
33
+ # FIXME: Could add cli api as hash
34
+ server_hash = self.get_server_hash()
35
+ if self.server_hash != server_hash:
36
+ self.refresh_client()
37
+ self.server_hash = server_hash
38
+ else:
39
+ self.reset_session()
40
+
41
+ def refresh_client(self):
42
+ """
43
+ Ensure every client call is independent
44
+ Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code)
45
+ Returns:
46
+ """
47
+ # need session hash to be new every time, to avoid "generator already executing"
48
+ self.reset_session()
49
+
50
+ client = Client(*self.args, **self.kwargs)
51
+ for k, v in client.__dict__.items():
52
+ setattr(self, k, v)
53
+
54
+ def submit(
55
+ self,
56
+ *args,
57
+ api_name: str | None = None,
58
+ fn_index: int | None = None,
59
+ result_callbacks: Callable | list[Callable] | None = None,
60
+ ) -> Job:
61
+ # Note predict calls submit
62
+ try:
63
+ self.refresh_client_if_should()
64
+ job = super().submit(*args, api_name=api_name, fn_index=fn_index)
65
+ except Exception as e:
66
+ print("Hit e=%s" % str(e), flush=True)
67
+ # force reconfig in case only that
68
+ self.refresh_client()
69
+ job = super().submit(*args, api_name=api_name, fn_index=fn_index)
70
+
71
+ # see if immediately failed
72
+ e = job.future._exception
73
+ if e is not None:
74
+ print("GR job failed: %s %s" % (str(e), ''.join(traceback.format_tb(e.__traceback__))), flush=True)
75
+ # force reconfig in case only that
76
+ self.refresh_client()
77
+ job = super().submit(*args, api_name=api_name, fn_index=fn_index)
78
+ e2 = job.future._exception
79
+ if e2 is not None:
80
+ print("GR job failed again: %s\n%s" % (str(e2), ''.join(traceback.format_tb(e2.__traceback__))), flush=True)
81
+
82
+ return job
gradio_utils/prompt_form.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+
4
+ import gradio as gr
5
+
6
+
7
+ def make_chatbots(output_label0, output_label0_model2, **kwargs):
8
+ text_outputs = []
9
+ chat_kwargs = []
10
+ for model_state_lock in kwargs['model_states']:
11
+ if os.environ.get('DEBUG_MODEL_LOCK'):
12
+ model_name = model_state_lock["base_model"] + " : " + model_state_lock["inference_server"]
13
+ else:
14
+ model_name = model_state_lock["base_model"]
15
+ output_label = f'h2oGPT [{model_name}]'
16
+ min_width = 250 if kwargs['gradio_size'] in ['small', 'large', 'medium'] else 160
17
+ chat_kwargs.append(dict(label=output_label, visible=kwargs['model_lock'], elem_classes='chatsmall',
18
+ height=kwargs['height'] or 400, min_width=min_width))
19
+
20
+ if kwargs['model_lock_columns'] == -1:
21
+ kwargs['model_lock_columns'] = len(kwargs['model_states'])
22
+ if kwargs['model_lock_columns'] is None:
23
+ kwargs['model_lock_columns'] = 3
24
+
25
+ ncols = kwargs['model_lock_columns']
26
+ if kwargs['model_states'] == 0:
27
+ nrows = 0
28
+ else:
29
+ nrows = math.ceil(len(kwargs['model_states']) / kwargs['model_lock_columns'])
30
+
31
+ if kwargs['model_lock_columns'] == 0:
32
+ # not using model_lock
33
+ pass
34
+ elif nrows <= 1:
35
+ with gr.Row():
36
+ for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
37
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
38
+ elif nrows == kwargs['model_states']:
39
+ with gr.Row():
40
+ for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
41
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
42
+ elif nrows == 2:
43
+ with gr.Row():
44
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
45
+ if mii >= len(kwargs['model_states']) / 2:
46
+ continue
47
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
48
+ with gr.Row():
49
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
50
+ if mii < len(kwargs['model_states']) / 2:
51
+ continue
52
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
53
+ elif nrows == 3:
54
+ with gr.Row():
55
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
56
+ if mii >= 1 * len(kwargs['model_states']) / 3:
57
+ continue
58
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
59
+ with gr.Row():
60
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
61
+ if mii < 1 * len(kwargs['model_states']) / 3 or mii >= 2 * len(kwargs['model_states']) / 3:
62
+ continue
63
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
64
+ with gr.Row():
65
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
66
+ if mii < 2 * len(kwargs['model_states']) / 3:
67
+ continue
68
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
69
+ elif nrows >= 4:
70
+ with gr.Row():
71
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
72
+ if mii >= 1 * len(kwargs['model_states']) / 4:
73
+ continue
74
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
75
+ with gr.Row():
76
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
77
+ if mii < 1 * len(kwargs['model_states']) / 4 or mii >= 2 * len(kwargs['model_states']) / 4:
78
+ continue
79
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
80
+ with gr.Row():
81
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
82
+ if mii < 2 * len(kwargs['model_states']) / 4 or mii >= 3 * len(kwargs['model_states']) / 4:
83
+ continue
84
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
85
+ with gr.Row():
86
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
87
+ if mii < 3 * len(kwargs['model_states']) / 4:
88
+ continue
89
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
90
+
91
+ with gr.Row():
92
+ text_output = gr.Chatbot(label=output_label0, visible=not kwargs['model_lock'], height=kwargs['height'] or 400)
93
+ text_output2 = gr.Chatbot(label=output_label0_model2,
94
+ visible=False and not kwargs['model_lock'], height=kwargs['height'] or 400)
95
+ return text_output, text_output2, text_outputs
96
+
97
+
98
+ def make_prompt_form(kwargs):
99
+ if kwargs['input_lines'] > 1:
100
+ instruction_label = "Shift-Enter to Submit, Enter for more lines"
101
+ else:
102
+ instruction_label = "Enter to Submit, Shift-Enter for more lines"
103
+
104
+ with gr.Row():#elem_id='prompt-form-area'):
105
+ with gr.Column(scale=50):
106
+ instruction = gr.Textbox(
107
+ lines=kwargs['input_lines'],
108
+ label='Ask anything',
109
+ placeholder=instruction_label,
110
+ info=None,
111
+ elem_id='prompt-form',
112
+ container=True,
113
+ )
114
+ with gr.Row():
115
+ submit = gr.Button(value='Submit', variant='primary', scale=0, size='sm')
116
+ stop_btn = gr.Button(value="Stop", variant='secondary', scale=0, size='sm')
117
+
118
+ return instruction, submit, stop_btn
h2oai_pipeline.py CHANGED
@@ -1,14 +1,17 @@
 
 
1
  from transformers import TextGenerationPipeline
2
  from transformers.pipelines.text_generation import ReturnType
3
 
4
  from stopping import get_stopping
5
- from prompter import Prompter
6
 
7
 
8
  class H2OTextGenerationPipeline(TextGenerationPipeline):
9
  def __init__(self, *args, debug=False, chat=False, stream_output=False,
10
- sanitize_bot_response=True,
11
- use_prompter=True, prompter=None, prompt_type=None,
 
12
  max_input_tokens=2048 - 256, **kwargs):
13
  """
14
  HF-like pipeline, but handle instruction prompting and stopping (for some models)
@@ -21,6 +24,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
21
  :param prompter: prompter, can pass if have already
22
  :param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
23
  If use_prompter, then will make prompter and use it.
 
24
  :param max_input_tokens:
25
  :param kwargs:
26
  """
@@ -28,12 +32,14 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
28
  self.prompt_text = None
29
  self.use_prompter = use_prompter
30
  self.prompt_type = prompt_type
 
31
  self.prompter = prompter
32
  if self.use_prompter:
33
  if self.prompter is not None:
34
  assert self.prompter.prompt_type is not None
35
  else:
36
- self.prompter = Prompter(self.prompt_type, debug=debug, chat=chat, stream_output=stream_output)
 
37
  self.human = self.prompter.humanstr
38
  self.bot = self.prompter.botstr
39
  self.can_stop = True
@@ -45,14 +51,75 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
45
  self.sanitize_bot_response = sanitize_bot_response
46
  self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
 
 
49
  data_point = dict(context='', instruction=prompt_text, input='')
50
  if self.prompter is not None:
51
  prompt_text = self.prompter.generate_prompt(data_point)
52
  self.prompt_text = prompt_text
53
  if handle_long_generation is None:
54
  # forces truncation of inputs to avoid critical failure
55
- handle_long_generation = 'hole'
56
  return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
57
  **generate_kwargs)
58
 
@@ -65,7 +132,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
65
  outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
66
  sanitize_bot_response=self.sanitize_bot_response)
67
  elif self.bot and self.human:
68
- outputs = rec['generated_text'].split(self.bot)[1].strip().split(self.human)[0].strip()
69
  else:
70
  outputs = rec['generated_text']
71
  rec['generated_text'] = outputs
@@ -73,8 +140,10 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
73
 
74
  def _forward(self, model_inputs, **generate_kwargs):
75
  if self.can_stop:
76
- stopping_criteria = get_stopping(self.prompt_type, self.tokenizer, self.device, human=self.human,
77
- bot=self.bot)
 
 
78
  generate_kwargs['stopping_criteria'] = stopping_criteria
79
  # return super()._forward(model_inputs, **generate_kwargs)
80
  return self.__forward(model_inputs, **generate_kwargs)
 
1
+ import os
2
+
3
  from transformers import TextGenerationPipeline
4
  from transformers.pipelines.text_generation import ReturnType
5
 
6
  from stopping import get_stopping
7
+ from prompter import Prompter, PromptType
8
 
9
 
10
  class H2OTextGenerationPipeline(TextGenerationPipeline):
11
  def __init__(self, *args, debug=False, chat=False, stream_output=False,
12
+ sanitize_bot_response=False,
13
+ use_prompter=True, prompter=None,
14
+ prompt_type=None, prompt_dict=None,
15
  max_input_tokens=2048 - 256, **kwargs):
16
  """
17
  HF-like pipeline, but handle instruction prompting and stopping (for some models)
 
24
  :param prompter: prompter, can pass if have already
25
  :param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
26
  If use_prompter, then will make prompter and use it.
27
+ :param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom
28
  :param max_input_tokens:
29
  :param kwargs:
30
  """
 
32
  self.prompt_text = None
33
  self.use_prompter = use_prompter
34
  self.prompt_type = prompt_type
35
+ self.prompt_dict = prompt_dict
36
  self.prompter = prompter
37
  if self.use_prompter:
38
  if self.prompter is not None:
39
  assert self.prompter.prompt_type is not None
40
  else:
41
+ self.prompter = Prompter(self.prompt_type, self.prompt_dict, debug=debug, chat=chat,
42
+ stream_output=stream_output)
43
  self.human = self.prompter.humanstr
44
  self.bot = self.prompter.botstr
45
  self.can_stop = True
 
51
  self.sanitize_bot_response = sanitize_bot_response
52
  self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
53
 
54
+ @staticmethod
55
+ def limit_prompt(prompt_text, tokenizer, max_prompt_length=None):
56
+ verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0')))
57
+
58
+ if hasattr(tokenizer, 'model_max_length'):
59
+ # model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
60
+ model_max_length = tokenizer.model_max_length
61
+ if max_prompt_length is not None:
62
+ model_max_length = min(model_max_length, max_prompt_length)
63
+ # cut at some upper likely limit to avoid excessive tokenization etc
64
+ # upper bound of 10 chars/token, e.g. special chars sometimes are long
65
+ if len(prompt_text) > model_max_length * 10:
66
+ len0 = len(prompt_text)
67
+ prompt_text = prompt_text[-model_max_length * 10:]
68
+ if verbose:
69
+ print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True)
70
+ else:
71
+ # unknown
72
+ model_max_length = None
73
+
74
+ num_prompt_tokens = None
75
+ if model_max_length is not None:
76
+ # can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
77
+ # For https://github.com/h2oai/h2ogpt/issues/192
78
+ for trial in range(0, 3):
79
+ prompt_tokens = tokenizer(prompt_text)['input_ids']
80
+ num_prompt_tokens = len(prompt_tokens)
81
+ if num_prompt_tokens > model_max_length:
82
+ # conservative by using int()
83
+ chars_per_token = int(len(prompt_text) / num_prompt_tokens)
84
+ # keep tail, where question is if using langchain
85
+ prompt_text = prompt_text[-model_max_length * chars_per_token:]
86
+ if verbose:
87
+ print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
88
+ num_prompt_tokens, chars_per_token, len(prompt_text)), flush=True)
89
+ else:
90
+ if verbose:
91
+ print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
92
+ break
93
+
94
+ # Why Below False: don't limit max_new_tokens more, just rely upon stopping to reach limit of model
95
+ if False:
96
+ # if input prompt is some number of tokens, despite user request, can't have max_new_tokens more
97
+ #
98
+ assert num_prompt_tokens is not None
99
+ if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]:
100
+ # then give room for prompt
101
+ fudge = 20
102
+ else:
103
+ fudge = 0
104
+ max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'],
105
+ model_max_length - (num_prompt_tokens + fudge)))
106
+ if max_new_tokens < generate_kwargs['max_new_tokens']:
107
+ if verbose:
108
+ print("Reduced max_new_tokens from %s -> %s" % (
109
+ generate_kwargs['max_new_tokens'], max_new_tokens))
110
+ generate_kwargs['max_new_tokens'] = max_new_tokens
111
+ return prompt_text, num_prompt_tokens
112
+
113
  def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
114
+ prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
115
+
116
  data_point = dict(context='', instruction=prompt_text, input='')
117
  if self.prompter is not None:
118
  prompt_text = self.prompter.generate_prompt(data_point)
119
  self.prompt_text = prompt_text
120
  if handle_long_generation is None:
121
  # forces truncation of inputs to avoid critical failure
122
+ handle_long_generation = None # disable with new approaches
123
  return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
124
  **generate_kwargs)
125
 
 
132
  outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
133
  sanitize_bot_response=self.sanitize_bot_response)
134
  elif self.bot and self.human:
135
+ outputs = rec['generated_text'].split(self.bot)[1].split(self.human)[0]
136
  else:
137
  outputs = rec['generated_text']
138
  rec['generated_text'] = outputs
 
140
 
141
  def _forward(self, model_inputs, **generate_kwargs):
142
  if self.can_stop:
143
+ stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict,
144
+ self.tokenizer, self.device,
145
+ human=self.human, bot=self.bot,
146
+ model_max_length=self.tokenizer.model_max_length)
147
  generate_kwargs['stopping_criteria'] = stopping_criteria
148
  # return super()._forward(model_inputs, **generate_kwargs)
149
  return self.__forward(model_inputs, **generate_kwargs)
loaders.py CHANGED
@@ -1,6 +1,8 @@
1
- def get_loaders(llama_type, model_name, reward_type):
2
  # NOTE: Some models need specific new prompt_type
3
  # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
 
 
4
  if llama_type:
5
  from transformers import LlamaForCausalLM, LlamaTokenizer
6
  model_loader = LlamaForCausalLM
@@ -39,7 +41,8 @@ def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resu
39
  tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
40
  local_files_only=local_files_only,
41
  resume_download=resume_download,
42
- use_auth_token=use_auth_token)
 
43
 
44
  tokenizer.pad_token_id = 0 # different from the eos token
45
  # when generating, we will use the logits of right-most token to predict the next token
 
1
+ def get_loaders(model_name, reward_type, llama_type=None):
2
  # NOTE: Some models need specific new prompt_type
3
  # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
4
+ if llama_type is None:
5
+ llama_type = "llama" in model_name.lower()
6
  if llama_type:
7
  from transformers import LlamaForCausalLM, LlamaTokenizer
8
  model_loader = LlamaForCausalLM
 
41
  tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
42
  local_files_only=local_files_only,
43
  resume_download=resume_download,
44
+ use_auth_token=use_auth_token,
45
+ padding_side='left')
46
 
47
  tokenizer.pad_token_id = 0 # different from the eos token
48
  # when generating, we will use the logits of right-most token to predict the next token
prompter.py CHANGED
@@ -1,30 +1,10 @@
 
 
1
  import time
2
- from enum import Enum
3
 
4
  non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
5
 
6
-
7
- class PromptType(Enum):
8
- plain = 0
9
- instruct = 1
10
- quality = 2
11
- human_bot = 3
12
- dai_faq = 4
13
- summarize = 5
14
- simple_instruct = 6
15
- instruct_vicuna = 7
16
- instruct_with_end = 8
17
- human_bot_orig = 9
18
- prompt_answer = 10
19
- open_assistant = 11
20
- wizard_lm = 12
21
- wizard_mega = 13
22
- instruct_vicuna2 = 14
23
- instruct_vicuna3 = 15
24
- wizard2 = 16
25
- wizard3 = 17
26
-
27
-
28
  prompt_type_to_model_name = {
29
  'plain': [
30
  'EleutherAI/gpt-j-6B',
@@ -45,17 +25,29 @@ prompt_type_to_model_name = {
45
  'mosaicml/mpt-7b-storywriter',
46
  'mosaicml/mpt-7b-instruct', # internal code handles instruct
47
  'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
48
- 'gptj', # internally handles prompting
49
- 'llama', # plain, or need to choose prompt_type for given TheBloke model
50
- 'gpt4all_llama', # internally handles prompting
51
  ],
 
52
  'prompt_answer': [
53
  'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
54
  'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
55
  'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
 
 
 
 
 
 
 
 
 
 
 
56
  'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
57
  'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
58
  'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
 
 
59
  ],
60
  'instruct': [],
61
  'instruct_with_end': ['databricks/dolly-v2-12b'],
@@ -68,6 +60,7 @@ prompt_type_to_model_name = {
68
  'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
69
  'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
70
  'h2oai/h2ogpt-research-oasst1-512-30b',
 
71
  'h2oai/h2ogpt-oasst1-falcon-40b',
72
  'h2oai/h2ogpt-oig-oasst1-falcon-40b',
73
  ],
@@ -79,7 +72,17 @@ prompt_type_to_model_name = {
79
  "open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
80
  "wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
81
  "wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
 
 
 
 
 
82
  }
 
 
 
 
 
83
 
84
  inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
85
  inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
@@ -93,20 +96,53 @@ for p in PromptType:
93
  prompt_types.extend([p.name, p.value, str(p.value)])
94
 
95
 
96
- def get_prompt(prompt_type, chat, context, reduced):
97
- if prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
98
- PromptType.plain.name]:
99
- promptA = promptB = PreInstruct = PreInput = PreResponse = ''
100
- terminate_response = []
 
 
 
 
 
 
 
 
 
 
 
101
  chat_sep = ''
 
102
  humanstr = ''
103
  botstr = ''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  elif prompt_type == 'simple_instruct':
105
  promptA = promptB = PreInstruct = PreInput = PreResponse = None
106
  terminate_response = []
107
- chat_sep = '\n'
108
- humanstr = ''
109
- botstr = ''
110
  elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
111
  PromptType.instruct.name] + [PromptType.instruct_with_end.value,
112
  str(PromptType.instruct_with_end.value),
@@ -132,7 +168,7 @@ def get_prompt(prompt_type, chat, context, reduced):
132
  terminate_response = ['### End']
133
  else:
134
  terminate_response = None
135
- chat_sep = '\n'
136
  humanstr = PreInstruct
137
  botstr = PreResponse
138
  elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
@@ -154,7 +190,7 @@ def get_prompt(prompt_type, chat, context, reduced):
154
  ### Response:
155
  """
156
  terminate_response = None
157
- chat_sep = '\n'
158
  humanstr = PreInstruct # first thing human says
159
  botstr = PreResponse # first thing bot says
160
  elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
@@ -176,14 +212,14 @@ Current Time: {}
176
 
177
  """
178
  preprompt = PRE_PROMPT.format(cur_date, cur_time)
179
- start = human
180
- promptB = promptA = '%s%s ' % (preprompt, start)
181
 
182
- PreInstruct = ""
183
 
184
  PreInput = None
185
 
186
- if reduced:
187
  # when making context, want it to appear as-if LLM generated, which starts with space after :
188
  PreResponse = bot + ' '
189
  else:
@@ -191,10 +227,11 @@ Current Time: {}
191
  # if add space here, non-unique tokenization will often make LLM produce wrong output
192
  PreResponse = bot
193
 
194
- terminate_response = [start, PreResponse]
195
- chat_sep = '\n'
196
  humanstr = human # tag before human talks
197
  botstr = bot # tag before bot talks
 
198
  elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
199
  PromptType.dai_faq.name]:
200
  promptA = ''
@@ -210,7 +247,7 @@ Current Time: {}
210
  ### Driverless AI documentation answer:
211
  """
212
  terminate_response = ['\n\n']
213
- chat_sep = terminate_response
214
  humanstr = PreInstruct
215
  botstr = PreResponse
216
  elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
@@ -219,7 +256,7 @@ Current Time: {}
219
  PreInstruct = '## Main Text\n\n'
220
  PreResponse = '\n\n## Summary\n\n'
221
  terminate_response = None
222
- chat_sep = '\n'
223
  humanstr = PreInstruct
224
  botstr = PreResponse
225
  elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
@@ -239,7 +276,7 @@ Current Time: {}
239
  """
240
  terminate_response = [
241
  '### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
242
- chat_sep = '\n'
243
  humanstr = PreInstruct
244
  botstr = PreResponse
245
  elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
@@ -247,33 +284,50 @@ Current Time: {}
247
  preprompt = ''
248
  prompt_tokens = "<|prompt|>"
249
  answer_tokens = "<|answer|>"
250
- start = prompt_tokens
251
  promptB = promptA = '%s%s' % (preprompt, start)
252
- PreInstruct = ""
253
  PreInput = None
254
  PreResponse = answer_tokens
255
  eos = '<|endoftext|>' # neox eos
256
- terminate_response = [start, PreResponse, eos]
257
- chat_sep = eos
258
  humanstr = prompt_tokens
259
  botstr = answer_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
261
  PromptType.open_assistant.name]:
262
  # From added_tokens.json
263
  preprompt = ''
264
  prompt_tokens = "<|prompter|>"
265
  answer_tokens = "<|assistant|>"
266
- start = prompt_tokens
267
  promptB = promptA = '%s%s' % (preprompt, start)
268
- PreInstruct = ""
269
  PreInput = None
270
  PreResponse = answer_tokens
271
  pend = "<|prefix_end|>"
272
  eos = "</s>"
273
- terminate_response = [start, PreResponse, pend, eos]
274
- chat_sep = eos
275
  humanstr = prompt_tokens
276
  botstr = answer_tokens
 
 
277
  elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
278
  PromptType.wizard_lm.name]:
279
  # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
@@ -285,7 +339,7 @@ Current Time: {}
285
  PreResponse = "\n\n### Response\n"
286
  eos = "</s>"
287
  terminate_response = [PreResponse, eos]
288
- chat_sep = eos
289
  humanstr = promptA
290
  botstr = PreResponse
291
  elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
@@ -301,13 +355,12 @@ Current Time: {}
301
  ### Assistant:
302
  """
303
  terminate_response = [PreResponse]
304
- chat_sep = '\n'
305
  humanstr = PreInstruct
306
  botstr = PreResponse
307
  elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
308
  PromptType.instruct_vicuna2.name]:
309
- promptA = promptB = "" if not (
310
- chat and reduced) else ''
311
 
312
  PreInstruct = """
313
  HUMAN:
@@ -320,13 +373,12 @@ ASSISTANT:
320
  """
321
  terminate_response = [
322
  'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
323
- chat_sep = '\n'
324
  humanstr = PreInstruct
325
  botstr = PreResponse
326
  elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
327
  PromptType.instruct_vicuna3.name]:
328
- promptA = promptB = "" if not (
329
- chat and reduced) else ''
330
 
331
  PreInstruct = """
332
  ### User:
@@ -339,13 +391,14 @@ ASSISTANT:
339
  """
340
  terminate_response = [
341
  '### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
342
- chat_sep = '\n'
343
  humanstr = PreInstruct
344
  botstr = PreResponse
345
  elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
346
  PromptType.wizard2.name]:
347
  # https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
348
- preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request."""
 
349
  start = ''
350
  promptB = promptA = '%s%s' % (preprompt, start)
351
  PreInstruct = """
@@ -356,30 +409,136 @@ ASSISTANT:
356
  ### Response:
357
  """
358
  terminate_response = [PreResponse]
359
- chat_sep = '\n'
360
  humanstr = PreInstruct
361
  botstr = PreResponse
362
  elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
363
  PromptType.wizard3.name]:
364
  # https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
365
- preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
 
366
  start = ''
367
  promptB = promptA = '%s%s' % (preprompt, start)
368
  PreInstruct = """USER: """
369
  PreInput = None
370
  PreResponse = """ASSISTANT: """
371
  terminate_response = [PreResponse]
372
- chat_sep = '\n'
373
  humanstr = PreInstruct
374
  botstr = PreResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  else:
377
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
378
 
379
- return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep, humanstr, botstr
 
 
 
 
 
 
 
380
 
 
 
 
 
381
 
382
- def generate_prompt(data_point, prompt_type, chat, reduced):
 
383
  context = data_point.get('context')
384
  if context is None:
385
  context = ''
@@ -387,11 +546,15 @@ def generate_prompt(data_point, prompt_type, chat, reduced):
387
  input = data_point.get('input')
388
  output = data_point.get('output')
389
  prompt_type = data_point.get('prompt_type', prompt_type)
 
390
  assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
391
  promptA, promptB, PreInstruct, PreInput, PreResponse, \
392
- terminate_response, chat_sep, humanstr, botstr = get_prompt(prompt_type, chat, context, reduced)
 
 
393
 
394
- prompt = context if not reduced else ''
 
395
 
396
  if input and promptA:
397
  prompt += f"""{promptA}"""
@@ -400,37 +563,37 @@ def generate_prompt(data_point, prompt_type, chat, reduced):
400
 
401
  if instruction and PreInstruct is not None and input and PreInput is not None:
402
  prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
403
- prompt = inject_newline(prompt_type, prompt)
404
  elif instruction and input and PreInstruct is None and PreInput is not None:
405
  prompt += f"""{PreInput}{instruction}
406
  {input}"""
407
- prompt = inject_newline(prompt_type, prompt)
408
  elif input and instruction and PreInput is None and PreInstruct is not None:
409
  prompt += f"""{PreInstruct}{instruction}
410
  {input}"""
411
- prompt = inject_newline(prompt_type, prompt)
412
  elif instruction and PreInstruct is not None:
413
  prompt += f"""{PreInstruct}{instruction}"""
414
- prompt = inject_newline(prompt_type, prompt)
415
  elif input and PreInput is not None:
416
  prompt += f"""{PreInput}{input}"""
417
- prompt = inject_newline(prompt_type, prompt)
418
  elif input and instruction and PreInput is not None:
419
  prompt += f"""{PreInput}{instruction}{input}"""
420
- prompt = inject_newline(prompt_type, prompt)
421
  elif input and instruction and PreInstruct is not None:
422
  prompt += f"""{PreInstruct}{instruction}{input}"""
423
- prompt = inject_newline(prompt_type, prompt)
424
  elif input and instruction:
425
  # i.e. for simple_instruct
426
  prompt += f"""{instruction}: {input}"""
427
- prompt = inject_newline(prompt_type, prompt)
428
  elif input:
429
  prompt += f"""{input}"""
430
- prompt = inject_newline(prompt_type, prompt)
431
  elif instruction:
432
  prompt += f"""{instruction}"""
433
- prompt = inject_newline(prompt_type, prompt)
434
 
435
  if PreResponse is not None:
436
  prompt += f"""{PreResponse}"""
@@ -441,23 +604,21 @@ def generate_prompt(data_point, prompt_type, chat, reduced):
441
  if output:
442
  prompt += f"""{output}"""
443
 
444
- return prompt, pre_response, terminate_response, chat_sep
445
 
446
 
447
- def inject_newline(prompt_type, prompt):
448
- if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
449
  # only add new line if structured prompt, while 'plain' is just generation of next tokens from input
450
- prompt += '\n'
451
  return prompt
452
 
453
 
454
  class Prompter(object):
455
- def __init__(self, prompt_type, debug=False, chat=False, stream_output=False, repeat_penalty=True,
456
  allowed_repeat_line_length=10):
457
  self.prompt_type = prompt_type
458
- data_point = dict(instruction='', input='', output='')
459
- _, self.pre_response, self.terminate_response, self.chat_sep = \
460
- generate_prompt(data_point, prompt_type, chat, False)
461
  self.debug = debug
462
  self.chat = chat
463
  self.stream_output = stream_output
@@ -466,23 +627,41 @@ class Prompter(object):
466
  self.prompt = None
467
  context = "" # not for chat context
468
  reduced = False # not for chat context
 
469
  self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
470
- self.terminate_response, self.chat_sep, self.humanstr, self.botstr = \
471
- get_prompt(prompt_type, chat, context, reduced)
472
-
473
- def generate_prompt(self, data_point):
474
- reduced = False
475
- prompt, _, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
 
 
 
 
 
 
 
 
 
 
476
  if self.debug:
477
- print("prompt: ", prompt, flush=True)
 
 
 
 
 
 
 
478
  self.prompt = prompt
479
  return prompt
480
 
481
- def get_response(self, outputs, prompt=None, sanitize_bot_response=True):
482
  if isinstance(outputs, str):
483
  outputs = [outputs]
484
  if self.debug:
485
- print("output:\n", '\n\n'.join(outputs), flush=True)
486
  if prompt is not None:
487
  self.prompt = prompt
488
 
@@ -493,7 +672,8 @@ class Prompter(object):
493
  if sanitize_bot_response:
494
  from better_profanity import profanity
495
  response = profanity.censor(response)
496
- response = response.strip("\n")
 
497
  return response
498
 
499
  def clean_repeats(response):
@@ -515,12 +695,12 @@ class Prompter(object):
515
  # then use most basic parsing like pipeline
516
  if self.botstr in output:
517
  if self.humanstr:
518
- output = clean_response(output.split(self.botstr)[1].strip().split(self.humanstr)[0].strip())
519
  else:
520
  # i.e. use after bot but only up to next bot
521
- output = clean_response(output.split(self.botstr)[1].strip().split(self.botstr)[0].strip())
522
  else:
523
- # output = clean_response(output.strip())
524
  # assume just not printed yet
525
  output = ""
526
  else:
@@ -547,9 +727,9 @@ class Prompter(object):
547
  allow_terminate = True
548
  output = output[len(prompt):]
549
  # clean after subtract prompt out, so correct removal of pre_response
550
- output = clean_response(output).strip()
551
  if self.repeat_penalty:
552
- output = clean_repeats(output).strip()
553
  if self.terminate_response and allow_terminate:
554
  finds = []
555
  for term in self.terminate_response:
@@ -557,11 +737,9 @@ class Prompter(object):
557
  finds = [x for x in finds if x >= 0]
558
  if len(finds) > 0:
559
  termi = finds[0]
560
- output = output[:termi].strip()
561
  else:
562
- output = output.strip()
563
- else:
564
- output = output.strip()
565
  if multi_output:
566
  # prefix with output counter
567
  output = "\n=========== Output %d\n\n" % (1 + oi) + output
@@ -572,5 +750,5 @@ class Prompter(object):
572
  # join all outputs, only one extra new line between outputs
573
  output = '\n'.join(outputs)
574
  if self.debug:
575
- print("outputclean:\n", '\n\n'.join(outputs), flush=True)
576
  return output
 
1
+ import os
2
+ import ast
3
  import time
4
+ from enums import PromptType # also supports imports from this file from other files
5
 
6
  non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  prompt_type_to_model_name = {
9
  'plain': [
10
  'EleutherAI/gpt-j-6B',
 
25
  'mosaicml/mpt-7b-storywriter',
26
  'mosaicml/mpt-7b-instruct', # internal code handles instruct
27
  'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
28
+ 'mosaicml/mpt-30b-instruct', # internal code handles instruct
 
 
29
  ],
30
+ 'gptj': ['gptj', 'gpt4all_llama'],
31
  'prompt_answer': [
32
  'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
33
  'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
34
  'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
35
+ 'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b',
36
+ 'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2',
37
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3',
38
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b',
39
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2',
40
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1',
41
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2',
42
+ 'h2oai/h2ogpt-gm-oasst1-en-xgen-7b-8k',
43
+ 'h2oai/h2ogpt-gm-oasst1-multilang-xgen-7b-8k',
44
+ ],
45
+ 'prompt_answer_openllama': [
46
  'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
47
  'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
48
  'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
49
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
50
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-13b',
51
  ],
52
  'instruct': [],
53
  'instruct_with_end': ['databricks/dolly-v2-12b'],
 
60
  'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
61
  'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
62
  'h2oai/h2ogpt-research-oasst1-512-30b',
63
+ 'h2oai/h2ogpt-research-oasst1-llama-65b',
64
  'h2oai/h2ogpt-oasst1-falcon-40b',
65
  'h2oai/h2ogpt-oig-oasst1-falcon-40b',
66
  ],
 
72
  "open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
73
  "wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
74
  "wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
75
+ "instruct_simple": ['JosephusCheung/Guanaco'],
76
+ "wizard_vicuna": ['ehartford/Wizard-Vicuna-13B-Uncensored'],
77
+ "wizard2": ['llama', 'mosaicml/mpt-30b-instruct'],
78
+ "vicuna11": ['lmsys/vicuna-33b-v1.3'],
79
+ # could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
80
  }
81
+ if os.getenv('OPENAI_API_KEY'):
82
+ prompt_type_to_model_name.update({
83
+ "openai": ["text-davinci-003", "text-curie-001", "text-babbage-001", "text-ada-001"],
84
+ "openai_chat": ["gpt-3.5-turbo", "gpt-3.5-turbo-16k"],
85
+ })
86
 
87
  inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
88
  inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
 
96
  prompt_types.extend([p.name, p.value, str(p.value)])
97
 
98
 
99
+ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, making_context, return_dict=False):
100
+ prompt_dict_error = ''
101
+ generates_leading_space = False
102
+
103
+ if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict):
104
+ try:
105
+ prompt_dict = ast.literal_eval(prompt_dict)
106
+ except BaseException as e:
107
+ prompt_dict_error = str(e)
108
+ if prompt_dict_error:
109
+ promptA = None
110
+ promptB = None
111
+ PreInstruct = None
112
+ PreInput = ''
113
+ PreResponse = ''
114
+ terminate_response = None
115
  chat_sep = ''
116
+ chat_turn_sep = ''
117
  humanstr = ''
118
  botstr = ''
119
+ generates_leading_space = False
120
+ elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
121
+ PromptType.custom.name]:
122
+ promptA = prompt_dict.get('promptA', '')
123
+ promptB = prompt_dict('promptB', '')
124
+ PreInstruct = prompt_dict.get('PreInstruct', '')
125
+ PreInput = prompt_dict.get('PreInput', '')
126
+ PreResponse = prompt_dict.get('PreResponse', '')
127
+ terminate_response = prompt_dict.get('terminate_response', None)
128
+ chat_sep = prompt_dict.get('chat_sep', '\n')
129
+ chat_turn_sep = prompt_dict.get('chat_turn_sep', '\n')
130
+ humanstr = prompt_dict.get('humanstr', '')
131
+ botstr = prompt_dict.get('botstr', '')
132
+ elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
133
+ PromptType.plain.name]:
134
+ promptA = promptB = PreInstruct = PreInput = PreResponse = None
135
+ terminate_response = []
136
+ chat_turn_sep = chat_sep = ''
137
+ # plain should have None for human/bot, so nothing truncated out, not '' that would truncate after first token
138
+ humanstr = None
139
+ botstr = None
140
  elif prompt_type == 'simple_instruct':
141
  promptA = promptB = PreInstruct = PreInput = PreResponse = None
142
  terminate_response = []
143
+ chat_turn_sep = chat_sep = '\n'
144
+ humanstr = None
145
+ botstr = None
146
  elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
147
  PromptType.instruct.name] + [PromptType.instruct_with_end.value,
148
  str(PromptType.instruct_with_end.value),
 
168
  terminate_response = ['### End']
169
  else:
170
  terminate_response = None
171
+ chat_turn_sep = chat_sep = '\n'
172
  humanstr = PreInstruct
173
  botstr = PreResponse
174
  elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
 
190
  ### Response:
191
  """
192
  terminate_response = None
193
+ chat_turn_sep = chat_sep = '\n'
194
  humanstr = PreInstruct # first thing human says
195
  botstr = PreResponse # first thing bot says
196
  elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
 
212
 
213
  """
214
  preprompt = PRE_PROMPT.format(cur_date, cur_time)
215
+ start = ''
216
+ promptB = promptA = '%s%s' % (preprompt, start)
217
 
218
+ PreInstruct = human + ' '
219
 
220
  PreInput = None
221
 
222
+ if making_context:
223
  # when making context, want it to appear as-if LLM generated, which starts with space after :
224
  PreResponse = bot + ' '
225
  else:
 
227
  # if add space here, non-unique tokenization will often make LLM produce wrong output
228
  PreResponse = bot
229
 
230
+ terminate_response = ['\n' + human, '\n' + bot, human, bot, PreResponse]
231
+ chat_turn_sep = chat_sep = '\n'
232
  humanstr = human # tag before human talks
233
  botstr = bot # tag before bot talks
234
+ generates_leading_space = True
235
  elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
236
  PromptType.dai_faq.name]:
237
  promptA = ''
 
247
  ### Driverless AI documentation answer:
248
  """
249
  terminate_response = ['\n\n']
250
+ chat_turn_sep = chat_sep = terminate_response
251
  humanstr = PreInstruct
252
  botstr = PreResponse
253
  elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
 
256
  PreInstruct = '## Main Text\n\n'
257
  PreResponse = '\n\n## Summary\n\n'
258
  terminate_response = None
259
+ chat_turn_sep = chat_sep = '\n'
260
  humanstr = PreInstruct
261
  botstr = PreResponse
262
  elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
 
276
  """
277
  terminate_response = [
278
  '### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
279
+ chat_turn_sep = chat_sep = '\n'
280
  humanstr = PreInstruct
281
  botstr = PreResponse
282
  elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
 
284
  preprompt = ''
285
  prompt_tokens = "<|prompt|>"
286
  answer_tokens = "<|answer|>"
287
+ start = ''
288
  promptB = promptA = '%s%s' % (preprompt, start)
289
+ PreInstruct = prompt_tokens
290
  PreInput = None
291
  PreResponse = answer_tokens
292
  eos = '<|endoftext|>' # neox eos
 
 
293
  humanstr = prompt_tokens
294
  botstr = answer_tokens
295
+ terminate_response = [humanstr, PreResponse, eos]
296
+ chat_sep = ''
297
+ chat_turn_sep = eos
298
+ elif prompt_type in [PromptType.prompt_answer_openllama.value, str(PromptType.prompt_answer_openllama.value),
299
+ PromptType.prompt_answer_openllama.name]:
300
+ preprompt = ''
301
+ prompt_tokens = "<|prompt|>"
302
+ answer_tokens = "<|answer|>"
303
+ start = ''
304
+ promptB = promptA = '%s%s' % (preprompt, start)
305
+ PreInstruct = prompt_tokens
306
+ PreInput = None
307
+ PreResponse = answer_tokens
308
+ eos = '</s>' # llama eos
309
+ humanstr = prompt_tokens
310
+ botstr = answer_tokens
311
+ terminate_response = [humanstr, PreResponse, eos]
312
+ chat_sep = ''
313
+ chat_turn_sep = eos
314
  elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
315
  PromptType.open_assistant.name]:
316
  # From added_tokens.json
317
  preprompt = ''
318
  prompt_tokens = "<|prompter|>"
319
  answer_tokens = "<|assistant|>"
320
+ start = ''
321
  promptB = promptA = '%s%s' % (preprompt, start)
322
+ PreInstruct = prompt_tokens
323
  PreInput = None
324
  PreResponse = answer_tokens
325
  pend = "<|prefix_end|>"
326
  eos = "</s>"
 
 
327
  humanstr = prompt_tokens
328
  botstr = answer_tokens
329
+ terminate_response = [humanstr, PreResponse, pend, eos]
330
+ chat_turn_sep = chat_sep = eos
331
  elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
332
  PromptType.wizard_lm.name]:
333
  # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
 
339
  PreResponse = "\n\n### Response\n"
340
  eos = "</s>"
341
  terminate_response = [PreResponse, eos]
342
+ chat_turn_sep = chat_sep = eos
343
  humanstr = promptA
344
  botstr = PreResponse
345
  elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
 
355
  ### Assistant:
356
  """
357
  terminate_response = [PreResponse]
358
+ chat_turn_sep = chat_sep = '\n'
359
  humanstr = PreInstruct
360
  botstr = PreResponse
361
  elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
362
  PromptType.instruct_vicuna2.name]:
363
+ promptA = promptB = "" if not (chat and reduced) else ''
 
364
 
365
  PreInstruct = """
366
  HUMAN:
 
373
  """
374
  terminate_response = [
375
  'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
376
+ chat_turn_sep = chat_sep = '\n'
377
  humanstr = PreInstruct
378
  botstr = PreResponse
379
  elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
380
  PromptType.instruct_vicuna3.name]:
381
+ promptA = promptB = "" if not (chat and reduced) else ''
 
382
 
383
  PreInstruct = """
384
  ### User:
 
391
  """
392
  terminate_response = [
393
  '### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
394
+ chat_turn_sep = chat_sep = '\n'
395
  humanstr = PreInstruct
396
  botstr = PreResponse
397
  elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
398
  PromptType.wizard2.name]:
399
  # https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
400
+ preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.""" if not (
401
+ chat and reduced) else ''
402
  start = ''
403
  promptB = promptA = '%s%s' % (preprompt, start)
404
  PreInstruct = """
 
409
  ### Response:
410
  """
411
  terminate_response = [PreResponse]
412
+ chat_turn_sep = chat_sep = '\n'
413
  humanstr = PreInstruct
414
  botstr = PreResponse
415
  elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
416
  PromptType.wizard3.name]:
417
  # https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
418
+ preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""" if not (
419
+ chat and reduced) else ''
420
  start = ''
421
  promptB = promptA = '%s%s' % (preprompt, start)
422
  PreInstruct = """USER: """
423
  PreInput = None
424
  PreResponse = """ASSISTANT: """
425
  terminate_response = [PreResponse]
426
+ chat_turn_sep = chat_sep = '\n'
427
  humanstr = PreInstruct
428
  botstr = PreResponse
429
+ elif prompt_type in [PromptType.wizard_vicuna.value, str(PromptType.wizard_vicuna.value),
430
+ PromptType.wizard_vicuna.name]:
431
+ preprompt = ''
432
+ start = ''
433
+ promptB = promptA = '%s%s' % (preprompt, start)
434
+ PreInstruct = """USER: """
435
+ PreInput = None
436
+ PreResponse = """ASSISTANT: """
437
+ terminate_response = [PreResponse]
438
+ chat_turn_sep = chat_sep = '\n'
439
+ humanstr = PreInstruct
440
+ botstr = PreResponse
441
+
442
+ elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value),
443
+ PromptType.instruct_simple.name]:
444
+ promptB = promptA = '' if not (chat and reduced) else ''
445
 
446
+ PreInstruct = """
447
+ ### Instruction:
448
+ """
449
+
450
+ PreInput = """
451
+ ### Input:
452
+ """
453
+
454
+ PreResponse = """
455
+ ### Response:
456
+ """
457
+ terminate_response = None
458
+ chat_turn_sep = chat_sep = '\n'
459
+ humanstr = PreInstruct
460
+ botstr = PreResponse
461
+ elif prompt_type in [PromptType.openai.value, str(PromptType.openai.value),
462
+ PromptType.openai.name]:
463
+ preprompt = """The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly.""" if not (
464
+ chat and reduced) else ''
465
+ start = ''
466
+ promptB = promptA = '%s%s' % (preprompt, start)
467
+ PreInstruct = "\nHuman: "
468
+ PreInput = None
469
+ PreResponse = "\nAI:"
470
+ terminate_response = [PreResponse] + [" Human:", " AI:"]
471
+ chat_turn_sep = chat_sep = '\n'
472
+ humanstr = PreInstruct
473
+ botstr = PreResponse
474
+ elif prompt_type in [PromptType.gptj.value, str(PromptType.gptj.value),
475
+ PromptType.gptj.name]:
476
+ preprompt = "### Instruction:\n The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response." if not (
477
+ chat and reduced) else ''
478
+ start = ''
479
+ promptB = promptA = '%s%s' % (preprompt, start)
480
+ PreInstruct = "\n### Prompt: "
481
+ PreInput = None
482
+ PreResponse = "\n### Response: "
483
+ terminate_response = [PreResponse] + ["Prompt:", "Response:"]
484
+ chat_turn_sep = chat_sep = '\n'
485
+ humanstr = PreInstruct
486
+ botstr = PreResponse
487
+ elif prompt_type in [PromptType.openai_chat.value, str(PromptType.openai_chat.value),
488
+ PromptType.openai_chat.name]:
489
+ # prompting and termination all handled by endpoint
490
+ preprompt = """"""
491
+ start = ''
492
+ promptB = promptA = '%s%s' % (preprompt, start)
493
+ PreInstruct = ""
494
+ PreInput = None
495
+ PreResponse = ""
496
+ terminate_response = []
497
+ chat_turn_sep = chat_sep = '\n'
498
+ humanstr = None
499
+ botstr = None
500
+ elif prompt_type in [PromptType.vicuna11.value, str(PromptType.vicuna11.value),
501
+ PromptType.vicuna11.name]:
502
+ preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ if not (
503
+ chat and reduced) else ''
504
+ start = ''
505
+ promptB = promptA = '%s%s' % (preprompt, start)
506
+ eos = '</s>'
507
+ PreInstruct = """USER: """
508
+ PreInput = None
509
+ PreResponse = """ASSISTANT:"""
510
+ terminate_response = [PreResponse]
511
+ chat_sep = ' '
512
+ chat_turn_sep = eos
513
+ humanstr = PreInstruct
514
+ botstr = PreResponse
515
+
516
+ if making_context:
517
+ # when making context, want it to appear as-if LLM generated, which starts with space after :
518
+ PreResponse = PreResponse + ' '
519
+ else:
520
+ # normally LLM adds space after this, because was how trained.
521
+ # if add space here, non-unique tokenization will often make LLM produce wrong output
522
+ PreResponse = PreResponse
523
  else:
524
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
525
 
526
+ if isinstance(terminate_response, (tuple, list)):
527
+ assert '' not in terminate_response, "Bad terminate_response"
528
+
529
+ ret_dict = dict(promptA=promptA, promptB=promptB, PreInstruct=PreInstruct, PreInput=PreInput,
530
+ PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep,
531
+ chat_turn_sep=chat_turn_sep,
532
+ humanstr=humanstr, botstr=botstr,
533
+ generates_leading_space=generates_leading_space)
534
 
535
+ if return_dict:
536
+ return ret_dict, prompt_dict_error
537
+ else:
538
+ return tuple(list(ret_dict.values()))
539
 
540
+
541
+ def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced, making_context):
542
  context = data_point.get('context')
543
  if context is None:
544
  context = ''
 
546
  input = data_point.get('input')
547
  output = data_point.get('output')
548
  prompt_type = data_point.get('prompt_type', prompt_type)
549
+ prompt_dict = data_point.get('prompt_dict', prompt_dict)
550
  assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
551
  promptA, promptB, PreInstruct, PreInput, PreResponse, \
552
+ terminate_response, chat_sep, chat_turn_sep, humanstr, botstr, \
553
+ generates_leading_space = get_prompt(prompt_type, prompt_dict, chat,
554
+ context, reduced, making_context)
555
 
556
+ # could avoid if reduce=True, but too complex for parent functions to handle
557
+ prompt = context
558
 
559
  if input and promptA:
560
  prompt += f"""{promptA}"""
 
563
 
564
  if instruction and PreInstruct is not None and input and PreInput is not None:
565
  prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
566
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
567
  elif instruction and input and PreInstruct is None and PreInput is not None:
568
  prompt += f"""{PreInput}{instruction}
569
  {input}"""
570
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
571
  elif input and instruction and PreInput is None and PreInstruct is not None:
572
  prompt += f"""{PreInstruct}{instruction}
573
  {input}"""
574
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
575
  elif instruction and PreInstruct is not None:
576
  prompt += f"""{PreInstruct}{instruction}"""
577
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
578
  elif input and PreInput is not None:
579
  prompt += f"""{PreInput}{input}"""
580
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
581
  elif input and instruction and PreInput is not None:
582
  prompt += f"""{PreInput}{instruction}{input}"""
583
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
584
  elif input and instruction and PreInstruct is not None:
585
  prompt += f"""{PreInstruct}{instruction}{input}"""
586
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
587
  elif input and instruction:
588
  # i.e. for simple_instruct
589
  prompt += f"""{instruction}: {input}"""
590
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
591
  elif input:
592
  prompt += f"""{input}"""
593
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
594
  elif instruction:
595
  prompt += f"""{instruction}"""
596
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
597
 
598
  if PreResponse is not None:
599
  prompt += f"""{PreResponse}"""
 
604
  if output:
605
  prompt += f"""{output}"""
606
 
607
+ return prompt, pre_response, terminate_response, chat_sep, chat_turn_sep
608
 
609
 
610
+ def inject_chatsep(prompt_type, prompt, chat_sep=None):
611
+ if chat_sep:
612
  # only add new line if structured prompt, while 'plain' is just generation of next tokens from input
613
+ prompt += chat_sep
614
  return prompt
615
 
616
 
617
  class Prompter(object):
618
+ def __init__(self, prompt_type, prompt_dict, debug=False, chat=False, stream_output=False, repeat_penalty=True,
619
  allowed_repeat_line_length=10):
620
  self.prompt_type = prompt_type
621
+ self.prompt_dict = prompt_dict
 
 
622
  self.debug = debug
623
  self.chat = chat
624
  self.stream_output = stream_output
 
627
  self.prompt = None
628
  context = "" # not for chat context
629
  reduced = False # not for chat context
630
+ making_context = False # not for chat context
631
  self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
632
+ self.terminate_response, self.chat_sep, self.chat_turn_sep, self.humanstr, self.botstr, \
633
+ self.generates_leading_space = \
634
+ get_prompt(self.prompt_type, self.prompt_dict, chat, context, reduced, making_context)
635
+ self.pre_response = self.PreResponse
636
+
637
+ def generate_prompt(self, data_point, reduced=None):
638
+ """
639
+ data_point['context'] is assumed to be like a system prompt or pre-conversation, not inserted after user prompt
640
+ :param data_point:
641
+ :param reduced:
642
+ :return:
643
+ """
644
+ reduced = data_point.get('context') not in ['', None] if reduced is None else reduced
645
+ making_context = False # whether really making final prompt or just generating context
646
+ prompt, _, _, _, _ = generate_prompt(data_point, self.prompt_type, self.prompt_dict, self.chat, reduced,
647
+ making_context)
648
  if self.debug:
649
+ print("prompt: %s" % prompt, flush=True)
650
+ # if have context, should have always reduced and only preappend promptA/B here
651
+ if data_point.get('context'):
652
+ if data_point.get('input') and self.promptA:
653
+ prompt = self.promptA + prompt
654
+ elif self.promptB:
655
+ prompt = self.promptB + prompt
656
+
657
  self.prompt = prompt
658
  return prompt
659
 
660
+ def get_response(self, outputs, prompt=None, sanitize_bot_response=False):
661
  if isinstance(outputs, str):
662
  outputs = [outputs]
663
  if self.debug:
664
+ print("output:\n%s" % '\n\n'.join(outputs), flush=True)
665
  if prompt is not None:
666
  self.prompt = prompt
667
 
 
672
  if sanitize_bot_response:
673
  from better_profanity import profanity
674
  response = profanity.censor(response)
675
+ if self.generates_leading_space and isinstance(response, str) and len(response) > 0 and response[0] == ' ':
676
+ response = response[1:]
677
  return response
678
 
679
  def clean_repeats(response):
 
695
  # then use most basic parsing like pipeline
696
  if self.botstr in output:
697
  if self.humanstr:
698
+ output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0])
699
  else:
700
  # i.e. use after bot but only up to next bot
701
+ output = clean_response(output.split(self.botstr)[1].split(self.botstr)[0])
702
  else:
703
+ # output = clean_response(output)
704
  # assume just not printed yet
705
  output = ""
706
  else:
 
727
  allow_terminate = True
728
  output = output[len(prompt):]
729
  # clean after subtract prompt out, so correct removal of pre_response
730
+ output = clean_response(output)
731
  if self.repeat_penalty:
732
+ output = clean_repeats(output)
733
  if self.terminate_response and allow_terminate:
734
  finds = []
735
  for term in self.terminate_response:
 
737
  finds = [x for x in finds if x >= 0]
738
  if len(finds) > 0:
739
  termi = finds[0]
740
+ output = output[:termi]
741
  else:
742
+ output = output
 
 
743
  if multi_output:
744
  # prefix with output counter
745
  output = "\n=========== Output %d\n\n" % (1 + oi) + output
 
750
  # join all outputs, only one extra new line between outputs
751
  output = '\n'.join(outputs)
752
  if self.debug:
753
+ print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
754
  return output
requirements.txt CHANGED
@@ -1,50 +1,50 @@
1
  # for generate (gradio server) and finetune
2
- datasets==2.12.0
3
- sentencepiece==0.1.97
4
- gradio==3.31.0
5
- huggingface_hub==0.14.1
6
  appdirs==1.4.4
7
  fire==0.5.0
8
- docutils==0.19
9
  torch==2.0.1
10
  evaluate==0.4.0
11
  rouge_score==0.1.2
12
  sacrebleu==2.3.1
13
  scikit-learn==1.2.2
14
  alt-profanity-check==1.2.2
15
- better-profanity==0.6.1
16
- numpy==1.24.2
17
- pandas==2.0.0
18
  matplotlib==3.7.1
19
  loralib==0.1.1
20
  bitsandbytes==0.39.0
21
- accelerate==0.19.0
22
- git+https://github.com/huggingface/peft.git@3714aa2fff158fdfa637b2b65952580801d890b2
23
- transformers==4.28.1
24
  tokenizers==0.13.3
25
  APScheduler==3.10.1
26
 
27
  # optional for generate
28
  pynvml==11.5.0
29
- psutil==5.9.4
30
  boto3==1.26.101
31
  botocore==1.29.101
32
 
33
  # optional for finetune
34
- tensorboard==2.12.1
35
- neptune==1.1.1
36
 
37
  # for gradio client
38
- gradio_client==0.2.5
39
  beautifulsoup4==4.12.2
40
- markdown==3.4.1
41
 
42
  # data and testing
43
  pytest==7.2.2
44
  pytest-xdist==3.2.1
45
  nltk==3.8.1
46
  textstat==0.7.3
47
- pandoc==2.3
48
  #pypandoc==1.11
49
  pypandoc_binary==1.11
50
  openpyxl==3.1.2
@@ -53,17 +53,66 @@ bioc==2.0
53
 
54
  # falcon
55
  einops==0.6.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # optional for chat with PDF
57
- langchain==0.0.183
58
- pypdf==3.8.1
59
- tiktoken==0.3.3
60
  # avoid textract, requires old six
61
  #textract==1.6.5
62
 
63
  # for HF embeddings
64
  sentence_transformers==2.2.2
65
- # for OpenAI embeddings (requires key)
66
- openai==0.27.6
67
 
68
  # local vector db
69
  chromadb==0.3.25
@@ -75,14 +124,14 @@ chromadb==0.3.25
75
 
76
  # strong support for images
77
  # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
78
- unstructured[local-inference]==0.6.6
79
  #pdf2image==1.16.3
80
  #pytesseract==0.3.10
81
  pillow
82
 
83
  pdfminer.six==20221105
84
- urllib3==1.26.6
85
- requests_file==1.5.1
86
 
87
  #pdf2image==1.16.3
88
  #pytesseract==0.3.10
@@ -97,4 +146,8 @@ tabulate==0.9.0
97
  pip-licenses==4.3.0
98
 
99
  # weaviate vector db
100
- weaviate-client==3.19.2
 
 
 
 
 
1
  # for generate (gradio server) and finetune
2
+ datasets==2.13.0
3
+ sentencepiece==0.1.99
4
+ gradio==3.35.2
5
+ huggingface_hub==0.15.1
6
  appdirs==1.4.4
7
  fire==0.5.0
8
+ docutils==0.20.1
9
  torch==2.0.1
10
  evaluate==0.4.0
11
  rouge_score==0.1.2
12
  sacrebleu==2.3.1
13
  scikit-learn==1.2.2
14
  alt-profanity-check==1.2.2
15
+ better-profanity==0.7.0
16
+ numpy==1.24.3
17
+ pandas==2.0.2
18
  matplotlib==3.7.1
19
  loralib==0.1.1
20
  bitsandbytes==0.39.0
21
+ accelerate==0.20.3
22
+ git+https://github.com/huggingface/peft.git@0b62b4378b4ce9367932c73540349da9a41bdea8
23
+ transformers==4.30.2
24
  tokenizers==0.13.3
25
  APScheduler==3.10.1
26
 
27
  # optional for generate
28
  pynvml==11.5.0
29
+ psutil==5.9.5
30
  boto3==1.26.101
31
  botocore==1.29.101
32
 
33
  # optional for finetune
34
+ tensorboard==2.13.0
35
+ neptune==1.2.0
36
 
37
  # for gradio client
38
+ gradio_client==0.2.7
39
  beautifulsoup4==4.12.2
40
+ markdown==3.4.3
41
 
42
  # data and testing
43
  pytest==7.2.2
44
  pytest-xdist==3.2.1
45
  nltk==3.8.1
46
  textstat==0.7.3
47
+ # pandoc==2.3
48
  #pypandoc==1.11
49
  pypandoc_binary==1.11
50
  openpyxl==3.1.2
 
53
 
54
  # falcon
55
  einops==0.6.1
56
+ instructorembedding==1.0.1
57
+
58
+ # for gpt4all .env file, but avoid worrying about imports
59
+ python-dotenv==1.0.0
60
+
61
+ text-generation==0.6.0
62
+ # for tokenization when don't have HF tokenizer
63
+ tiktoken==0.4.0
64
+ # optional: for OpenAI endpoint or embeddings (requires key)
65
+ openai==0.27.8
66
+ # optional for chat with PDF
67
+ langchain==0.0.202
68
+ pypdf==3.9.1
69
+ # avoid textract, requires old six
70
+ #textract==1.6.5
71
+
72
+ # for HF embeddings
73
+ sentence_transformers==2.2.2
74
+
75
+ # local vector db
76
+ chromadb==0.3.25
77
+ # server vector db
78
+ #pymilvus==2.2.8
79
+
80
+ # weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
81
+ # unstructured==0.6.6
82
+
83
+ # strong support for images
84
+ # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
85
+ unstructured[local-inference]==0.7.4
86
+ #pdf2image==1.16.3
87
+ #pytesseract==0.3.10
88
+ pillow
89
+
90
+ pdfminer.six==20221105
91
+ urllib3
92
+ requests_file
93
+
94
+ #pdf2image==1.16.3
95
+ #pytesseract==0.3.10
96
+ tabulate==0.9.0
97
+ # FYI pandoc already part of requirements.txt
98
+
99
+ # JSONLoader, but makes some trouble for some users
100
+ # jq==1.4.1
101
+
102
+ # to check licenses
103
+ # Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
104
+ pip-licenses==4.3.0
105
+
106
+ # weaviate vector db
107
+ weaviate-client==3.20.0
108
  # optional for chat with PDF
109
+ langchain==0.0.202
110
+ pypdf==3.9.1
 
111
  # avoid textract, requires old six
112
  #textract==1.6.5
113
 
114
  # for HF embeddings
115
  sentence_transformers==2.2.2
 
 
116
 
117
  # local vector db
118
  chromadb==0.3.25
 
124
 
125
  # strong support for images
126
  # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
127
+ unstructured[local-inference]==0.7.4
128
  #pdf2image==1.16.3
129
  #pytesseract==0.3.10
130
  pillow
131
 
132
  pdfminer.six==20221105
133
+ urllib3
134
+ requests_file
135
 
136
  #pdf2image==1.16.3
137
  #pytesseract==0.3.10
 
146
  pip-licenses==4.3.0
147
 
148
  # weaviate vector db
149
+ weaviate-client==3.20.0
150
+ faiss-gpu==1.7.2
151
+ arxiv==1.4.7
152
+ pymupdf==1.22.3 # AGPL license
153
+ # extract-msg==0.41.1 # GPL3
stopping.py CHANGED
@@ -1,17 +1,18 @@
1
  import torch
2
  from transformers import StoppingCriteria, StoppingCriteriaList
3
 
4
- from prompter import PromptType
5
 
6
 
7
  class StoppingCriteriaSub(StoppingCriteria):
8
 
9
- def __init__(self, stops=[], encounters=[], device="cuda"):
10
  super().__init__()
11
  assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
12
  self.encounters = encounters
13
  self.stops = [stop.to(device) for stop in stops]
14
  self.num_stops = [0] * len(stops)
 
15
 
16
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
17
  for stopi, stop in enumerate(self.stops):
@@ -20,12 +21,16 @@ class StoppingCriteriaSub(StoppingCriteria):
20
  if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
21
  # print("Stopped", flush=True)
22
  return True
 
 
 
23
  # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
24
  # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
25
  return False
26
 
27
 
28
- def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:"):
 
29
  if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
30
  if prompt_type == PromptType.human_bot.name:
31
  # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
@@ -66,7 +71,8 @@ def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:")
66
  stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
67
  # build stopper
68
  stopping_criteria = StoppingCriteriaList(
69
- [StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)])
 
70
  else:
71
  stopping_criteria = StoppingCriteriaList()
72
  return stopping_criteria
 
1
  import torch
2
  from transformers import StoppingCriteria, StoppingCriteriaList
3
 
4
+ from enums import PromptType
5
 
6
 
7
  class StoppingCriteriaSub(StoppingCriteria):
8
 
9
+ def __init__(self, stops=[], encounters=[], device="cuda", model_max_length=None):
10
  super().__init__()
11
  assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
12
  self.encounters = encounters
13
  self.stops = [stop.to(device) for stop in stops]
14
  self.num_stops = [0] * len(stops)
15
+ self.model_max_length = model_max_length
16
 
17
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
18
  for stopi, stop in enumerate(self.stops):
 
21
  if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
22
  # print("Stopped", flush=True)
23
  return True
24
+ if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length:
25
+ # critical limit
26
+ return True
27
  # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
28
  # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
29
  return False
30
 
31
 
32
+ def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:", model_max_length=None):
33
+ # FIXME: prompt_dict unused currently
34
  if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
35
  if prompt_type == PromptType.human_bot.name:
36
  # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
 
71
  stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
72
  # build stopper
73
  stopping_criteria = StoppingCriteriaList(
74
+ [StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device,
75
+ model_max_length=model_max_length)])
76
  else:
77
  stopping_criteria = StoppingCriteriaList()
78
  return stopping_criteria
utils.py CHANGED
@@ -14,6 +14,7 @@ import time
14
  import traceback
15
  import zipfile
16
  from datetime import datetime
 
17
  import filelock
18
  import requests, uuid
19
  from typing import Tuple, Callable, Dict
@@ -68,6 +69,25 @@ def ping():
68
  pass
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def get_torch_allocated():
72
  import torch
73
  return torch.cuda.memory_allocated()
@@ -97,27 +117,29 @@ def system_info():
97
  system['CPU_C/%s' % k] = v
98
 
99
  # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
100
- from pynvml.smi import nvidia_smi
101
- nvsmi = nvidia_smi.getInstance()
102
-
103
- gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in
104
- enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])}
105
- for k, v in gpu_power_dict.items():
106
- system['GPU_W/%s' % k] = v
107
-
108
- gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in
109
- enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])}
110
- for k, v in gpu_temp_dict.items():
111
- system['GPU_C/%s' % k] = v
112
-
113
- gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in
114
- enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])}
115
- gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in
116
- enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])}
117
- gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
118
- for k, v in gpu_memory_frac_dict.items():
119
- system[f'GPU_M/%s' % k] = v
120
-
 
 
121
  system['hash'] = get_githash()
122
 
123
  return system
@@ -166,35 +188,39 @@ def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
166
  return zip_file, zip_file
167
 
168
 
169
- def save_generate_output(output=None, base_model=None, save_dir=None):
 
170
  try:
171
- return _save_generate_output(output=output, base_model=base_model, save_dir=save_dir)
 
172
  except Exception as e:
173
  traceback.print_exc()
174
  print('Exception in saving: %s' % str(e))
175
 
176
 
177
- def _save_generate_output(output=None, base_model=None, save_dir=None):
 
178
  """
179
  Save conversation to .json, row by row.
180
  json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
181
  Appends if file exists
182
  """
 
 
183
  assert save_dir, "save_dir must be provided"
184
  if os.path.exists(save_dir) and not os.path.isdir(save_dir):
185
  raise RuntimeError("save_dir already exists and is not a directory!")
186
  os.makedirs(save_dir, exist_ok=True)
187
  import json
188
- if output[-10:] == '\n\n<human>:':
189
- # remove trailing <human>:
190
- output = output[:-10]
191
  with filelock.FileLock("save_dir.lock"):
192
  # lock logging in case have concurrency
193
  with open(os.path.join(save_dir, "history.json"), "a") as f:
194
  # just add [ at start, and ] at end, and have proper JSON dataset
195
  f.write(
196
  " " + json.dumps(
197
- dict(text=output, time=time.ctime(), base_model=base_model)
198
  ) + ",\n"
199
  )
200
 
@@ -800,6 +826,7 @@ def get_kwargs(func, exclude_names=None, **kwargs):
800
 
801
 
802
  import pkg_resources
 
803
  have_faiss = False
804
 
805
  try:
@@ -827,7 +854,7 @@ def hash_file(file):
827
  BUF_SIZE = 65536 # lets read stuff in 64kb chunks!
828
 
829
  md5 = hashlib.md5()
830
- #sha1 = hashlib.sha1()
831
 
832
  with open(file, 'rb') as f:
833
  while True:
@@ -835,9 +862,67 @@ def hash_file(file):
835
  if not data:
836
  break
837
  md5.update(data)
838
- #sha1.update(data)
839
  except BaseException as e:
840
  print("Cannot hash %s due to %s" % (file, str(e)))
841
  traceback.print_exc()
842
  md5 = None
843
  return md5.hexdigest()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  import traceback
15
  import zipfile
16
  from datetime import datetime
17
+
18
  import filelock
19
  import requests, uuid
20
  from typing import Tuple, Callable, Dict
 
69
  pass
70
 
71
 
72
+ def ping_gpu():
73
+ try:
74
+ print('Ping_GPU: %s %s' % (str(datetime.now()), system_info()), flush=True)
75
+ except AttributeError:
76
+ # some programs wrap print and will fail with flush passed
77
+ pass
78
+ try:
79
+ ping_gpu_memory()
80
+ except Exception as e:
81
+ print('Ping_GPU memory failure: %s' % str(e), flush=True)
82
+
83
+
84
+ def ping_gpu_memory():
85
+ from models.gpu_mem_track import MemTracker
86
+ gpu_tracker = MemTracker() # define a GPU tracker
87
+ from torch.cuda import memory_summary
88
+ gpu_tracker.track()
89
+
90
+
91
  def get_torch_allocated():
92
  import torch
93
  return torch.cuda.memory_allocated()
 
117
  system['CPU_C/%s' % k] = v
118
 
119
  # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
120
+ try:
121
+ from pynvml.smi import nvidia_smi
122
+ nvsmi = nvidia_smi.getInstance()
123
+
124
+ gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in
125
+ enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])}
126
+ for k, v in gpu_power_dict.items():
127
+ system['GPU_W/%s' % k] = v
128
+
129
+ gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in
130
+ enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])}
131
+ for k, v in gpu_temp_dict.items():
132
+ system['GPU_C/%s' % k] = v
133
+
134
+ gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in
135
+ enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])}
136
+ gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in
137
+ enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])}
138
+ gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
139
+ for k, v in gpu_memory_frac_dict.items():
140
+ system[f'GPU_M/%s' % k] = v
141
+ except ModuleNotFoundError:
142
+ pass
143
  system['hash'] = get_githash()
144
 
145
  return system
 
188
  return zip_file, zip_file
189
 
190
 
191
+ def save_generate_output(prompt=None, output=None, base_model=None, save_dir=None, where_from='unknown where from',
192
+ extra_dict={}):
193
  try:
194
+ return _save_generate_output(prompt=prompt, output=output, base_model=base_model, save_dir=save_dir,
195
+ where_from=where_from, extra_dict=extra_dict)
196
  except Exception as e:
197
  traceback.print_exc()
198
  print('Exception in saving: %s' % str(e))
199
 
200
 
201
+ def _save_generate_output(prompt=None, output=None, base_model=None, save_dir=None, where_from='unknown where from',
202
+ extra_dict={}):
203
  """
204
  Save conversation to .json, row by row.
205
  json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
206
  Appends if file exists
207
  """
208
+ prompt = '<not set>' if prompt is None else prompt
209
+ output = '<not set>' if output is None else output
210
  assert save_dir, "save_dir must be provided"
211
  if os.path.exists(save_dir) and not os.path.isdir(save_dir):
212
  raise RuntimeError("save_dir already exists and is not a directory!")
213
  os.makedirs(save_dir, exist_ok=True)
214
  import json
215
+ dict_to_save = dict(prompt=prompt, text=output, time=time.ctime(), base_model=base_model, where_from=where_from)
216
+ dict_to_save.update(extra_dict)
 
217
  with filelock.FileLock("save_dir.lock"):
218
  # lock logging in case have concurrency
219
  with open(os.path.join(save_dir, "history.json"), "a") as f:
220
  # just add [ at start, and ] at end, and have proper JSON dataset
221
  f.write(
222
  " " + json.dumps(
223
+ dict_to_save
224
  ) + ",\n"
225
  )
226
 
 
826
 
827
 
828
  import pkg_resources
829
+
830
  have_faiss = False
831
 
832
  try:
 
854
  BUF_SIZE = 65536 # lets read stuff in 64kb chunks!
855
 
856
  md5 = hashlib.md5()
857
+ # sha1 = hashlib.sha1()
858
 
859
  with open(file, 'rb') as f:
860
  while True:
 
862
  if not data:
863
  break
864
  md5.update(data)
865
+ # sha1.update(data)
866
  except BaseException as e:
867
  print("Cannot hash %s due to %s" % (file, str(e)))
868
  traceback.print_exc()
869
  md5 = None
870
  return md5.hexdigest()
871
+
872
+
873
+ def start_faulthandler():
874
+ # If hit server or any subprocess with signal SIGUSR1, it'll print out all threads stack trace, but wont't quit or coredump
875
+ # If more than one fork tries to write at same time, then looks corrupted.
876
+ import faulthandler
877
+
878
+ # SIGUSR1 in h2oai/__init__.py as well
879
+ faulthandler.enable()
880
+ if hasattr(faulthandler, 'register'):
881
+ # windows/mac
882
+ import signal
883
+ faulthandler.register(signal.SIGUSR1)
884
+
885
+
886
+ def get_hf_server(inference_server):
887
+ inf_split = inference_server.split(" ")
888
+ assert len(inf_split) == 1 or len(inf_split) == 3
889
+ inference_server = inf_split[0]
890
+ if len(inf_split) == 3:
891
+ headers = {"authorization": "%s %s" % (inf_split[1], inf_split[2])}
892
+ else:
893
+ headers = None
894
+ return inference_server, headers
895
+
896
+
897
+ class FakeTokenizer:
898
+ """
899
+ 1) For keeping track of model_max_length
900
+ 2) For when model doesn't directly expose tokenizer but need to count tokens
901
+ """
902
+
903
+ def __init__(self, model_max_length=2048, encoding_name="cl100k_base"):
904
+ # dont' push limit, since if using fake tokenizer, only estimate, and seen underestimates by order 250
905
+ self.model_max_length = model_max_length - 250
906
+ self.encoding_name = encoding_name
907
+ # The first time this runs, it will require an internet connection to download. Later runs won't need an internet connection.
908
+ import tiktoken
909
+ self.encoding = tiktoken.get_encoding(self.encoding_name)
910
+
911
+ def encode(self, x, *args, return_tensors="pt", **kwargs):
912
+ input_ids = self.encoding.encode(x, disallowed_special=())
913
+ if return_tensors == 'pt' and isinstance(input_ids, list):
914
+ import torch
915
+ input_ids = torch.tensor(input_ids)
916
+ return dict(input_ids=input_ids)
917
+
918
+ def decode(self, x, *args, **kwargs):
919
+ # input is input_ids[0] form
920
+ return self.encoding.decode(x)
921
+
922
+ def num_tokens_from_string(self, prompt: str) -> int:
923
+ """Returns the number of tokens in a text string."""
924
+ num_tokens = len(self.encoding.encode(prompt))
925
+ return num_tokens
926
+
927
+ def __call__(self, x, *args, **kwargs):
928
+ return self.encode(x, *args, **kwargs)
utils_langchain.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Union, Optional
2
+ import time
3
+ import queue
4
+
5
+ from langchain.callbacks.base import BaseCallbackHandler
6
+ from langchain.schema import LLMResult
7
+
8
+
9
+ class StreamingGradioCallbackHandler(BaseCallbackHandler):
10
+ """
11
+ Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend
12
+ """
13
+ def __init__(self, timeout: Optional[float] = None, block=True):
14
+ super().__init__()
15
+ self.text_queue = queue.SimpleQueue()
16
+ self.stop_signal = None
17
+ self.do_stop = False
18
+ self.timeout = timeout
19
+ self.block = block
20
+
21
+ def on_llm_start(
22
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
23
+ ) -> None:
24
+ """Run when LLM starts running. Clean the queue."""
25
+ while not self.text_queue.empty():
26
+ try:
27
+ self.text_queue.get(block=False)
28
+ except queue.Empty:
29
+ continue
30
+
31
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
32
+ """Run on new LLM token. Only available when streaming is enabled."""
33
+ self.text_queue.put(token)
34
+
35
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
36
+ """Run when LLM ends running."""
37
+ self.text_queue.put(self.stop_signal)
38
+
39
+ def on_llm_error(
40
+ self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
41
+ ) -> None:
42
+ """Run when LLM errors."""
43
+ self.text_queue.put(self.stop_signal)
44
+
45
+ def __iter__(self):
46
+ return self
47
+
48
+ def __next__(self):
49
+ while True:
50
+ try:
51
+ value = self.stop_signal # value looks unused in pycharm, not true
52
+ if self.do_stop:
53
+ print("hit stop", flush=True)
54
+ # could raise or break, maybe best to raise and make parent see if any exception in thread
55
+ raise StopIteration()
56
+ # break
57
+ value = self.text_queue.get(block=self.block, timeout=self.timeout)
58
+ break
59
+ except queue.Empty:
60
+ time.sleep(0.01)
61
+ if value == self.stop_signal:
62
+ raise StopIteration()
63
+ else:
64
+ return value