Spaces:
Runtime error
Runtime error
pseudotensor
commited on
Commit
•
b368114
1
Parent(s):
b64f5c9
Update with h2oGPT hash c37e5ee65166e4d964435193d5d8c23aaa8d3f09
Browse files- client_test.py +20 -7
- enums.py +8 -0
- evaluate_params.py +1 -0
- gen.py +59 -23
- gpt_langchain.py +145 -38
- gradio_runner.py +15 -1
- prompter.py +25 -0
- utils.py +17 -1
client_test.py
CHANGED
@@ -69,6 +69,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
|
69 |
top_k_docs=3,
|
70 |
langchain_mode='Disabled',
|
71 |
langchain_action=LangChainAction.QUERY.value,
|
|
|
72 |
prompt_dict=None):
|
73 |
from collections import OrderedDict
|
74 |
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
|
@@ -95,6 +96,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
|
95 |
iinput_nochat='', # only for chat=False
|
96 |
langchain_mode=langchain_mode,
|
97 |
langchain_action=langchain_action,
|
|
|
98 |
top_k_docs=top_k_docs,
|
99 |
chunk=True,
|
100 |
chunk_size=512,
|
@@ -203,6 +205,7 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
|
|
203 |
iinput_nochat='',
|
204 |
langchain_mode='Disabled',
|
205 |
langchain_action=LangChainAction.QUERY.value,
|
|
|
206 |
top_k_docs=4,
|
207 |
document_subset=DocumentChoices.Relevant.name,
|
208 |
document_choice=[],
|
@@ -225,23 +228,30 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
|
|
225 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
226 |
def test_client_chat(prompt_type='human_bot'):
|
227 |
return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
|
228 |
-
langchain_mode='Disabled',
|
|
|
|
|
229 |
|
230 |
|
231 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
232 |
def test_client_chat_stream(prompt_type='human_bot'):
|
233 |
return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
234 |
stream_output=True, max_new_tokens=512,
|
235 |
-
langchain_mode='Disabled',
|
|
|
|
|
236 |
|
237 |
|
238 |
-
def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens,
|
|
|
239 |
prompt_dict=None):
|
240 |
client = get_client(serialize=False)
|
241 |
|
242 |
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
|
243 |
-
max_new_tokens=max_new_tokens,
|
|
|
244 |
langchain_action=langchain_action,
|
|
|
245 |
prompt_dict=prompt_dict)
|
246 |
return run_client(client, prompt, args, kwargs)
|
247 |
|
@@ -285,15 +295,18 @@ def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
|
285 |
def test_client_nochat_stream(prompt_type='human_bot'):
|
286 |
return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
287 |
stream_output=True, max_new_tokens=512,
|
288 |
-
langchain_mode='Disabled',
|
|
|
|
|
289 |
|
290 |
|
291 |
-
def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens,
|
|
|
292 |
client = get_client(serialize=False)
|
293 |
|
294 |
kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
|
295 |
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
|
296 |
-
langchain_action=langchain_action)
|
297 |
return run_client_gen(client, prompt, args, kwargs)
|
298 |
|
299 |
|
|
|
69 |
top_k_docs=3,
|
70 |
langchain_mode='Disabled',
|
71 |
langchain_action=LangChainAction.QUERY.value,
|
72 |
+
langchain_agents=[],
|
73 |
prompt_dict=None):
|
74 |
from collections import OrderedDict
|
75 |
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
|
|
|
96 |
iinput_nochat='', # only for chat=False
|
97 |
langchain_mode=langchain_mode,
|
98 |
langchain_action=langchain_action,
|
99 |
+
langchain_agents=langchain_agents,
|
100 |
top_k_docs=top_k_docs,
|
101 |
chunk=True,
|
102 |
chunk_size=512,
|
|
|
205 |
iinput_nochat='',
|
206 |
langchain_mode='Disabled',
|
207 |
langchain_action=LangChainAction.QUERY.value,
|
208 |
+
langchain_agents=[],
|
209 |
top_k_docs=4,
|
210 |
document_subset=DocumentChoices.Relevant.name,
|
211 |
document_choice=[],
|
|
|
228 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
229 |
def test_client_chat(prompt_type='human_bot'):
|
230 |
return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
|
231 |
+
langchain_mode='Disabled',
|
232 |
+
langchain_action=LangChainAction.QUERY.value,
|
233 |
+
langchain_agents=[])
|
234 |
|
235 |
|
236 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
237 |
def test_client_chat_stream(prompt_type='human_bot'):
|
238 |
return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
239 |
stream_output=True, max_new_tokens=512,
|
240 |
+
langchain_mode='Disabled',
|
241 |
+
langchain_action=LangChainAction.QUERY.value,
|
242 |
+
langchain_agents=[])
|
243 |
|
244 |
|
245 |
+
def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens,
|
246 |
+
langchain_mode, langchain_action, langchain_agents,
|
247 |
prompt_dict=None):
|
248 |
client = get_client(serialize=False)
|
249 |
|
250 |
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
|
251 |
+
max_new_tokens=max_new_tokens,
|
252 |
+
langchain_mode=langchain_mode,
|
253 |
langchain_action=langchain_action,
|
254 |
+
langchain_agents=langchain_agents,
|
255 |
prompt_dict=prompt_dict)
|
256 |
return run_client(client, prompt, args, kwargs)
|
257 |
|
|
|
295 |
def test_client_nochat_stream(prompt_type='human_bot'):
|
296 |
return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
297 |
stream_output=True, max_new_tokens=512,
|
298 |
+
langchain_mode='Disabled',
|
299 |
+
langchain_action=LangChainAction.QUERY.value,
|
300 |
+
langchain_agents=[])
|
301 |
|
302 |
|
303 |
+
def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens,
|
304 |
+
langchain_mode, langchain_action, langchain_agents):
|
305 |
client = get_client(serialize=False)
|
306 |
|
307 |
kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
|
308 |
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
|
309 |
+
langchain_action=langchain_action, langchain_agents=langchain_agents)
|
310 |
return run_client_gen(client, prompt, args, kwargs)
|
311 |
|
312 |
|
enums.py
CHANGED
@@ -31,6 +31,7 @@ class PromptType(Enum):
|
|
31 |
mptinstruct = 25
|
32 |
mptchat = 26
|
33 |
falcon = 27
|
|
|
34 |
|
35 |
|
36 |
class DocumentChoices(Enum):
|
@@ -71,6 +72,13 @@ class LangChainAction(Enum):
|
|
71 |
SUMMARIZE_REFINE = "Summarize_refine"
|
72 |
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
no_server_str = no_lora_str = no_model_str = '[None/Remove]'
|
75 |
|
76 |
# from site-packages/langchain/llms/openai.py
|
|
|
31 |
mptinstruct = 25
|
32 |
mptchat = 26
|
33 |
falcon = 27
|
34 |
+
guanaco = 28
|
35 |
|
36 |
|
37 |
class DocumentChoices(Enum):
|
|
|
72 |
SUMMARIZE_REFINE = "Summarize_refine"
|
73 |
|
74 |
|
75 |
+
class LangChainAgent(Enum):
|
76 |
+
"""LangChain agents"""
|
77 |
+
|
78 |
+
SEARCH = "Search"
|
79 |
+
# CSV = "csv" # WIP
|
80 |
+
|
81 |
+
|
82 |
no_server_str = no_lora_str = no_model_str = '[None/Remove]'
|
83 |
|
84 |
# from site-packages/langchain/llms/openai.py
|
evaluate_params.py
CHANGED
@@ -31,6 +31,7 @@ eval_func_param_names = ['instruction',
|
|
31 |
'iinput_nochat',
|
32 |
'langchain_mode',
|
33 |
'langchain_action',
|
|
|
34 |
'top_k_docs',
|
35 |
'chunk',
|
36 |
'chunk_size',
|
|
|
31 |
'iinput_nochat',
|
32 |
'langchain_mode',
|
33 |
'langchain_action',
|
34 |
+
'langchain_agents',
|
35 |
'top_k_docs',
|
36 |
'chunk',
|
37 |
'chunk_size',
|
gen.py
CHANGED
@@ -29,11 +29,11 @@ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is
|
|
29 |
|
30 |
from evaluate_params import eval_func_param_names, no_default_param_names
|
31 |
from enums import DocumentChoices, LangChainMode, no_lora_str, model_token_mapping, no_model_str, source_prefix, \
|
32 |
-
source_postfix, LangChainAction
|
33 |
from loaders import get_loaders
|
34 |
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
|
35 |
import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove, \
|
36 |
-
have_langchain
|
37 |
|
38 |
start_faulthandler()
|
39 |
import_matplotlib()
|
@@ -54,6 +54,8 @@ langchain_modes = [x.value for x in list(LangChainMode)]
|
|
54 |
|
55 |
langchain_actions = [x.value for x in list(LangChainAction)]
|
56 |
|
|
|
|
|
57 |
scratch_base_dir = '/tmp/'
|
58 |
|
59 |
|
@@ -134,7 +136,7 @@ def main(
|
|
134 |
extra_lora_options: typing.List[str] = [],
|
135 |
extra_server_options: typing.List[str] = [],
|
136 |
|
137 |
-
score_model: str = '
|
138 |
|
139 |
eval_filename: str = None,
|
140 |
eval_prompts_only_num: int = 0,
|
@@ -143,15 +145,18 @@ def main(
|
|
143 |
|
144 |
langchain_mode: str = None,
|
145 |
langchain_action: str = LangChainAction.QUERY.value,
|
|
|
146 |
force_langchain_evaluate: bool = False,
|
147 |
visible_langchain_modes: list = ['UserData', 'MyData'],
|
148 |
# WIP:
|
149 |
# visible_langchain_actions: list = langchain_actions.copy(),
|
150 |
visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value],
|
|
|
151 |
document_subset: str = DocumentChoices.Relevant.name,
|
152 |
document_choice: list = [],
|
153 |
user_path: str = None,
|
154 |
detect_user_path_changes_every_query: bool = False,
|
|
|
155 |
load_db_if_exists: bool = True,
|
156 |
keep_sources_in_context: bool = False,
|
157 |
db_type: str = 'chroma',
|
@@ -196,6 +201,8 @@ def main(
|
|
196 |
Or Address can be "openai_chat" or "openai" for OpenAI API
|
197 |
e.g. python generate.py --inference_server="openai_chat" --base_model=gpt-3.5-turbo
|
198 |
e.g. python generate.py --inference_server="openai" --base_model=text-davinci-003
|
|
|
|
|
199 |
:param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
|
200 |
:param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
|
201 |
:param model_lock: Lock models to specific combinations, for ease of use and extending to many models
|
@@ -271,18 +278,24 @@ def main(
|
|
271 |
:param extra_model_options: extra models to show in list in gradio
|
272 |
:param extra_lora_options: extra LORA to show in list in gradio
|
273 |
:param extra_server_options: extra servers to show in list in gradio
|
274 |
-
:param score_model: which model to score responses
|
|
|
|
|
|
|
275 |
:param eval_filename: json file to use for evaluation, if None is sharegpt
|
276 |
:param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
|
277 |
:param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
|
278 |
:param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
|
279 |
:param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
|
|
|
280 |
WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
|
281 |
:param langchain_action: Mode langchain operations in on documents.
|
282 |
Query: Make query of document(s)
|
283 |
Summarize or Summarize_map_reduce: Summarize document(s) via map_reduce
|
284 |
Summarize_all: Summarize document(s) using entire document at once
|
285 |
Summarize_refine: Summarize document(s) using entire document, and try to refine before returning summary
|
|
|
|
|
286 |
:param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing.
|
287 |
:param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
|
288 |
If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
|
@@ -293,17 +306,18 @@ def main(
|
|
293 |
But wiki_full is expensive and requires preparation
|
294 |
To allow scratch space only live in session, add 'MyData' to list
|
295 |
Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
|
296 |
-
FIXME: Avoid 'All' for now, not implemented
|
297 |
:param visible_langchain_actions: Which actions to allow
|
|
|
298 |
:param document_subset: Default document choice when taking subset of collection
|
299 |
:param document_choice: Chosen document(s) by internal name
|
|
|
300 |
:param load_db_if_exists: Whether to load chroma db if exists or re-generate db
|
301 |
:param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
|
302 |
:param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
|
303 |
:param use_openai_embedding: Whether to use OpenAI embeddings for vector db
|
304 |
:param use_openai_model: Whether to use OpenAI model for use with vector db
|
305 |
:param hf_embedding_model: Which HF embedding model to use for vector db
|
306 |
-
Default is instructor-large with 768 parameters per embedding if have GPUs, else all-MiniLM-L6-
|
307 |
Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2"
|
308 |
Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl'
|
309 |
We support automatically changing of embeddings for chroma, with a backup of db made if this is done
|
@@ -327,6 +341,7 @@ def main(
|
|
327 |
captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
|
328 |
captions_model: str = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state
|
329 |
Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
|
|
|
330 |
:param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
|
331 |
parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
|
332 |
Recommended if using larger caption model
|
@@ -394,6 +409,8 @@ def main(
|
|
394 |
visible_langchain_modes += [langchain_mode]
|
395 |
|
396 |
assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
|
|
|
|
|
397 |
|
398 |
# if specifically chose not to show My or User Data, disable upload, so gradio elements are simpler
|
399 |
if LangChainMode.MY_DATA.value not in visible_langchain_modes:
|
@@ -413,7 +430,8 @@ def main(
|
|
413 |
" set user_path and ensure allow_upload_to_user_data=True" % langchain_mode, flush=True)
|
414 |
else:
|
415 |
raise RuntimeError("Please pass --langchain_mode=<chosen mode> out of %s" % langchain_modes)
|
416 |
-
if not have_langchain and langchain_mode not in [None, LangChainMode.DISABLED.value, LangChainMode.LLM.value,
|
|
|
417 |
raise RuntimeError("Asked for LangChain mode but langchain python package cannot be found.")
|
418 |
if langchain_mode is None:
|
419 |
# if not set yet, disable
|
@@ -474,7 +492,7 @@ def main(
|
|
474 |
# HF accounted for later in get_max_max_new_tokens()
|
475 |
save_dir = os.getenv('SAVE_DIR', save_dir)
|
476 |
score_model = os.getenv('SCORE_MODEL', score_model)
|
477 |
-
if score_model == 'None'
|
478 |
score_model = ''
|
479 |
concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count))
|
480 |
api_open = bool(int(os.getenv('API_OPEN', str(int(api_open)))))
|
@@ -482,6 +500,7 @@ def main(
|
|
482 |
|
483 |
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
|
484 |
if n_gpus == 0:
|
|
|
485 |
gpu_id = None
|
486 |
load_8bit = False
|
487 |
load_4bit = False
|
@@ -499,7 +518,11 @@ def main(
|
|
499 |
if hf_embedding_model is None:
|
500 |
# if no GPUs, use simpler embedding model to avoid cost in time
|
501 |
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
|
|
|
|
|
502 |
else:
|
|
|
|
|
503 |
if hf_embedding_model is None:
|
504 |
# if still None, then set default
|
505 |
hf_embedding_model = 'hkunlp/instructor-large'
|
@@ -967,11 +990,13 @@ def get_model(
|
|
967 |
client = gr_client or hf_client
|
968 |
# Don't return None, None for model, tokenizer so triggers
|
969 |
return client, tokenizer, 'http'
|
970 |
-
if isinstance(inference_server, str) and
|
971 |
-
|
972 |
-
|
973 |
-
|
974 |
-
|
|
|
|
|
975 |
return inference_server, tokenizer, inference_server
|
976 |
assert not inference_server, "Malformed inference_server=%s" % inference_server
|
977 |
if base_model in non_hf_types:
|
@@ -1278,6 +1303,7 @@ def evaluate(
|
|
1278 |
iinput_nochat,
|
1279 |
langchain_mode,
|
1280 |
langchain_action,
|
|
|
1281 |
top_k_docs,
|
1282 |
chunk,
|
1283 |
chunk_size,
|
@@ -1298,6 +1324,7 @@ def evaluate(
|
|
1298 |
raise_generate_gpu_exceptions=None,
|
1299 |
chat_context=None,
|
1300 |
lora_weights=None,
|
|
|
1301 |
load_db_if_exists=True,
|
1302 |
dbs=None,
|
1303 |
user_path=None,
|
@@ -1452,6 +1479,8 @@ def evaluate(
|
|
1452 |
# THIRD PLACE where LangChain referenced, but imports only occur if enabled and have db to use
|
1453 |
assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
|
1454 |
assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
|
|
|
|
|
1455 |
if langchain_mode in ['MyData'] and my_db_state is not None and len(my_db_state) > 0 and my_db_state[0] is not None:
|
1456 |
db1 = my_db_state[0]
|
1457 |
elif dbs is not None and langchain_mode in dbs:
|
@@ -1484,6 +1513,7 @@ def evaluate(
|
|
1484 |
inference_server=inference_server,
|
1485 |
stream_output=stream_output,
|
1486 |
prompter=prompter,
|
|
|
1487 |
load_db_if_exists=load_db_if_exists,
|
1488 |
db=db1,
|
1489 |
user_path=user_path,
|
@@ -1498,6 +1528,7 @@ def evaluate(
|
|
1498 |
chunk_size=chunk_size,
|
1499 |
langchain_mode=langchain_mode,
|
1500 |
langchain_action=langchain_action,
|
|
|
1501 |
document_subset=document_subset,
|
1502 |
document_choice=document_choice,
|
1503 |
db_type=db_type,
|
@@ -1526,6 +1557,7 @@ def evaluate(
|
|
1526 |
inference_server=inference_server,
|
1527 |
langchain_mode=langchain_mode,
|
1528 |
langchain_action=langchain_action,
|
|
|
1529 |
document_subset=document_subset,
|
1530 |
document_choice=document_choice,
|
1531 |
num_prompt_tokens=num_prompt_tokens,
|
@@ -1549,12 +1581,12 @@ def evaluate(
|
|
1549 |
clear_torch_cache()
|
1550 |
return
|
1551 |
|
1552 |
-
if inference_server.startswith('
|
1553 |
-
|
1554 |
-
|
1555 |
where_from = "openai_client"
|
|
|
1556 |
|
1557 |
-
openai.api_key = os.getenv("OPENAI_API_KEY")
|
1558 |
terminate_response = prompter.terminate_response or []
|
1559 |
stop_sequences = list(set(terminate_response + [prompter.PreResponse]))
|
1560 |
stop_sequences = [x for x in stop_sequences if x]
|
@@ -1567,7 +1599,7 @@ def evaluate(
|
|
1567 |
n=num_return_sequences,
|
1568 |
presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
|
1569 |
)
|
1570 |
-
if inference_server == 'openai':
|
1571 |
response = openai.Completion.create(
|
1572 |
model=base_model,
|
1573 |
prompt=prompt,
|
@@ -1590,7 +1622,9 @@ def evaluate(
|
|
1590 |
yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
|
1591 |
sanitize_bot_response=sanitize_bot_response),
|
1592 |
sources='')
|
1593 |
-
elif inference_server == 'openai_chat':
|
|
|
|
|
1594 |
response = openai.ChatCompletion.create(
|
1595 |
model=base_model,
|
1596 |
messages=[
|
@@ -1643,6 +1677,7 @@ def evaluate(
|
|
1643 |
where_from = "gr_client"
|
1644 |
client_langchain_mode = 'Disabled'
|
1645 |
client_langchain_action = LangChainAction.QUERY.value
|
|
|
1646 |
gen_server_kwargs = dict(temperature=temperature,
|
1647 |
top_p=top_p,
|
1648 |
top_k=top_k,
|
@@ -1695,6 +1730,7 @@ def evaluate(
|
|
1695 |
iinput_nochat=gr_iinput, # only for chat=False
|
1696 |
langchain_mode=client_langchain_mode,
|
1697 |
langchain_action=client_langchain_action,
|
|
|
1698 |
top_k_docs=top_k_docs,
|
1699 |
chunk=chunk,
|
1700 |
chunk_size=chunk_size,
|
@@ -2276,8 +2312,8 @@ y = np.random.randint(0, 1, 100)
|
|
2276 |
|
2277 |
# move to correct position
|
2278 |
for example in examples:
|
2279 |
-
example += [chat, '', '', LangChainMode.DISABLED.value, LangChainAction.QUERY.value,
|
2280 |
-
top_k_docs, chunk, chunk_size,
|
2281 |
]
|
2282 |
# adjust examples if non-chat mode
|
2283 |
if not chat:
|
@@ -2383,14 +2419,14 @@ def check_locals(**kwargs):
|
|
2383 |
|
2384 |
|
2385 |
def get_model_max_length(model_state):
|
2386 |
-
if not isinstance(model_state['tokenizer'], (str,
|
2387 |
return model_state['tokenizer'].model_max_length
|
2388 |
else:
|
2389 |
return 2048
|
2390 |
|
2391 |
|
2392 |
def get_max_max_new_tokens(model_state, **kwargs):
|
2393 |
-
if not isinstance(model_state['tokenizer'], (str,
|
2394 |
max_max_new_tokens = model_state['tokenizer'].model_max_length
|
2395 |
else:
|
2396 |
max_max_new_tokens = None
|
|
|
29 |
|
30 |
from evaluate_params import eval_func_param_names, no_default_param_names
|
31 |
from enums import DocumentChoices, LangChainMode, no_lora_str, model_token_mapping, no_model_str, source_prefix, \
|
32 |
+
source_postfix, LangChainAction, LangChainAgent
|
33 |
from loaders import get_loaders
|
34 |
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
|
35 |
import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove, \
|
36 |
+
have_langchain, set_openai
|
37 |
|
38 |
start_faulthandler()
|
39 |
import_matplotlib()
|
|
|
54 |
|
55 |
langchain_actions = [x.value for x in list(LangChainAction)]
|
56 |
|
57 |
+
langchain_agents_list = [x.value for x in list(LangChainAgent)]
|
58 |
+
|
59 |
scratch_base_dir = '/tmp/'
|
60 |
|
61 |
|
|
|
136 |
extra_lora_options: typing.List[str] = [],
|
137 |
extra_server_options: typing.List[str] = [],
|
138 |
|
139 |
+
score_model: str = 'auto',
|
140 |
|
141 |
eval_filename: str = None,
|
142 |
eval_prompts_only_num: int = 0,
|
|
|
145 |
|
146 |
langchain_mode: str = None,
|
147 |
langchain_action: str = LangChainAction.QUERY.value,
|
148 |
+
langchain_agents: list = [],
|
149 |
force_langchain_evaluate: bool = False,
|
150 |
visible_langchain_modes: list = ['UserData', 'MyData'],
|
151 |
# WIP:
|
152 |
# visible_langchain_actions: list = langchain_actions.copy(),
|
153 |
visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value],
|
154 |
+
visible_langchain_agents: list = langchain_agents_list.copy(),
|
155 |
document_subset: str = DocumentChoices.Relevant.name,
|
156 |
document_choice: list = [],
|
157 |
user_path: str = None,
|
158 |
detect_user_path_changes_every_query: bool = False,
|
159 |
+
use_llm_if_no_docs: bool = False,
|
160 |
load_db_if_exists: bool = True,
|
161 |
keep_sources_in_context: bool = False,
|
162 |
db_type: str = 'chroma',
|
|
|
201 |
Or Address can be "openai_chat" or "openai" for OpenAI API
|
202 |
e.g. python generate.py --inference_server="openai_chat" --base_model=gpt-3.5-turbo
|
203 |
e.g. python generate.py --inference_server="openai" --base_model=text-davinci-003
|
204 |
+
Or Address can be "vllm:IP:port" or "vllm:IP:port" for OpenAI-compliant vLLM endpoint
|
205 |
+
Note: vllm_chat not supported by vLLM project.
|
206 |
:param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
|
207 |
:param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
|
208 |
:param model_lock: Lock models to specific combinations, for ease of use and extending to many models
|
|
|
278 |
:param extra_model_options: extra models to show in list in gradio
|
279 |
:param extra_lora_options: extra LORA to show in list in gradio
|
280 |
:param extra_server_options: extra servers to show in list in gradio
|
281 |
+
:param score_model: which model to score responses
|
282 |
+
None: no response scoring
|
283 |
+
'auto': auto mode, '' (no model) for CPU, 'OpenAssistant/reward-model-deberta-v3-large-v2' for GPU,
|
284 |
+
because on CPU takes too much compute just for scoring response
|
285 |
:param eval_filename: json file to use for evaluation, if None is sharegpt
|
286 |
:param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
|
287 |
:param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
|
288 |
:param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
|
289 |
:param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
|
290 |
+
None: auto mode, check if langchain package exists, at least do ChatLLM if so, else Disabled
|
291 |
WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
|
292 |
:param langchain_action: Mode langchain operations in on documents.
|
293 |
Query: Make query of document(s)
|
294 |
Summarize or Summarize_map_reduce: Summarize document(s) via map_reduce
|
295 |
Summarize_all: Summarize document(s) using entire document at once
|
296 |
Summarize_refine: Summarize document(s) using entire document, and try to refine before returning summary
|
297 |
+
:param langchain_agents: Which agents to use
|
298 |
+
'search': Use Web Search as context for LLM response, e.g. SERP if have SERPAPI_API_KEY in env
|
299 |
:param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing.
|
300 |
:param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
|
301 |
If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
|
|
|
306 |
But wiki_full is expensive and requires preparation
|
307 |
To allow scratch space only live in session, add 'MyData' to list
|
308 |
Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
|
|
|
309 |
:param visible_langchain_actions: Which actions to allow
|
310 |
+
:param visible_langchain_agents: Which agents to allow
|
311 |
:param document_subset: Default document choice when taking subset of collection
|
312 |
:param document_choice: Chosen document(s) by internal name
|
313 |
+
:param use_llm_if_no_docs: Whether to use LLM even if no documents, when langchain_mode=UserData or MyData
|
314 |
:param load_db_if_exists: Whether to load chroma db if exists or re-generate db
|
315 |
:param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
|
316 |
:param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
|
317 |
:param use_openai_embedding: Whether to use OpenAI embeddings for vector db
|
318 |
:param use_openai_model: Whether to use OpenAI model for use with vector db
|
319 |
:param hf_embedding_model: Which HF embedding model to use for vector db
|
320 |
+
Default is instructor-large with 768 parameters per embedding if have GPUs, else all-MiniLM-L6-v2 if no GPUs
|
321 |
Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2"
|
322 |
Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl'
|
323 |
We support automatically changing of embeddings for chroma, with a backup of db made if this is done
|
|
|
341 |
captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
|
342 |
captions_model: str = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state
|
343 |
Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
|
344 |
+
Disabled for CPU since BLIP requires CUDA
|
345 |
:param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
|
346 |
parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
|
347 |
Recommended if using larger caption model
|
|
|
409 |
visible_langchain_modes += [langchain_mode]
|
410 |
|
411 |
assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
|
412 |
+
assert len(
|
413 |
+
set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents
|
414 |
|
415 |
# if specifically chose not to show My or User Data, disable upload, so gradio elements are simpler
|
416 |
if LangChainMode.MY_DATA.value not in visible_langchain_modes:
|
|
|
430 |
" set user_path and ensure allow_upload_to_user_data=True" % langchain_mode, flush=True)
|
431 |
else:
|
432 |
raise RuntimeError("Please pass --langchain_mode=<chosen mode> out of %s" % langchain_modes)
|
433 |
+
if not have_langchain and langchain_mode not in [None, LangChainMode.DISABLED.value, LangChainMode.LLM.value,
|
434 |
+
LangChainMode.CHAT_LLM.value]:
|
435 |
raise RuntimeError("Asked for LangChain mode but langchain python package cannot be found.")
|
436 |
if langchain_mode is None:
|
437 |
# if not set yet, disable
|
|
|
492 |
# HF accounted for later in get_max_max_new_tokens()
|
493 |
save_dir = os.getenv('SAVE_DIR', save_dir)
|
494 |
score_model = os.getenv('SCORE_MODEL', score_model)
|
495 |
+
if str(score_model) == 'None':
|
496 |
score_model = ''
|
497 |
concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count))
|
498 |
api_open = bool(int(os.getenv('API_OPEN', str(int(api_open)))))
|
|
|
500 |
|
501 |
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
|
502 |
if n_gpus == 0:
|
503 |
+
enable_captions = False
|
504 |
gpu_id = None
|
505 |
load_8bit = False
|
506 |
load_4bit = False
|
|
|
518 |
if hf_embedding_model is None:
|
519 |
# if no GPUs, use simpler embedding model to avoid cost in time
|
520 |
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
|
521 |
+
if score_model == 'auto':
|
522 |
+
score_model = ''
|
523 |
else:
|
524 |
+
if score_model == 'auto':
|
525 |
+
score_model = 'OpenAssistant/reward-model-deberta-v3-large-v2'
|
526 |
if hf_embedding_model is None:
|
527 |
# if still None, then set default
|
528 |
hf_embedding_model = 'hkunlp/instructor-large'
|
|
|
990 |
client = gr_client or hf_client
|
991 |
# Don't return None, None for model, tokenizer so triggers
|
992 |
return client, tokenizer, 'http'
|
993 |
+
if isinstance(inference_server, str) and (
|
994 |
+
inference_server.startswith('openai') or inference_server.startswith('vllm')):
|
995 |
+
if inference_server.startswith('openai'):
|
996 |
+
assert os.getenv('OPENAI_API_KEY'), "Set environment for OPENAI_API_KEY"
|
997 |
+
# Don't return None, None for model, tokenizer so triggers
|
998 |
+
# include small token cushion
|
999 |
+
tokenizer = FakeTokenizer(model_max_length=model_token_mapping[base_model] - 50)
|
1000 |
return inference_server, tokenizer, inference_server
|
1001 |
assert not inference_server, "Malformed inference_server=%s" % inference_server
|
1002 |
if base_model in non_hf_types:
|
|
|
1303 |
iinput_nochat,
|
1304 |
langchain_mode,
|
1305 |
langchain_action,
|
1306 |
+
langchain_agents,
|
1307 |
top_k_docs,
|
1308 |
chunk,
|
1309 |
chunk_size,
|
|
|
1324 |
raise_generate_gpu_exceptions=None,
|
1325 |
chat_context=None,
|
1326 |
lora_weights=None,
|
1327 |
+
use_llm_if_no_docs=False,
|
1328 |
load_db_if_exists=True,
|
1329 |
dbs=None,
|
1330 |
user_path=None,
|
|
|
1479 |
# THIRD PLACE where LangChain referenced, but imports only occur if enabled and have db to use
|
1480 |
assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
|
1481 |
assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
|
1482 |
+
assert len(
|
1483 |
+
set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents
|
1484 |
if langchain_mode in ['MyData'] and my_db_state is not None and len(my_db_state) > 0 and my_db_state[0] is not None:
|
1485 |
db1 = my_db_state[0]
|
1486 |
elif dbs is not None and langchain_mode in dbs:
|
|
|
1513 |
inference_server=inference_server,
|
1514 |
stream_output=stream_output,
|
1515 |
prompter=prompter,
|
1516 |
+
use_llm_if_no_docs=use_llm_if_no_docs,
|
1517 |
load_db_if_exists=load_db_if_exists,
|
1518 |
db=db1,
|
1519 |
user_path=user_path,
|
|
|
1528 |
chunk_size=chunk_size,
|
1529 |
langchain_mode=langchain_mode,
|
1530 |
langchain_action=langchain_action,
|
1531 |
+
langchain_agents=langchain_agents,
|
1532 |
document_subset=document_subset,
|
1533 |
document_choice=document_choice,
|
1534 |
db_type=db_type,
|
|
|
1557 |
inference_server=inference_server,
|
1558 |
langchain_mode=langchain_mode,
|
1559 |
langchain_action=langchain_action,
|
1560 |
+
langchain_agents=langchain_agents,
|
1561 |
document_subset=document_subset,
|
1562 |
document_choice=document_choice,
|
1563 |
num_prompt_tokens=num_prompt_tokens,
|
|
|
1581 |
clear_torch_cache()
|
1582 |
return
|
1583 |
|
1584 |
+
if inference_server.startswith('vllm') or inference_server.startswith('openai') or inference_server.startswith(
|
1585 |
+
'http'):
|
1586 |
+
if inference_server.startswith('vllm') or inference_server.startswith('openai'):
|
1587 |
where_from = "openai_client"
|
1588 |
+
openai, inf_type = set_openai(inference_server)
|
1589 |
|
|
|
1590 |
terminate_response = prompter.terminate_response or []
|
1591 |
stop_sequences = list(set(terminate_response + [prompter.PreResponse]))
|
1592 |
stop_sequences = [x for x in stop_sequences if x]
|
|
|
1599 |
n=num_return_sequences,
|
1600 |
presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
|
1601 |
)
|
1602 |
+
if inf_type == 'vllm' or inference_server == 'openai':
|
1603 |
response = openai.Completion.create(
|
1604 |
model=base_model,
|
1605 |
prompt=prompt,
|
|
|
1622 |
yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
|
1623 |
sanitize_bot_response=sanitize_bot_response),
|
1624 |
sources='')
|
1625 |
+
elif inf_type == 'vllm_chat' or inference_server == 'openai_chat':
|
1626 |
+
if inf_type == 'vllm_chat':
|
1627 |
+
raise NotImplementedError('%s not supported by vLLM' % inf_type)
|
1628 |
response = openai.ChatCompletion.create(
|
1629 |
model=base_model,
|
1630 |
messages=[
|
|
|
1677 |
where_from = "gr_client"
|
1678 |
client_langchain_mode = 'Disabled'
|
1679 |
client_langchain_action = LangChainAction.QUERY.value
|
1680 |
+
client_langchain_agents = []
|
1681 |
gen_server_kwargs = dict(temperature=temperature,
|
1682 |
top_p=top_p,
|
1683 |
top_k=top_k,
|
|
|
1730 |
iinput_nochat=gr_iinput, # only for chat=False
|
1731 |
langchain_mode=client_langchain_mode,
|
1732 |
langchain_action=client_langchain_action,
|
1733 |
+
langchain_agents=client_langchain_agents,
|
1734 |
top_k_docs=top_k_docs,
|
1735 |
chunk=chunk,
|
1736 |
chunk_size=chunk_size,
|
|
|
2312 |
|
2313 |
# move to correct position
|
2314 |
for example in examples:
|
2315 |
+
example += [chat, '', '', LangChainMode.DISABLED.value, LangChainAction.QUERY.value, [],
|
2316 |
+
top_k_docs, chunk, chunk_size, DocumentChoices.Relevant.name, []
|
2317 |
]
|
2318 |
# adjust examples if non-chat mode
|
2319 |
if not chat:
|
|
|
2419 |
|
2420 |
|
2421 |
def get_model_max_length(model_state):
|
2422 |
+
if not isinstance(model_state['tokenizer'], (str, type(None))):
|
2423 |
return model_state['tokenizer'].model_max_length
|
2424 |
else:
|
2425 |
return 2048
|
2426 |
|
2427 |
|
2428 |
def get_max_max_new_tokens(model_state, **kwargs):
|
2429 |
+
if not isinstance(model_state['tokenizer'], (str, type(None))):
|
2430 |
max_max_new_tokens = model_state['tokenizer'].model_max_length
|
2431 |
else:
|
2432 |
max_max_new_tokens = None
|
gpt_langchain.py
CHANGED
@@ -21,6 +21,7 @@ import filelock
|
|
21 |
from joblib import delayed
|
22 |
from langchain.callbacks import streaming_stdout
|
23 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
|
|
24 |
from tqdm import tqdm
|
25 |
|
26 |
from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
|
@@ -30,7 +31,7 @@ from gen import get_model, SEED
|
|
30 |
from prompter import non_hf_types, PromptType, Prompter
|
31 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
32 |
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
|
33 |
-
have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_pymupdf
|
34 |
from utils_langchain import StreamingGradioCallbackHandler
|
35 |
|
36 |
import_matplotlib()
|
@@ -276,15 +277,7 @@ from typing import Any, Dict, List, Optional, Set
|
|
276 |
|
277 |
from pydantic import Extra, Field, root_validator
|
278 |
|
279 |
-
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
280 |
-
|
281 |
-
"""Wrapper around Huggingface text generation inference API."""
|
282 |
-
from functools import partial
|
283 |
-
from typing import Any, Dict, List, Optional
|
284 |
-
|
285 |
-
from pydantic import Extra, Field, root_validator
|
286 |
-
|
287 |
-
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
288 |
from langchain.llms.base import LLM
|
289 |
|
290 |
|
@@ -356,6 +349,7 @@ class GradioInference(LLM):
|
|
356 |
gr_client = self.client
|
357 |
client_langchain_mode = 'Disabled'
|
358 |
client_langchain_action = LangChainAction.QUERY.value
|
|
|
359 |
top_k_docs = 1
|
360 |
chunk = True
|
361 |
chunk_size = 512
|
@@ -385,6 +379,7 @@ class GradioInference(LLM):
|
|
385 |
iinput_nochat='', # only for chat=False
|
386 |
langchain_mode=client_langchain_mode,
|
387 |
langchain_action=client_langchain_action,
|
|
|
388 |
top_k_docs=top_k_docs,
|
389 |
chunk=chunk,
|
390 |
chunk_size=chunk_size,
|
@@ -566,6 +561,92 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
|
|
566 |
|
567 |
|
568 |
from langchain.chat_models import ChatOpenAI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
569 |
|
570 |
|
571 |
class H2OChatOpenAI(ChatOpenAI):
|
@@ -599,14 +680,26 @@ def get_llm(use_openai_model=False,
|
|
599 |
sanitize_bot_response=False,
|
600 |
verbose=False,
|
601 |
):
|
602 |
-
if use_openai_model or inference_server
|
603 |
if use_openai_model and model_name is None:
|
604 |
model_name = "gpt-3.5-turbo"
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
cls = H2OChatOpenAI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
610 |
callbacks = [StreamingGradioCallbackHandler()]
|
611 |
llm = cls(model_name=model_name,
|
612 |
temperature=temperature if do_sample else 0,
|
@@ -616,11 +709,18 @@ def get_llm(use_openai_model=False,
|
|
616 |
frequency_penalty=0,
|
617 |
presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
|
618 |
callbacks=callbacks if stream_output else None,
|
|
|
|
|
|
|
|
|
|
|
|
|
619 |
)
|
620 |
streamer = callbacks[0] if stream_output else None
|
621 |
if inference_server in ['openai', 'openai_chat']:
|
622 |
prompt_type = inference_server
|
623 |
else:
|
|
|
624 |
prompt_type = prompt_type or 'plain'
|
625 |
elif inference_server:
|
626 |
assert inference_server.startswith(
|
@@ -916,7 +1016,6 @@ def get_dai_docs(from_hf=False, get_pickle=True):
|
|
916 |
return sources
|
917 |
|
918 |
|
919 |
-
|
920 |
image_types = ["png", "jpg", "jpeg"]
|
921 |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
922 |
"md",
|
@@ -927,7 +1026,8 @@ non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
|
927 |
]
|
928 |
# "msg", GPL3
|
929 |
|
930 |
-
if have_libreoffice:
|
|
|
931 |
non_image_types.extend(["docx", "doc", "xls", "xlsx"])
|
932 |
|
933 |
file_types = non_image_types + image_types
|
@@ -936,9 +1036,11 @@ file_types = non_image_types + image_types
|
|
936 |
def add_meta(docs1, file):
|
937 |
file_extension = pathlib.Path(file).suffix
|
938 |
hashid = hash_file(file)
|
|
|
939 |
if not isinstance(docs1, (list, tuple, types.GeneratorType)):
|
940 |
docs1 = [docs1]
|
941 |
-
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now()), hashid=hashid)) for
|
|
|
942 |
|
943 |
|
944 |
def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
@@ -1011,11 +1113,11 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
1011 |
add_meta(docs1, file)
|
1012 |
docs1 = clean_doc(docs1)
|
1013 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.HTML)
|
1014 |
-
elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
|
1015 |
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
|
1016 |
add_meta(docs1, file)
|
1017 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1018 |
-
elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and have_libreoffice:
|
1019 |
docs1 = UnstructuredExcelLoader(file_path=file).load()
|
1020 |
add_meta(docs1, file)
|
1021 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
@@ -1760,6 +1862,7 @@ def _run_qa_db(query=None,
|
|
1760 |
cut_distanct=1.1,
|
1761 |
sanitize_bot_response=False,
|
1762 |
show_rank=False,
|
|
|
1763 |
load_db_if_exists=False,
|
1764 |
db=None,
|
1765 |
do_sample=False,
|
@@ -1775,6 +1878,7 @@ def _run_qa_db(query=None,
|
|
1775 |
num_return_sequences=1,
|
1776 |
langchain_mode=None,
|
1777 |
langchain_action=None,
|
|
|
1778 |
document_subset=DocumentChoices.Relevant.name,
|
1779 |
document_choice=[],
|
1780 |
n_jobs=-1,
|
@@ -1857,20 +1961,21 @@ def _run_qa_db(query=None,
|
|
1857 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
1858 |
yield formatted_doc_chunks, ''
|
1859 |
return
|
1860 |
-
if not
|
1861 |
-
|
1862 |
-
|
1863 |
-
|
1864 |
-
|
1865 |
-
|
1866 |
-
|
1867 |
-
|
1868 |
-
|
1869 |
-
|
1870 |
-
|
1871 |
-
|
1872 |
-
|
1873 |
-
|
|
|
1874 |
|
1875 |
if chain is None and model_name not in non_hf_types:
|
1876 |
# here if no docs at all and not HF type
|
@@ -1948,6 +2053,7 @@ def get_chain(query=None,
|
|
1948 |
db=None,
|
1949 |
langchain_mode=None,
|
1950 |
langchain_action=None,
|
|
|
1951 |
document_subset=DocumentChoices.Relevant.name,
|
1952 |
document_choice=[],
|
1953 |
n_jobs=-1,
|
@@ -1961,6 +2067,7 @@ def get_chain(query=None,
|
|
1961 |
auto_reduce_chunks=True,
|
1962 |
max_chunks=100,
|
1963 |
):
|
|
|
1964 |
# determine whether use of context out of docs is planned
|
1965 |
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
1966 |
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
|
@@ -2092,8 +2199,8 @@ def get_chain(query=None,
|
|
2092 |
for result in zip(db_documents, db_metadatas)]
|
2093 |
|
2094 |
# order documents
|
2095 |
-
doc_hashes = [x
|
2096 |
-
doc_chunk_ids = [x
|
2097 |
docs_with_score = [x for _, _, x in
|
2098 |
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
|
2099 |
]
|
@@ -2302,6 +2409,7 @@ def clean_doc(docs1):
|
|
2302 |
|
2303 |
def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
|
2304 |
if not chunk:
|
|
|
2305 |
return sources
|
2306 |
if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources):
|
2307 |
# if just one document
|
@@ -2320,8 +2428,7 @@ def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
|
|
2320 |
source_chunks = splitter.split_documents(sources)
|
2321 |
|
2322 |
# currently in order, but when pull from db won't be, so mark order and document by hash
|
2323 |
-
|
2324 |
-
[x.metadata.update(dict(doc_hash=doc_hash, chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)]
|
2325 |
|
2326 |
return source_chunks
|
2327 |
|
|
|
21 |
from joblib import delayed
|
22 |
from langchain.callbacks import streaming_stdout
|
23 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
24 |
+
from langchain.schema import LLMResult
|
25 |
from tqdm import tqdm
|
26 |
|
27 |
from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
|
|
|
31 |
from prompter import non_hf_types, PromptType, Prompter
|
32 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
33 |
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
|
34 |
+
have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_pymupdf, set_openai
|
35 |
from utils_langchain import StreamingGradioCallbackHandler
|
36 |
|
37 |
import_matplotlib()
|
|
|
277 |
|
278 |
from pydantic import Extra, Field, root_validator
|
279 |
|
280 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
from langchain.llms.base import LLM
|
282 |
|
283 |
|
|
|
349 |
gr_client = self.client
|
350 |
client_langchain_mode = 'Disabled'
|
351 |
client_langchain_action = LangChainAction.QUERY.value
|
352 |
+
client_langchain_agents = []
|
353 |
top_k_docs = 1
|
354 |
chunk = True
|
355 |
chunk_size = 512
|
|
|
379 |
iinput_nochat='', # only for chat=False
|
380 |
langchain_mode=client_langchain_mode,
|
381 |
langchain_action=client_langchain_action,
|
382 |
+
langchain_agents=client_langchain_agents,
|
383 |
top_k_docs=top_k_docs,
|
384 |
chunk=chunk,
|
385 |
chunk_size=chunk_size,
|
|
|
561 |
|
562 |
|
563 |
from langchain.chat_models import ChatOpenAI
|
564 |
+
from langchain.llms import OpenAI
|
565 |
+
from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \
|
566 |
+
update_token_usage
|
567 |
+
|
568 |
+
|
569 |
+
class H2OOpenAI(OpenAI):
|
570 |
+
"""
|
571 |
+
New class to handle vLLM's use of OpenAI, no vllm_chat supported, so only need here
|
572 |
+
Handles prompting that OpenAI doesn't need, stopping as well
|
573 |
+
"""
|
574 |
+
stop_sequences: Any = None
|
575 |
+
sanitize_bot_response: bool = False
|
576 |
+
prompter: Any = None
|
577 |
+
tokenizer: Any = None
|
578 |
+
|
579 |
+
@classmethod
|
580 |
+
def all_required_field_names(cls) -> Set:
|
581 |
+
all_required_field_names = super(OpenAI, cls).all_required_field_names()
|
582 |
+
all_required_field_names.update(
|
583 |
+
{'top_p', 'frequency_penalty', 'presence_penalty', 'stop_sequences', 'sanitize_bot_response', 'prompter',
|
584 |
+
'tokenizer'})
|
585 |
+
return all_required_field_names
|
586 |
+
|
587 |
+
def _generate(
|
588 |
+
self,
|
589 |
+
prompts: List[str],
|
590 |
+
stop: Optional[List[str]] = None,
|
591 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
592 |
+
**kwargs: Any,
|
593 |
+
) -> LLMResult:
|
594 |
+
stop = self.stop_sequences if not stop else self.stop_sequences + stop
|
595 |
+
|
596 |
+
# HF inference server needs control over input tokens
|
597 |
+
assert self.tokenizer is not None
|
598 |
+
from h2oai_pipeline import H2OTextGenerationPipeline
|
599 |
+
for prompti, prompt in enumerate(prompts):
|
600 |
+
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
|
601 |
+
# NOTE: OpenAI/vLLM server does not add prompting, so must do here
|
602 |
+
data_point = dict(context='', instruction=prompt, input='')
|
603 |
+
prompt = self.prompter.generate_prompt(data_point)
|
604 |
+
prompts[prompti] = prompt
|
605 |
+
|
606 |
+
params = self._invocation_params
|
607 |
+
params = {**params, **kwargs}
|
608 |
+
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
609 |
+
choices = []
|
610 |
+
token_usage: Dict[str, int] = {}
|
611 |
+
# Get the token usage from the response.
|
612 |
+
# Includes prompt, completion, and total tokens used.
|
613 |
+
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
614 |
+
text = ''
|
615 |
+
for _prompts in sub_prompts:
|
616 |
+
if self.streaming:
|
617 |
+
text_with_prompt = ""
|
618 |
+
prompt = _prompts[0]
|
619 |
+
if len(_prompts) > 1:
|
620 |
+
raise ValueError("Cannot stream results with multiple prompts.")
|
621 |
+
params["stream"] = True
|
622 |
+
response = _streaming_response_template()
|
623 |
+
first = True
|
624 |
+
for stream_resp in completion_with_retry(
|
625 |
+
self, prompt=_prompts, **params
|
626 |
+
):
|
627 |
+
if first:
|
628 |
+
stream_resp["choices"][0]["text"] = prompt + stream_resp["choices"][0]["text"]
|
629 |
+
first = False
|
630 |
+
text_chunk = stream_resp["choices"][0]["text"]
|
631 |
+
text_with_prompt += text_chunk
|
632 |
+
text = self.prompter.get_response(text_with_prompt, prompt=prompt,
|
633 |
+
sanitize_bot_response=self.sanitize_bot_response)
|
634 |
+
if run_manager:
|
635 |
+
run_manager.on_llm_new_token(
|
636 |
+
text_chunk,
|
637 |
+
verbose=self.verbose,
|
638 |
+
logprobs=stream_resp["choices"][0]["logprobs"],
|
639 |
+
)
|
640 |
+
_update_response(response, stream_resp)
|
641 |
+
choices.extend(response["choices"])
|
642 |
+
else:
|
643 |
+
response = completion_with_retry(self, prompt=_prompts, **params)
|
644 |
+
choices.extend(response["choices"])
|
645 |
+
if not self.streaming:
|
646 |
+
# Can't update token usage if streaming
|
647 |
+
update_token_usage(_keys, response, token_usage)
|
648 |
+
choices[0]['text'] = text
|
649 |
+
return self.create_llm_result(choices, prompts, token_usage)
|
650 |
|
651 |
|
652 |
class H2OChatOpenAI(ChatOpenAI):
|
|
|
680 |
sanitize_bot_response=False,
|
681 |
verbose=False,
|
682 |
):
|
683 |
+
if use_openai_model or inference_server.startswith('openai') or inference_server.startswith('vllm'):
|
684 |
if use_openai_model and model_name is None:
|
685 |
model_name = "gpt-3.5-turbo"
|
686 |
+
openai, inf_type = set_openai(
|
687 |
+
inference_server) # FIXME: Will later import be ignored? I think so, so should be fine
|
688 |
+
kwargs_extra = {}
|
689 |
+
if inference_server == 'openai_chat' or inf_type == 'vllm_chat':
|
690 |
cls = H2OChatOpenAI
|
691 |
+
else:
|
692 |
+
cls = H2OOpenAI
|
693 |
+
if inf_type == 'vllm':
|
694 |
+
terminate_response = prompter.terminate_response or []
|
695 |
+
stop_sequences = list(set(terminate_response + [prompter.PreResponse]))
|
696 |
+
stop_sequences = [x for x in stop_sequences if x]
|
697 |
+
kwargs_extra = dict(stop_sequences=stop_sequences,
|
698 |
+
sanitize_bot_response=sanitize_bot_response,
|
699 |
+
prompter=prompter,
|
700 |
+
tokenizer=tokenizer,
|
701 |
+
client=None)
|
702 |
+
|
703 |
callbacks = [StreamingGradioCallbackHandler()]
|
704 |
llm = cls(model_name=model_name,
|
705 |
temperature=temperature if do_sample else 0,
|
|
|
709 |
frequency_penalty=0,
|
710 |
presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
|
711 |
callbacks=callbacks if stream_output else None,
|
712 |
+
openai_api_key=openai.api_key,
|
713 |
+
openai_api_base=openai.api_base,
|
714 |
+
logit_bias=None if inf_type =='vllm' else {},
|
715 |
+
max_retries=2,
|
716 |
+
streaming=stream_output,
|
717 |
+
**kwargs_extra
|
718 |
)
|
719 |
streamer = callbacks[0] if stream_output else None
|
720 |
if inference_server in ['openai', 'openai_chat']:
|
721 |
prompt_type = inference_server
|
722 |
else:
|
723 |
+
# vllm goes here
|
724 |
prompt_type = prompt_type or 'plain'
|
725 |
elif inference_server:
|
726 |
assert inference_server.startswith(
|
|
|
1016 |
return sources
|
1017 |
|
1018 |
|
|
|
1019 |
image_types = ["png", "jpg", "jpeg"]
|
1020 |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
1021 |
"md",
|
|
|
1026 |
]
|
1027 |
# "msg", GPL3
|
1028 |
|
1029 |
+
if have_libreoffice or True:
|
1030 |
+
# or True so it tries to load, e.g. on MAC/Windows, even if don't have libreoffice since works without that
|
1031 |
non_image_types.extend(["docx", "doc", "xls", "xlsx"])
|
1032 |
|
1033 |
file_types = non_image_types + image_types
|
|
|
1036 |
def add_meta(docs1, file):
|
1037 |
file_extension = pathlib.Path(file).suffix
|
1038 |
hashid = hash_file(file)
|
1039 |
+
doc_hash = str(uuid.uuid4())[:10]
|
1040 |
if not isinstance(docs1, (list, tuple, types.GeneratorType)):
|
1041 |
docs1 = [docs1]
|
1042 |
+
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now()), hashid=hashid, doc_hash=doc_hash)) for
|
1043 |
+
x in docs1]
|
1044 |
|
1045 |
|
1046 |
def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
|
1113 |
add_meta(docs1, file)
|
1114 |
docs1 = clean_doc(docs1)
|
1115 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.HTML)
|
1116 |
+
elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and (have_libreoffice or True):
|
1117 |
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
|
1118 |
add_meta(docs1, file)
|
1119 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1120 |
+
elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and (have_libreoffice or True):
|
1121 |
docs1 = UnstructuredExcelLoader(file_path=file).load()
|
1122 |
add_meta(docs1, file)
|
1123 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
|
|
1862 |
cut_distanct=1.1,
|
1863 |
sanitize_bot_response=False,
|
1864 |
show_rank=False,
|
1865 |
+
use_llm_if_no_docs=False,
|
1866 |
load_db_if_exists=False,
|
1867 |
db=None,
|
1868 |
do_sample=False,
|
|
|
1878 |
num_return_sequences=1,
|
1879 |
langchain_mode=None,
|
1880 |
langchain_action=None,
|
1881 |
+
langchain_agents=None,
|
1882 |
document_subset=DocumentChoices.Relevant.name,
|
1883 |
document_choice=[],
|
1884 |
n_jobs=-1,
|
|
|
1961 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
1962 |
yield formatted_doc_chunks, ''
|
1963 |
return
|
1964 |
+
if not use_llm_if_no_docs:
|
1965 |
+
if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
|
1966 |
+
LangChainAction.SUMMARIZE_ALL.value,
|
1967 |
+
LangChainAction.SUMMARIZE_REFINE.value]:
|
1968 |
+
ret = 'No relevant documents to summarize.' if have_any_docs else 'No documents to summarize.'
|
1969 |
+
extra = ''
|
1970 |
+
yield ret, extra
|
1971 |
+
return
|
1972 |
+
if not docs and langchain_mode not in [LangChainMode.DISABLED.value,
|
1973 |
+
LangChainMode.CHAT_LLM.value,
|
1974 |
+
LangChainMode.LLM.value]:
|
1975 |
+
ret = 'No relevant documents to query.' if have_any_docs else 'No documents to query.'
|
1976 |
+
extra = ''
|
1977 |
+
yield ret, extra
|
1978 |
+
return
|
1979 |
|
1980 |
if chain is None and model_name not in non_hf_types:
|
1981 |
# here if no docs at all and not HF type
|
|
|
2053 |
db=None,
|
2054 |
langchain_mode=None,
|
2055 |
langchain_action=None,
|
2056 |
+
langchain_agents=None,
|
2057 |
document_subset=DocumentChoices.Relevant.name,
|
2058 |
document_choice=[],
|
2059 |
n_jobs=-1,
|
|
|
2067 |
auto_reduce_chunks=True,
|
2068 |
max_chunks=100,
|
2069 |
):
|
2070 |
+
assert langchain_agents is not None # should be at least []
|
2071 |
# determine whether use of context out of docs is planned
|
2072 |
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
2073 |
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
|
|
|
2199 |
for result in zip(db_documents, db_metadatas)]
|
2200 |
|
2201 |
# order documents
|
2202 |
+
doc_hashes = [x.get('doc_hash', 'None') for x in db_metadatas]
|
2203 |
+
doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas]
|
2204 |
docs_with_score = [x for _, _, x in
|
2205 |
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
|
2206 |
]
|
|
|
2409 |
|
2410 |
def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
|
2411 |
if not chunk:
|
2412 |
+
[x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(sources)]
|
2413 |
return sources
|
2414 |
if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources):
|
2415 |
# if just one document
|
|
|
2428 |
source_chunks = splitter.split_documents(sources)
|
2429 |
|
2430 |
# currently in order, but when pull from db won't be, so mark order and document by hash
|
2431 |
+
[x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)]
|
|
|
2432 |
|
2433 |
return source_chunks
|
2434 |
|
gradio_runner.py
CHANGED
@@ -58,7 +58,7 @@ from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt
|
|
58 |
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
59 |
ping, get_short_name, makedirs, get_kwargs, remove, system_info, ping_gpu, get_url, get_local_ip
|
60 |
from gen import get_model, languages_covered, evaluate, score_qa, langchain_modes, inputs_kwargs_list, scratch_base_dir, \
|
61 |
-
get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions
|
62 |
from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults
|
63 |
|
64 |
from apscheduler.schedulers.background import BackgroundScheduler
|
@@ -101,6 +101,7 @@ def go_gradio(**kwargs):
|
|
101 |
db_type = kwargs['db_type']
|
102 |
visible_langchain_modes = kwargs['visible_langchain_modes']
|
103 |
visible_langchain_actions = kwargs['visible_langchain_actions']
|
|
|
104 |
allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
|
105 |
allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
|
106 |
enable_sources_list = kwargs['enable_sources_list']
|
@@ -361,6 +362,14 @@ def go_gradio(**kwargs):
|
|
361 |
value=allowed_actions[0] if len(allowed_actions) > 0 else None,
|
362 |
label="Action",
|
363 |
visible=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
col_tabs = gr.Column(elem_id="col_container", scale=10)
|
365 |
with (col_tabs, gr.Tabs()):
|
366 |
with gr.TabItem("Chat"):
|
@@ -469,6 +478,7 @@ def go_gradio(**kwargs):
|
|
469 |
value=None,
|
470 |
interactive=True,
|
471 |
multiselect=False,
|
|
|
472 |
)
|
473 |
with gr.Column(scale=4):
|
474 |
pass
|
@@ -1035,6 +1045,8 @@ def go_gradio(**kwargs):
|
|
1035 |
user_kwargs['langchain_mode'] = 'Disabled'
|
1036 |
if 'langchain_action' not in user_kwargs:
|
1037 |
user_kwargs['langchain_action'] = LangChainAction.QUERY.value
|
|
|
|
|
1038 |
|
1039 |
set1 = set(list(default_kwargs1.keys()))
|
1040 |
set2 = set(eval_func_param_names)
|
@@ -1216,6 +1228,7 @@ def go_gradio(**kwargs):
|
|
1216 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
1217 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1218 |
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
|
|
1219 |
document_subset1 = args_list[eval_func_param_names.index('document_subset')]
|
1220 |
document_choice1 = args_list[eval_func_param_names.index('document_choice')]
|
1221 |
if not prompt_type1:
|
@@ -1312,6 +1325,7 @@ def go_gradio(**kwargs):
|
|
1312 |
args_list = args_list[:-3] # only keep rest needed for evaluate()
|
1313 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1314 |
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
|
|
1315 |
document_subset1 = args_list[eval_func_param_names.index('document_subset')]
|
1316 |
document_choice1 = args_list[eval_func_param_names.index('document_choice')]
|
1317 |
if not history:
|
|
|
58 |
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
59 |
ping, get_short_name, makedirs, get_kwargs, remove, system_info, ping_gpu, get_url, get_local_ip
|
60 |
from gen import get_model, languages_covered, evaluate, score_qa, langchain_modes, inputs_kwargs_list, scratch_base_dir, \
|
61 |
+
get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions, langchain_agents_list
|
62 |
from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults
|
63 |
|
64 |
from apscheduler.schedulers.background import BackgroundScheduler
|
|
|
101 |
db_type = kwargs['db_type']
|
102 |
visible_langchain_modes = kwargs['visible_langchain_modes']
|
103 |
visible_langchain_actions = kwargs['visible_langchain_actions']
|
104 |
+
visible_langchain_agents = kwargs['visible_langchain_agents']
|
105 |
allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
|
106 |
allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
|
107 |
enable_sources_list = kwargs['enable_sources_list']
|
|
|
362 |
value=allowed_actions[0] if len(allowed_actions) > 0 else None,
|
363 |
label="Action",
|
364 |
visible=True)
|
365 |
+
allowed_agents = [x for x in langchain_agents_list if x in visible_langchain_agents]
|
366 |
+
langchain_agents = gr.Dropdown(
|
367 |
+
langchain_agents_list,
|
368 |
+
value=kwargs['langchain_agents'],
|
369 |
+
label="Agents",
|
370 |
+
multiselect=True,
|
371 |
+
interactive=True,
|
372 |
+
visible=False) # WIP
|
373 |
col_tabs = gr.Column(elem_id="col_container", scale=10)
|
374 |
with (col_tabs, gr.Tabs()):
|
375 |
with gr.TabItem("Chat"):
|
|
|
478 |
value=None,
|
479 |
interactive=True,
|
480 |
multiselect=False,
|
481 |
+
visible=True,
|
482 |
)
|
483 |
with gr.Column(scale=4):
|
484 |
pass
|
|
|
1045 |
user_kwargs['langchain_mode'] = 'Disabled'
|
1046 |
if 'langchain_action' not in user_kwargs:
|
1047 |
user_kwargs['langchain_action'] = LangChainAction.QUERY.value
|
1048 |
+
if 'langchain_agents' not in user_kwargs:
|
1049 |
+
user_kwargs['langchain_agents'] = []
|
1050 |
|
1051 |
set1 = set(list(default_kwargs1.keys()))
|
1052 |
set2 = set(eval_func_param_names)
|
|
|
1228 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
1229 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1230 |
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
1231 |
+
langchain_agents1 = args_list[eval_func_param_names.index('langchain_agents')]
|
1232 |
document_subset1 = args_list[eval_func_param_names.index('document_subset')]
|
1233 |
document_choice1 = args_list[eval_func_param_names.index('document_choice')]
|
1234 |
if not prompt_type1:
|
|
|
1325 |
args_list = args_list[:-3] # only keep rest needed for evaluate()
|
1326 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1327 |
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
1328 |
+
langchain_agents1 = args_list[eval_func_param_names.index('langchain_agents')]
|
1329 |
document_subset1 = args_list[eval_func_param_names.index('document_subset')]
|
1330 |
document_choice1 = args_list[eval_func_param_names.index('document_choice')]
|
1331 |
if not history:
|
prompter.py
CHANGED
@@ -582,6 +582,20 @@ ASSISTANT:
|
|
582 |
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
583 |
PreResponse = PreResponse
|
584 |
# generates_leading_space = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
585 |
else:
|
586 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
587 |
|
@@ -810,9 +824,20 @@ class Prompter(object):
|
|
810 |
if oi > 0:
|
811 |
# post fix outputs with seperator
|
812 |
output += '\n'
|
|
|
813 |
outputs[oi] = output
|
814 |
# join all outputs, only one extra new line between outputs
|
815 |
output = '\n'.join(outputs)
|
816 |
if self.debug:
|
817 |
print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
|
818 |
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
582 |
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
583 |
PreResponse = PreResponse
|
584 |
# generates_leading_space = True
|
585 |
+
elif prompt_type in [PromptType.guanaco.value, str(PromptType.guanaco.value),
|
586 |
+
PromptType.guanaco.name]:
|
587 |
+
# https://huggingface.co/TheBloke/guanaco-65B-GPTQ
|
588 |
+
promptA = promptB = "" if not (chat and reduced) else ''
|
589 |
+
|
590 |
+
PreInstruct = """### Human: """
|
591 |
+
|
592 |
+
PreInput = None
|
593 |
+
|
594 |
+
PreResponse = """### Assistant:"""
|
595 |
+
terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
596 |
+
chat_turn_sep = chat_sep = '\n'
|
597 |
+
humanstr = PreInstruct
|
598 |
+
botstr = PreResponse
|
599 |
else:
|
600 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
601 |
|
|
|
824 |
if oi > 0:
|
825 |
# post fix outputs with seperator
|
826 |
output += '\n'
|
827 |
+
output = self.fix_text(self.prompt_type, output)
|
828 |
outputs[oi] = output
|
829 |
# join all outputs, only one extra new line between outputs
|
830 |
output = '\n'.join(outputs)
|
831 |
if self.debug:
|
832 |
print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
|
833 |
return output
|
834 |
+
|
835 |
+
@staticmethod
|
836 |
+
def fix_text(prompt_type1, text1):
|
837 |
+
if prompt_type1 == 'human_bot':
|
838 |
+
# hack bug in vLLM with stopping, stops right, but doesn't return last token
|
839 |
+
hfix = '<human'
|
840 |
+
if text1.endswith(hfix):
|
841 |
+
text1 = text1[:-len(hfix)]
|
842 |
+
return text1
|
843 |
+
|
utils.py
CHANGED
@@ -950,7 +950,6 @@ try:
|
|
950 |
except (pkg_resources.DistributionNotFound, AssertionError):
|
951 |
have_langchain = False
|
952 |
|
953 |
-
|
954 |
import distutils.spawn
|
955 |
|
956 |
have_tesseract = distutils.spawn.find_executable("tesseract")
|
@@ -985,3 +984,20 @@ except (pkg_resources.DistributionNotFound, AssertionError):
|
|
985 |
|
986 |
# disable, hangs too often
|
987 |
have_playwright = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
950 |
except (pkg_resources.DistributionNotFound, AssertionError):
|
951 |
have_langchain = False
|
952 |
|
|
|
953 |
import distutils.spawn
|
954 |
|
955 |
have_tesseract = distutils.spawn.find_executable("tesseract")
|
|
|
984 |
|
985 |
# disable, hangs too often
|
986 |
have_playwright = False
|
987 |
+
|
988 |
+
|
989 |
+
def set_openai(inference_server):
|
990 |
+
if inference_server.startswith('vllm'):
|
991 |
+
import openai_vllm
|
992 |
+
openai_vllm.api_key = "EMPTY"
|
993 |
+
inf_type = inference_server.split(':')[0]
|
994 |
+
ip_vllm = inference_server.split(':')[1]
|
995 |
+
port_vllm = inference_server.split(':')[2]
|
996 |
+
openai_vllm.api_base = f"http://{ip_vllm}:{port_vllm}/v1"
|
997 |
+
return openai_vllm, inf_type
|
998 |
+
else:
|
999 |
+
import openai
|
1000 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
1001 |
+
openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
|
1002 |
+
inf_type = inference_server
|
1003 |
+
return openai, inf_type
|