Spaces:
Running
Running
pseudotensor
commited on
Commit
•
454e203
1
Parent(s):
eac73aa
Update with h2oGPT hash ad9d685b188cece0b9c69716ea8e320b74f0caf7
Browse files- client_test.py +26 -10
- enums.py +24 -6
- evaluate_params.py +5 -0
- gen.py +186 -58
- gpt4all_llm.py +18 -8
- gpt_langchain.py +314 -145
- gradio_runner.py +470 -178
- gradio_utils/__init__.py +0 -0
- gradio_utils/__pycache__/__init__.cpython-310.pyc +0 -0
- gradio_utils/__pycache__/css.cpython-310.pyc +0 -0
- gradio_utils/css.py +4 -0
- h2oai_pipeline.py +4 -1
- iterators/__pycache__/timeout_iterator.cpython-310.pyc +0 -0
- iterators/timeout_iterator.py +1 -1
- prompter.py +53 -0
- requirements.txt +16 -16
- utils.py +100 -7
client_test.py
CHANGED
@@ -48,7 +48,7 @@ import markdown # pip install markdown
|
|
48 |
import pytest
|
49 |
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
50 |
|
51 |
-
from enums import
|
52 |
|
53 |
debug = False
|
54 |
|
@@ -68,7 +68,9 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
|
68 |
max_new_tokens=50,
|
69 |
top_k_docs=3,
|
70 |
langchain_mode='Disabled',
|
|
|
71 |
langchain_action=LangChainAction.QUERY.value,
|
|
|
72 |
prompt_dict=None):
|
73 |
from collections import OrderedDict
|
74 |
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
|
@@ -94,11 +96,13 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
|
94 |
instruction_nochat=prompt if not chat else '',
|
95 |
iinput_nochat='', # only for chat=False
|
96 |
langchain_mode=langchain_mode,
|
|
|
97 |
langchain_action=langchain_action,
|
|
|
98 |
top_k_docs=top_k_docs,
|
99 |
chunk=True,
|
100 |
chunk_size=512,
|
101 |
-
document_subset=
|
102 |
document_choice=[],
|
103 |
)
|
104 |
from evaluate_params import eval_func_param_names
|
@@ -202,9 +206,11 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
|
|
202 |
instruction_nochat=prompt,
|
203 |
iinput_nochat='',
|
204 |
langchain_mode='Disabled',
|
|
|
205 |
langchain_action=LangChainAction.QUERY.value,
|
|
|
206 |
top_k_docs=4,
|
207 |
-
document_subset=
|
208 |
document_choice=[],
|
209 |
)
|
210 |
|
@@ -225,23 +231,30 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
|
|
225 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
226 |
def test_client_chat(prompt_type='human_bot'):
|
227 |
return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
|
228 |
-
langchain_mode='Disabled',
|
|
|
|
|
229 |
|
230 |
|
231 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
232 |
def test_client_chat_stream(prompt_type='human_bot'):
|
233 |
return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
234 |
stream_output=True, max_new_tokens=512,
|
235 |
-
langchain_mode='Disabled',
|
|
|
|
|
236 |
|
237 |
|
238 |
-
def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens,
|
|
|
239 |
prompt_dict=None):
|
240 |
client = get_client(serialize=False)
|
241 |
|
242 |
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
|
243 |
-
max_new_tokens=max_new_tokens,
|
|
|
244 |
langchain_action=langchain_action,
|
|
|
245 |
prompt_dict=prompt_dict)
|
246 |
return run_client(client, prompt, args, kwargs)
|
247 |
|
@@ -285,15 +298,18 @@ def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
|
285 |
def test_client_nochat_stream(prompt_type='human_bot'):
|
286 |
return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
287 |
stream_output=True, max_new_tokens=512,
|
288 |
-
langchain_mode='Disabled',
|
|
|
|
|
289 |
|
290 |
|
291 |
-
def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens,
|
|
|
292 |
client = get_client(serialize=False)
|
293 |
|
294 |
kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
|
295 |
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
|
296 |
-
langchain_action=langchain_action)
|
297 |
return run_client_gen(client, prompt, args, kwargs)
|
298 |
|
299 |
|
|
|
48 |
import pytest
|
49 |
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
50 |
|
51 |
+
from enums import DocumentSubset, LangChainAction
|
52 |
|
53 |
debug = False
|
54 |
|
|
|
68 |
max_new_tokens=50,
|
69 |
top_k_docs=3,
|
70 |
langchain_mode='Disabled',
|
71 |
+
add_chat_history_to_context=True,
|
72 |
langchain_action=LangChainAction.QUERY.value,
|
73 |
+
langchain_agents=[],
|
74 |
prompt_dict=None):
|
75 |
from collections import OrderedDict
|
76 |
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
|
|
|
96 |
instruction_nochat=prompt if not chat else '',
|
97 |
iinput_nochat='', # only for chat=False
|
98 |
langchain_mode=langchain_mode,
|
99 |
+
add_chat_history_to_context=add_chat_history_to_context,
|
100 |
langchain_action=langchain_action,
|
101 |
+
langchain_agents=langchain_agents,
|
102 |
top_k_docs=top_k_docs,
|
103 |
chunk=True,
|
104 |
chunk_size=512,
|
105 |
+
document_subset=DocumentSubset.Relevant.name,
|
106 |
document_choice=[],
|
107 |
)
|
108 |
from evaluate_params import eval_func_param_names
|
|
|
206 |
instruction_nochat=prompt,
|
207 |
iinput_nochat='',
|
208 |
langchain_mode='Disabled',
|
209 |
+
add_chat_history_to_context=True,
|
210 |
langchain_action=LangChainAction.QUERY.value,
|
211 |
+
langchain_agents=[],
|
212 |
top_k_docs=4,
|
213 |
+
document_subset=DocumentSubset.Relevant.name,
|
214 |
document_choice=[],
|
215 |
)
|
216 |
|
|
|
231 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
232 |
def test_client_chat(prompt_type='human_bot'):
|
233 |
return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
|
234 |
+
langchain_mode='Disabled',
|
235 |
+
langchain_action=LangChainAction.QUERY.value,
|
236 |
+
langchain_agents=[])
|
237 |
|
238 |
|
239 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
240 |
def test_client_chat_stream(prompt_type='human_bot'):
|
241 |
return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
242 |
stream_output=True, max_new_tokens=512,
|
243 |
+
langchain_mode='Disabled',
|
244 |
+
langchain_action=LangChainAction.QUERY.value,
|
245 |
+
langchain_agents=[])
|
246 |
|
247 |
|
248 |
+
def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens,
|
249 |
+
langchain_mode, langchain_action, langchain_agents,
|
250 |
prompt_dict=None):
|
251 |
client = get_client(serialize=False)
|
252 |
|
253 |
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
|
254 |
+
max_new_tokens=max_new_tokens,
|
255 |
+
langchain_mode=langchain_mode,
|
256 |
langchain_action=langchain_action,
|
257 |
+
langchain_agents=langchain_agents,
|
258 |
prompt_dict=prompt_dict)
|
259 |
return run_client(client, prompt, args, kwargs)
|
260 |
|
|
|
298 |
def test_client_nochat_stream(prompt_type='human_bot'):
|
299 |
return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
300 |
stream_output=True, max_new_tokens=512,
|
301 |
+
langchain_mode='Disabled',
|
302 |
+
langchain_action=LangChainAction.QUERY.value,
|
303 |
+
langchain_agents=[])
|
304 |
|
305 |
|
306 |
+
def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens,
|
307 |
+
langchain_mode, langchain_action, langchain_agents):
|
308 |
client = get_client(serialize=False)
|
309 |
|
310 |
kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
|
311 |
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
|
312 |
+
langchain_action=langchain_action, langchain_agents=langchain_agents)
|
313 |
return run_client_gen(client, prompt, args, kwargs)
|
314 |
|
315 |
|
enums.py
CHANGED
@@ -31,25 +31,30 @@ class PromptType(Enum):
|
|
31 |
mptinstruct = 25
|
32 |
mptchat = 26
|
33 |
falcon = 27
|
|
|
|
|
34 |
|
35 |
|
36 |
-
class
|
37 |
Relevant = 0
|
38 |
-
|
39 |
-
|
40 |
|
41 |
|
42 |
non_query_commands = [
|
43 |
-
|
44 |
-
|
45 |
]
|
46 |
|
47 |
|
|
|
|
|
|
|
|
|
48 |
class LangChainMode(Enum):
|
49 |
"""LangChain mode"""
|
50 |
|
51 |
DISABLED = "Disabled"
|
52 |
-
CHAT_LLM = "ChatLLM"
|
53 |
LLM = "LLM"
|
54 |
ALL = "All"
|
55 |
WIKI = "wiki"
|
@@ -60,6 +65,12 @@ class LangChainMode(Enum):
|
|
60 |
H2O_DAI_DOCS = "DriverlessAI docs"
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
class LangChainAction(Enum):
|
64 |
"""LangChain action"""
|
65 |
|
@@ -71,6 +82,13 @@ class LangChainAction(Enum):
|
|
71 |
SUMMARIZE_REFINE = "Summarize_refine"
|
72 |
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
no_server_str = no_lora_str = no_model_str = '[None/Remove]'
|
75 |
|
76 |
# from site-packages/langchain/llms/openai.py
|
|
|
31 |
mptinstruct = 25
|
32 |
mptchat = 26
|
33 |
falcon = 27
|
34 |
+
guanaco = 28
|
35 |
+
llama2 = 29
|
36 |
|
37 |
|
38 |
+
class DocumentSubset(Enum):
|
39 |
Relevant = 0
|
40 |
+
RelSources = 1
|
41 |
+
TopKSources = 2
|
42 |
|
43 |
|
44 |
non_query_commands = [
|
45 |
+
DocumentSubset.RelSources.name,
|
46 |
+
DocumentSubset.TopKSources.name
|
47 |
]
|
48 |
|
49 |
|
50 |
+
class DocumentChoice(Enum):
|
51 |
+
ALL = 'All'
|
52 |
+
|
53 |
+
|
54 |
class LangChainMode(Enum):
|
55 |
"""LangChain mode"""
|
56 |
|
57 |
DISABLED = "Disabled"
|
|
|
58 |
LLM = "LLM"
|
59 |
ALL = "All"
|
60 |
WIKI = "wiki"
|
|
|
65 |
H2O_DAI_DOCS = "DriverlessAI docs"
|
66 |
|
67 |
|
68 |
+
# modes should not be removed from visible list or added by name
|
69 |
+
langchain_modes_intrinsic = [LangChainMode.DISABLED.value,
|
70 |
+
LangChainMode.LLM.value,
|
71 |
+
LangChainMode.MY_DATA.value]
|
72 |
+
|
73 |
+
|
74 |
class LangChainAction(Enum):
|
75 |
"""LangChain action"""
|
76 |
|
|
|
82 |
SUMMARIZE_REFINE = "Summarize_refine"
|
83 |
|
84 |
|
85 |
+
class LangChainAgent(Enum):
|
86 |
+
"""LangChain agents"""
|
87 |
+
|
88 |
+
SEARCH = "Search"
|
89 |
+
# CSV = "csv" # WIP
|
90 |
+
|
91 |
+
|
92 |
no_server_str = no_lora_str = no_model_str = '[None/Remove]'
|
93 |
|
94 |
# from site-packages/langchain/llms/openai.py
|
evaluate_params.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
no_default_param_names = [
|
2 |
'instruction',
|
3 |
'iinput',
|
@@ -30,7 +33,9 @@ eval_func_param_names = ['instruction',
|
|
30 |
'instruction_nochat',
|
31 |
'iinput_nochat',
|
32 |
'langchain_mode',
|
|
|
33 |
'langchain_action',
|
|
|
34 |
'top_k_docs',
|
35 |
'chunk',
|
36 |
'chunk_size',
|
|
|
1 |
+
input_args_list = ['model_state', 'my_db_state', 'selection_docs_state']
|
2 |
+
|
3 |
+
|
4 |
no_default_param_names = [
|
5 |
'instruction',
|
6 |
'iinput',
|
|
|
33 |
'instruction_nochat',
|
34 |
'iinput_nochat',
|
35 |
'langchain_mode',
|
36 |
+
'add_chat_history_to_context',
|
37 |
'langchain_action',
|
38 |
+
'langchain_agents',
|
39 |
'top_k_docs',
|
40 |
'chunk',
|
41 |
'chunk_size',
|
gen.py
CHANGED
@@ -8,7 +8,6 @@ import sys
|
|
8 |
import os
|
9 |
import time
|
10 |
import traceback
|
11 |
-
import types
|
12 |
import typing
|
13 |
import warnings
|
14 |
from datetime import datetime
|
@@ -28,12 +27,12 @@ os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
|
28 |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
29 |
|
30 |
from evaluate_params import eval_func_param_names, no_default_param_names
|
31 |
-
from enums import
|
32 |
-
source_postfix, LangChainAction
|
33 |
from loaders import get_loaders
|
34 |
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
|
35 |
import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove, \
|
36 |
-
have_langchain
|
37 |
|
38 |
start_faulthandler()
|
39 |
import_matplotlib()
|
@@ -50,10 +49,10 @@ from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
|
|
50 |
from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt, generate_prompt
|
51 |
from stopping import get_stopping
|
52 |
|
53 |
-
langchain_modes = [x.value for x in list(LangChainMode)]
|
54 |
-
|
55 |
langchain_actions = [x.value for x in list(LangChainAction)]
|
56 |
|
|
|
|
|
57 |
scratch_base_dir = '/tmp/'
|
58 |
|
59 |
|
@@ -114,6 +113,7 @@ def main(
|
|
114 |
show_examples: bool = None,
|
115 |
verbose: bool = False,
|
116 |
h2ocolors: bool = True,
|
|
|
117 |
height: int = 600,
|
118 |
show_lora: bool = True,
|
119 |
login_mode_if_model0: bool = False,
|
@@ -134,7 +134,7 @@ def main(
|
|
134 |
extra_lora_options: typing.List[str] = [],
|
135 |
extra_server_options: typing.List[str] = [],
|
136 |
|
137 |
-
score_model: str = '
|
138 |
|
139 |
eval_filename: str = None,
|
140 |
eval_prompts_only_num: int = 0,
|
@@ -143,22 +143,30 @@ def main(
|
|
143 |
|
144 |
langchain_mode: str = None,
|
145 |
langchain_action: str = LangChainAction.QUERY.value,
|
|
|
146 |
force_langchain_evaluate: bool = False,
|
|
|
147 |
visible_langchain_modes: list = ['UserData', 'MyData'],
|
148 |
# WIP:
|
149 |
# visible_langchain_actions: list = langchain_actions.copy(),
|
150 |
visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value],
|
151 |
-
|
152 |
-
|
|
|
153 |
user_path: str = None,
|
|
|
154 |
detect_user_path_changes_every_query: bool = False,
|
|
|
155 |
load_db_if_exists: bool = True,
|
156 |
keep_sources_in_context: bool = False,
|
157 |
db_type: str = 'chroma',
|
158 |
use_openai_embedding: bool = False,
|
159 |
use_openai_model: bool = False,
|
160 |
hf_embedding_model: str = None,
|
|
|
|
|
161 |
allow_upload_to_user_data: bool = True,
|
|
|
162 |
allow_upload_to_my_data: bool = True,
|
163 |
enable_url_upload: bool = True,
|
164 |
enable_text_upload: bool = True,
|
@@ -175,6 +183,7 @@ def main(
|
|
175 |
pre_load_caption_model: bool = False,
|
176 |
caption_gpu: bool = True,
|
177 |
enable_ocr: bool = False,
|
|
|
178 |
):
|
179 |
"""
|
180 |
|
@@ -196,6 +205,8 @@ def main(
|
|
196 |
Or Address can be "openai_chat" or "openai" for OpenAI API
|
197 |
e.g. python generate.py --inference_server="openai_chat" --base_model=gpt-3.5-turbo
|
198 |
e.g. python generate.py --inference_server="openai" --base_model=text-davinci-003
|
|
|
|
|
199 |
:param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
|
200 |
:param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
|
201 |
:param model_lock: Lock models to specific combinations, for ease of use and extending to many models
|
@@ -252,6 +263,7 @@ def main(
|
|
252 |
:param show_examples: whether to show clickable examples in gradio
|
253 |
:param verbose: whether to show verbose prints
|
254 |
:param h2ocolors: whether to use H2O.ai theme
|
|
|
255 |
:param height: height of chat window
|
256 |
:param show_lora: whether to show LORA options in UI (expert so can be hard to understand)
|
257 |
:param login_mode_if_model0: set to True to load --base_model after client logs in, to be able to free GPU memory when model is swapped
|
@@ -271,49 +283,73 @@ def main(
|
|
271 |
:param extra_model_options: extra models to show in list in gradio
|
272 |
:param extra_lora_options: extra LORA to show in list in gradio
|
273 |
:param extra_server_options: extra servers to show in list in gradio
|
274 |
-
:param score_model: which model to score responses
|
|
|
|
|
|
|
275 |
:param eval_filename: json file to use for evaluation, if None is sharegpt
|
276 |
:param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
|
277 |
:param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
|
278 |
:param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
|
279 |
:param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
|
|
|
280 |
WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
|
281 |
:param langchain_action: Mode langchain operations in on documents.
|
282 |
Query: Make query of document(s)
|
283 |
Summarize or Summarize_map_reduce: Summarize document(s) via map_reduce
|
284 |
Summarize_all: Summarize document(s) using entire document at once
|
285 |
Summarize_refine: Summarize document(s) using entire document, and try to refine before returning summary
|
|
|
|
|
286 |
:param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing.
|
287 |
:param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
|
288 |
If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
|
|
|
|
|
|
|
|
|
289 |
:param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
|
290 |
Expensive for large number of files, so not done by default. By default only detect changes during db loading.
|
|
|
291 |
:param visible_langchain_modes: dbs to generate at launch to be ready for LLM
|
292 |
Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
|
293 |
But wiki_full is expensive and requires preparation
|
294 |
To allow scratch space only live in session, add 'MyData' to list
|
295 |
Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
|
296 |
-
|
|
|
|
|
|
|
|
|
297 |
:param visible_langchain_actions: Which actions to allow
|
|
|
298 |
:param document_subset: Default document choice when taking subset of collection
|
299 |
-
:param document_choice: Chosen document(s) by internal name
|
|
|
300 |
:param load_db_if_exists: Whether to load chroma db if exists or re-generate db
|
301 |
:param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
|
302 |
:param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
|
303 |
:param use_openai_embedding: Whether to use OpenAI embeddings for vector db
|
304 |
:param use_openai_model: Whether to use OpenAI model for use with vector db
|
305 |
:param hf_embedding_model: Which HF embedding model to use for vector db
|
306 |
-
Default is instructor-large with 768 parameters per embedding if have GPUs, else all-MiniLM-L6-
|
307 |
Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2"
|
308 |
Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl'
|
309 |
We support automatically changing of embeddings for chroma, with a backup of db made if this is done
|
310 |
-
:param
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
:param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
|
312 |
:param enable_url_upload: Whether to allow upload from URL
|
313 |
:param enable_text_upload: Whether to allow upload of text
|
314 |
:param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
|
315 |
:param chunk: Whether to chunk data (True unless know data is already optimally chunked)
|
316 |
-
:param chunk_size: Size of chunks, with typically top-4 passed to LLM, so
|
317 |
:param top_k_docs: number of chunks to give LLM
|
318 |
:param reverse_docs: whether to reverse docs order so most relevant is closest to question.
|
319 |
Best choice for sufficiently smart model, and truncation occurs for oldest context, so best then too.
|
@@ -327,11 +363,15 @@ def main(
|
|
327 |
captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
|
328 |
captions_model: str = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state
|
329 |
Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
|
|
|
330 |
:param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
|
331 |
parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
|
332 |
Recommended if using larger caption model
|
333 |
:param caption_gpu: If support caption, then use GPU if exists
|
334 |
:param enable_ocr: Whether to support OCR on images
|
|
|
|
|
|
|
335 |
:return:
|
336 |
"""
|
337 |
if base_model is None:
|
@@ -393,7 +433,29 @@ def main(
|
|
393 |
if langchain_mode is not None:
|
394 |
visible_langchain_modes += [langchain_mode]
|
395 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
|
|
|
|
|
397 |
|
398 |
# if specifically chose not to show My or User Data, disable upload, so gradio elements are simpler
|
399 |
if LangChainMode.MY_DATA.value not in visible_langchain_modes:
|
@@ -404,21 +466,22 @@ def main(
|
|
404 |
# auto-set langchain_mode
|
405 |
if have_langchain and langchain_mode is None:
|
406 |
# start in chat mode, in case just want to chat and don't want to get "No documents to query" by default.
|
407 |
-
langchain_mode = LangChainMode.
|
408 |
-
if allow_upload_to_user_data and not is_public and
|
409 |
print("Auto set langchain_mode=%s. Could use UserData instead." % langchain_mode, flush=True)
|
410 |
elif allow_upload_to_my_data:
|
411 |
print("Auto set langchain_mode=%s. Could use MyData instead."
|
412 |
" To allow UserData to pull files from disk,"
|
413 |
-
" set user_path and ensure allow_upload_to_user_data=True" % langchain_mode,
|
|
|
414 |
else:
|
415 |
raise RuntimeError("Please pass --langchain_mode=<chosen mode> out of %s" % langchain_modes)
|
416 |
-
if not have_langchain and langchain_mode not in [None, LangChainMode.DISABLED.value, LangChainMode.LLM.value
|
417 |
raise RuntimeError("Asked for LangChain mode but langchain python package cannot be found.")
|
418 |
if langchain_mode is None:
|
419 |
# if not set yet, disable
|
420 |
langchain_mode = LangChainMode.DISABLED.value
|
421 |
-
print("Auto set langchain_mode=%s" % langchain_mode, flush=True)
|
422 |
|
423 |
if is_public:
|
424 |
allow_upload_to_user_data = False
|
@@ -474,7 +537,7 @@ def main(
|
|
474 |
# HF accounted for later in get_max_max_new_tokens()
|
475 |
save_dir = os.getenv('SAVE_DIR', save_dir)
|
476 |
score_model = os.getenv('SCORE_MODEL', score_model)
|
477 |
-
if score_model == 'None'
|
478 |
score_model = ''
|
479 |
concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count))
|
480 |
api_open = bool(int(os.getenv('API_OPEN', str(int(api_open)))))
|
@@ -482,6 +545,7 @@ def main(
|
|
482 |
|
483 |
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
|
484 |
if n_gpus == 0:
|
|
|
485 |
gpu_id = None
|
486 |
load_8bit = False
|
487 |
load_4bit = False
|
@@ -499,7 +563,11 @@ def main(
|
|
499 |
if hf_embedding_model is None:
|
500 |
# if no GPUs, use simpler embedding model to avoid cost in time
|
501 |
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
|
|
|
|
|
502 |
else:
|
|
|
|
|
503 |
if hf_embedding_model is None:
|
504 |
# if still None, then set default
|
505 |
hf_embedding_model = 'hkunlp/instructor-large'
|
@@ -524,8 +592,6 @@ def main(
|
|
524 |
|
525 |
if offload_folder:
|
526 |
makedirs(offload_folder)
|
527 |
-
if user_path:
|
528 |
-
makedirs(user_path)
|
529 |
|
530 |
placeholder_instruction, placeholder_input, \
|
531 |
stream_output, show_examples, \
|
@@ -551,7 +617,7 @@ def main(
|
|
551 |
verbose,
|
552 |
)
|
553 |
|
554 |
-
git_hash = get_githash()
|
555 |
locals_dict = locals()
|
556 |
locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
|
557 |
if verbose:
|
@@ -565,7 +631,7 @@ def main(
|
|
565 |
get_some_dbs_from_hf()
|
566 |
dbs = {}
|
567 |
for langchain_mode1 in visible_langchain_modes:
|
568 |
-
if langchain_mode1 in ['MyData']:
|
569 |
# don't use what is on disk, remove it instead
|
570 |
for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)):
|
571 |
if os.path.isdir(gpath1):
|
@@ -580,7 +646,7 @@ def main(
|
|
580 |
db = prep_langchain(persist_directory1,
|
581 |
load_db_if_exists,
|
582 |
db_type, use_openai_embedding,
|
583 |
-
langchain_mode1,
|
584 |
hf_embedding_model,
|
585 |
kwargs_make_db=locals())
|
586 |
finally:
|
@@ -599,6 +665,14 @@ def main(
|
|
599 |
model_state_none = dict(model=None, tokenizer=None, device=None,
|
600 |
base_model=None, tokenizer_base_model=None, lora_weights=None,
|
601 |
inference_server=None, prompt_type=None, prompt_dict=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
602 |
|
603 |
if cli:
|
604 |
from cli import run_cli
|
@@ -967,11 +1041,13 @@ def get_model(
|
|
967 |
client = gr_client or hf_client
|
968 |
# Don't return None, None for model, tokenizer so triggers
|
969 |
return client, tokenizer, 'http'
|
970 |
-
if isinstance(inference_server, str) and
|
971 |
-
|
972 |
-
|
973 |
-
|
974 |
-
|
|
|
|
|
975 |
return inference_server, tokenizer, inference_server
|
976 |
assert not inference_server, "Malformed inference_server=%s" % inference_server
|
977 |
if base_model in non_hf_types:
|
@@ -1255,6 +1331,7 @@ def get_score_model(score_model: str = None,
|
|
1255 |
def evaluate(
|
1256 |
model_state,
|
1257 |
my_db_state,
|
|
|
1258 |
# START NOTE: Examples must have same order of parameters
|
1259 |
instruction,
|
1260 |
iinput,
|
@@ -1277,7 +1354,9 @@ def evaluate(
|
|
1277 |
instruction_nochat,
|
1278 |
iinput_nochat,
|
1279 |
langchain_mode,
|
|
|
1280 |
langchain_action,
|
|
|
1281 |
top_k_docs,
|
1282 |
chunk,
|
1283 |
chunk_size,
|
@@ -1291,6 +1370,9 @@ def evaluate(
|
|
1291 |
save_dir=None,
|
1292 |
sanitize_bot_response=False,
|
1293 |
model_state0=None,
|
|
|
|
|
|
|
1294 |
memory_restriction_level=None,
|
1295 |
max_max_new_tokens=None,
|
1296 |
is_public=None,
|
@@ -1298,13 +1380,14 @@ def evaluate(
|
|
1298 |
raise_generate_gpu_exceptions=None,
|
1299 |
chat_context=None,
|
1300 |
lora_weights=None,
|
|
|
1301 |
load_db_if_exists=True,
|
1302 |
dbs=None,
|
1303 |
-
user_path=None,
|
1304 |
detect_user_path_changes_every_query=None,
|
1305 |
use_openai_embedding=None,
|
1306 |
use_openai_model=None,
|
1307 |
hf_embedding_model=None,
|
|
|
1308 |
db_type=None,
|
1309 |
n_jobs=None,
|
1310 |
first_para=None,
|
@@ -1333,6 +1416,16 @@ def evaluate(
|
|
1333 |
assert chunk_size is not None and isinstance(chunk_size, int)
|
1334 |
assert n_jobs is not None
|
1335 |
assert first_para is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1336 |
|
1337 |
if debug:
|
1338 |
locals_dict = locals().copy()
|
@@ -1452,18 +1545,24 @@ def evaluate(
|
|
1452 |
# THIRD PLACE where LangChain referenced, but imports only occur if enabled and have db to use
|
1453 |
assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
|
1454 |
assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
|
1455 |
-
|
1456 |
-
|
1457 |
-
|
1458 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1459 |
else:
|
1460 |
-
|
1461 |
-
do_langchain_path = langchain_mode not in [False, 'Disabled', '
|
1462 |
base_model in non_hf_types or \
|
1463 |
force_langchain_evaluate
|
1464 |
if do_langchain_path:
|
1465 |
outr = ""
|
1466 |
-
# use smaller
|
1467 |
from gpt_langchain import run_qa_db
|
1468 |
gen_hyper_langchain = dict(do_sample=do_sample,
|
1469 |
temperature=temperature,
|
@@ -1484,11 +1583,13 @@ def evaluate(
|
|
1484 |
inference_server=inference_server,
|
1485 |
stream_output=stream_output,
|
1486 |
prompter=prompter,
|
|
|
1487 |
load_db_if_exists=load_db_if_exists,
|
1488 |
-
db=
|
1489 |
-
|
1490 |
detect_user_path_changes_every_query=detect_user_path_changes_every_query,
|
1491 |
-
|
|
|
1492 |
use_openai_embedding=use_openai_embedding,
|
1493 |
use_openai_model=use_openai_model,
|
1494 |
hf_embedding_model=hf_embedding_model,
|
@@ -1498,6 +1599,7 @@ def evaluate(
|
|
1498 |
chunk_size=chunk_size,
|
1499 |
langchain_mode=langchain_mode,
|
1500 |
langchain_action=langchain_action,
|
|
|
1501 |
document_subset=document_subset,
|
1502 |
document_choice=document_choice,
|
1503 |
db_type=db_type,
|
@@ -1526,6 +1628,7 @@ def evaluate(
|
|
1526 |
inference_server=inference_server,
|
1527 |
langchain_mode=langchain_mode,
|
1528 |
langchain_action=langchain_action,
|
|
|
1529 |
document_subset=document_subset,
|
1530 |
document_choice=document_choice,
|
1531 |
num_prompt_tokens=num_prompt_tokens,
|
@@ -1549,12 +1652,12 @@ def evaluate(
|
|
1549 |
clear_torch_cache()
|
1550 |
return
|
1551 |
|
1552 |
-
if inference_server.startswith('
|
1553 |
-
|
1554 |
-
|
1555 |
where_from = "openai_client"
|
|
|
1556 |
|
1557 |
-
openai.api_key = os.getenv("OPENAI_API_KEY")
|
1558 |
terminate_response = prompter.terminate_response or []
|
1559 |
stop_sequences = list(set(terminate_response + [prompter.PreResponse]))
|
1560 |
stop_sequences = [x for x in stop_sequences if x]
|
@@ -1567,7 +1670,7 @@ def evaluate(
|
|
1567 |
n=num_return_sequences,
|
1568 |
presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
|
1569 |
)
|
1570 |
-
if inference_server == 'openai':
|
1571 |
response = openai.Completion.create(
|
1572 |
model=base_model,
|
1573 |
prompt=prompt,
|
@@ -1590,7 +1693,9 @@ def evaluate(
|
|
1590 |
yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
|
1591 |
sanitize_bot_response=sanitize_bot_response),
|
1592 |
sources='')
|
1593 |
-
elif inference_server == 'openai_chat':
|
|
|
|
|
1594 |
response = openai.ChatCompletion.create(
|
1595 |
model=base_model,
|
1596 |
messages=[
|
@@ -1642,7 +1747,9 @@ def evaluate(
|
|
1642 |
chat_client = False
|
1643 |
where_from = "gr_client"
|
1644 |
client_langchain_mode = 'Disabled'
|
|
|
1645 |
client_langchain_action = LangChainAction.QUERY.value
|
|
|
1646 |
gen_server_kwargs = dict(temperature=temperature,
|
1647 |
top_p=top_p,
|
1648 |
top_k=top_k,
|
@@ -1694,12 +1801,14 @@ def evaluate(
|
|
1694 |
instruction_nochat=gr_prompt if not chat_client else '',
|
1695 |
iinput_nochat=gr_iinput, # only for chat=False
|
1696 |
langchain_mode=client_langchain_mode,
|
|
|
1697 |
langchain_action=client_langchain_action,
|
|
|
1698 |
top_k_docs=top_k_docs,
|
1699 |
chunk=chunk,
|
1700 |
chunk_size=chunk_size,
|
1701 |
-
document_subset=
|
1702 |
-
document_choice=[],
|
1703 |
)
|
1704 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
1705 |
if not stream_output:
|
@@ -1993,7 +2102,7 @@ def evaluate(
|
|
1993 |
|
1994 |
|
1995 |
inputs_list_names = list(inspect.signature(evaluate).parameters)
|
1996 |
-
state_names = ['model_state', 'my_db_state']
|
1997 |
inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names]
|
1998 |
|
1999 |
|
@@ -2276,8 +2385,8 @@ y = np.random.randint(0, 1, 100)
|
|
2276 |
|
2277 |
# move to correct position
|
2278 |
for example in examples:
|
2279 |
-
example += [chat, '', '', LangChainMode.DISABLED.value, LangChainAction.QUERY.value,
|
2280 |
-
top_k_docs, chunk, chunk_size,
|
2281 |
]
|
2282 |
# adjust examples if non-chat mode
|
2283 |
if not chat:
|
@@ -2337,7 +2446,7 @@ def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_l
|
|
2337 |
truncation=True,
|
2338 |
max_length=max_length_tokenize).to(smodel.device)
|
2339 |
try:
|
2340 |
-
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
|
2341 |
except torch.cuda.OutOfMemoryError as e:
|
2342 |
print("GPU OOM 3: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
|
2343 |
del inputs
|
@@ -2383,14 +2492,14 @@ def check_locals(**kwargs):
|
|
2383 |
|
2384 |
|
2385 |
def get_model_max_length(model_state):
|
2386 |
-
if not isinstance(model_state['tokenizer'], (str,
|
2387 |
return model_state['tokenizer'].model_max_length
|
2388 |
else:
|
2389 |
return 2048
|
2390 |
|
2391 |
|
2392 |
def get_max_max_new_tokens(model_state, **kwargs):
|
2393 |
-
if not isinstance(model_state['tokenizer'], (str,
|
2394 |
max_max_new_tokens = model_state['tokenizer'].model_max_length
|
2395 |
else:
|
2396 |
max_max_new_tokens = None
|
@@ -2422,12 +2531,15 @@ def get_minmax_top_k_docs(is_public):
|
|
2422 |
return min_top_k_docs, max_top_k_docs, label_top_k_docs
|
2423 |
|
2424 |
|
2425 |
-
def history_to_context(history, langchain_mode1,
|
|
|
|
|
2426 |
memory_restriction_level1, keep_sources_in_context1):
|
2427 |
"""
|
2428 |
consumes all history up to (but not including) latest history item that is presumed to be an [instruction, None] pair
|
2429 |
:param history:
|
2430 |
:param langchain_mode1:
|
|
|
2431 |
:param prompt_type1:
|
2432 |
:param prompt_dict1:
|
2433 |
:param chat1:
|
@@ -2440,7 +2552,7 @@ def history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, cha
|
|
2440 |
_, _, _, max_prompt_length = get_cutoffs(memory_restriction_level1,
|
2441 |
for_context=True, model_max_length=model_max_length1)
|
2442 |
context1 = ''
|
2443 |
-
if max_prompt_length is not None and
|
2444 |
context1 = ''
|
2445 |
# - 1 below because current instruction already in history from user()
|
2446 |
for histi in range(0, len(history) - 1):
|
@@ -2476,6 +2588,22 @@ def history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, cha
|
|
2476 |
return context1
|
2477 |
|
2478 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2479 |
def entrypoint_main():
|
2480 |
"""
|
2481 |
Examples:
|
|
|
8 |
import os
|
9 |
import time
|
10 |
import traceback
|
|
|
11 |
import typing
|
12 |
import warnings
|
13 |
from datetime import datetime
|
|
|
27 |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
28 |
|
29 |
from evaluate_params import eval_func_param_names, no_default_param_names
|
30 |
+
from enums import DocumentSubset, LangChainMode, no_lora_str, model_token_mapping, no_model_str, source_prefix, \
|
31 |
+
source_postfix, LangChainAction, LangChainAgent, DocumentChoice
|
32 |
from loaders import get_loaders
|
33 |
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
|
34 |
import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove, \
|
35 |
+
have_langchain, set_openai, load_collection_enum
|
36 |
|
37 |
start_faulthandler()
|
38 |
import_matplotlib()
|
|
|
49 |
from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt, generate_prompt
|
50 |
from stopping import get_stopping
|
51 |
|
|
|
|
|
52 |
langchain_actions = [x.value for x in list(LangChainAction)]
|
53 |
|
54 |
+
langchain_agents_list = [x.value for x in list(LangChainAgent)]
|
55 |
+
|
56 |
scratch_base_dir = '/tmp/'
|
57 |
|
58 |
|
|
|
113 |
show_examples: bool = None,
|
114 |
verbose: bool = False,
|
115 |
h2ocolors: bool = True,
|
116 |
+
dark: bool = False, # light tends to be best
|
117 |
height: int = 600,
|
118 |
show_lora: bool = True,
|
119 |
login_mode_if_model0: bool = False,
|
|
|
134 |
extra_lora_options: typing.List[str] = [],
|
135 |
extra_server_options: typing.List[str] = [],
|
136 |
|
137 |
+
score_model: str = 'auto',
|
138 |
|
139 |
eval_filename: str = None,
|
140 |
eval_prompts_only_num: int = 0,
|
|
|
143 |
|
144 |
langchain_mode: str = None,
|
145 |
langchain_action: str = LangChainAction.QUERY.value,
|
146 |
+
langchain_agents: list = [],
|
147 |
force_langchain_evaluate: bool = False,
|
148 |
+
langchain_modes: list = [x.value for x in list(LangChainMode)],
|
149 |
visible_langchain_modes: list = ['UserData', 'MyData'],
|
150 |
# WIP:
|
151 |
# visible_langchain_actions: list = langchain_actions.copy(),
|
152 |
visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value],
|
153 |
+
visible_langchain_agents: list = langchain_agents_list.copy(),
|
154 |
+
document_subset: str = DocumentSubset.Relevant.name,
|
155 |
+
document_choice: list = [DocumentChoice.ALL.value],
|
156 |
user_path: str = None,
|
157 |
+
langchain_mode_paths: dict = {'UserData': None},
|
158 |
detect_user_path_changes_every_query: bool = False,
|
159 |
+
use_llm_if_no_docs: bool = False,
|
160 |
load_db_if_exists: bool = True,
|
161 |
keep_sources_in_context: bool = False,
|
162 |
db_type: str = 'chroma',
|
163 |
use_openai_embedding: bool = False,
|
164 |
use_openai_model: bool = False,
|
165 |
hf_embedding_model: str = None,
|
166 |
+
cut_distance: float = 1.64,
|
167 |
+
add_chat_history_to_context: bool = True,
|
168 |
allow_upload_to_user_data: bool = True,
|
169 |
+
reload_langchain_state: bool = True,
|
170 |
allow_upload_to_my_data: bool = True,
|
171 |
enable_url_upload: bool = True,
|
172 |
enable_text_upload: bool = True,
|
|
|
183 |
pre_load_caption_model: bool = False,
|
184 |
caption_gpu: bool = True,
|
185 |
enable_ocr: bool = False,
|
186 |
+
enable_pdf_ocr: str = 'auto',
|
187 |
):
|
188 |
"""
|
189 |
|
|
|
205 |
Or Address can be "openai_chat" or "openai" for OpenAI API
|
206 |
e.g. python generate.py --inference_server="openai_chat" --base_model=gpt-3.5-turbo
|
207 |
e.g. python generate.py --inference_server="openai" --base_model=text-davinci-003
|
208 |
+
Or Address can be "vllm:IP:port" or "vllm:IP:port" for OpenAI-compliant vLLM endpoint
|
209 |
+
Note: vllm_chat not supported by vLLM project.
|
210 |
:param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
|
211 |
:param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
|
212 |
:param model_lock: Lock models to specific combinations, for ease of use and extending to many models
|
|
|
263 |
:param show_examples: whether to show clickable examples in gradio
|
264 |
:param verbose: whether to show verbose prints
|
265 |
:param h2ocolors: whether to use H2O.ai theme
|
266 |
+
:param dark: whether to use dark mode for UI by default (still controlled in UI)
|
267 |
:param height: height of chat window
|
268 |
:param show_lora: whether to show LORA options in UI (expert so can be hard to understand)
|
269 |
:param login_mode_if_model0: set to True to load --base_model after client logs in, to be able to free GPU memory when model is swapped
|
|
|
283 |
:param extra_model_options: extra models to show in list in gradio
|
284 |
:param extra_lora_options: extra LORA to show in list in gradio
|
285 |
:param extra_server_options: extra servers to show in list in gradio
|
286 |
+
:param score_model: which model to score responses
|
287 |
+
None: no response scoring
|
288 |
+
'auto': auto mode, '' (no model) for CPU, 'OpenAssistant/reward-model-deberta-v3-large-v2' for GPU,
|
289 |
+
because on CPU takes too much compute just for scoring response
|
290 |
:param eval_filename: json file to use for evaluation, if None is sharegpt
|
291 |
:param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
|
292 |
:param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
|
293 |
:param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
|
294 |
:param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
|
295 |
+
None: auto mode, check if langchain package exists, at least do LLM if so, else Disabled
|
296 |
WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
|
297 |
:param langchain_action: Mode langchain operations in on documents.
|
298 |
Query: Make query of document(s)
|
299 |
Summarize or Summarize_map_reduce: Summarize document(s) via map_reduce
|
300 |
Summarize_all: Summarize document(s) using entire document at once
|
301 |
Summarize_refine: Summarize document(s) using entire document, and try to refine before returning summary
|
302 |
+
:param langchain_agents: Which agents to use
|
303 |
+
'search': Use Web Search as context for LLM response, e.g. SERP if have SERPAPI_API_KEY in env
|
304 |
:param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing.
|
305 |
:param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
|
306 |
If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
|
307 |
+
:param langchain_mode_paths: dict of langchain_mode keys and disk path values to use for source of documents
|
308 |
+
E.g. "{'UserData2': 'userpath2'}"
|
309 |
+
Can be None even if existing DB, to avoid new documents being added from that path, source links that are on disk still work.
|
310 |
+
If user_path is not None, that path is used for 'UserData' instead of the value in this dict
|
311 |
:param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
|
312 |
Expensive for large number of files, so not done by default. By default only detect changes during db loading.
|
313 |
+
:param langchain_modes: names of collections/dbs to potentially have
|
314 |
:param visible_langchain_modes: dbs to generate at launch to be ready for LLM
|
315 |
Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
|
316 |
But wiki_full is expensive and requires preparation
|
317 |
To allow scratch space only live in session, add 'MyData' to list
|
318 |
Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
|
319 |
+
If have own user modes, need to add these here or add in UI.
|
320 |
+
A state file is stored in visible_langchain_modes.pkl containing last UI-selected values of:
|
321 |
+
langchain_modes, visible_langchain_modes, and langchain_mode_paths
|
322 |
+
Delete the file if you want to start fresh,
|
323 |
+
but in any case the user_path passed in CLI is used for UserData even if was None or different
|
324 |
:param visible_langchain_actions: Which actions to allow
|
325 |
+
:param visible_langchain_agents: Which agents to allow
|
326 |
:param document_subset: Default document choice when taking subset of collection
|
327 |
+
:param document_choice: Chosen document(s) by internal name, 'All' means use all docs
|
328 |
+
:param use_llm_if_no_docs: Whether to use LLM even if no documents, when langchain_mode=UserData or MyData or custom
|
329 |
:param load_db_if_exists: Whether to load chroma db if exists or re-generate db
|
330 |
:param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
|
331 |
:param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
|
332 |
:param use_openai_embedding: Whether to use OpenAI embeddings for vector db
|
333 |
:param use_openai_model: Whether to use OpenAI model for use with vector db
|
334 |
:param hf_embedding_model: Which HF embedding model to use for vector db
|
335 |
+
Default is instructor-large with 768 parameters per embedding if have GPUs, else all-MiniLM-L6-v2 if no GPUs
|
336 |
Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2"
|
337 |
Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl'
|
338 |
We support automatically changing of embeddings for chroma, with a backup of db made if this is done
|
339 |
+
:param cut_distance: Distance to cut off references with larger distances when showing references.
|
340 |
+
1.64 is good to avoid dropping references for all-MiniLM-L6-v2, but instructor-large will always show excessive references.
|
341 |
+
For all-MiniLM-L6-v2, a value of 1.5 can push out even more references, or a large value of 100 can avoid any loss of references.
|
342 |
+
:param add_chat_history_to_context: Include chat context when performing action
|
343 |
+
Not supported yet for openai_chat when using document collection instead of LLM
|
344 |
+
Also not supported when using CLI mode
|
345 |
+
:param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db (UserData or custom user dbs)
|
346 |
+
:param reload_langchain_state: Whether to reload visible_langchain_modes.pkl file that contains any new user collections.
|
347 |
:param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
|
348 |
:param enable_url_upload: Whether to allow upload from URL
|
349 |
:param enable_text_upload: Whether to allow upload of text
|
350 |
:param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
|
351 |
:param chunk: Whether to chunk data (True unless know data is already optimally chunked)
|
352 |
+
:param chunk_size: Size of chunks, with typically top-4 passed to LLM, so needs to be in context length
|
353 |
:param top_k_docs: number of chunks to give LLM
|
354 |
:param reverse_docs: whether to reverse docs order so most relevant is closest to question.
|
355 |
Best choice for sufficiently smart model, and truncation occurs for oldest context, so best then too.
|
|
|
363 |
captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
|
364 |
captions_model: str = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state
|
365 |
Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
|
366 |
+
Disabled for CPU since BLIP requires CUDA
|
367 |
:param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
|
368 |
parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
|
369 |
Recommended if using larger caption model
|
370 |
:param caption_gpu: If support caption, then use GPU if exists
|
371 |
:param enable_ocr: Whether to support OCR on images
|
372 |
+
:param enable_pdf_ocr: 'auto' means only use OCR if normal text extraction fails. Useful for pure image-based PDFs with text
|
373 |
+
'on' means always do OCR as additional parsing of same documents
|
374 |
+
'off' means don't do OCR (e.g. because it's slow even if 'auto' only would trigger if nothing else worked)
|
375 |
:return:
|
376 |
"""
|
377 |
if base_model is None:
|
|
|
433 |
if langchain_mode is not None:
|
434 |
visible_langchain_modes += [langchain_mode]
|
435 |
|
436 |
+
# update
|
437 |
+
if isinstance(langchain_mode_paths, str):
|
438 |
+
langchain_mode_paths = ast.literal_eval(langchain_mode_paths)
|
439 |
+
assert isinstance(langchain_mode_paths, dict)
|
440 |
+
if user_path:
|
441 |
+
langchain_mode_paths['UserData'] = user_path
|
442 |
+
makedirs(user_path)
|
443 |
+
|
444 |
+
if is_public:
|
445 |
+
allow_upload_to_user_data = False
|
446 |
+
if LangChainMode.USER_DATA.value in visible_langchain_modes:
|
447 |
+
visible_langchain_modes.remove(LangChainMode.USER_DATA.value)
|
448 |
+
|
449 |
+
# in-place, for non-scratch dbs
|
450 |
+
if allow_upload_to_user_data:
|
451 |
+
update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, '')
|
452 |
+
# always listen to CLI-passed user_path if passed
|
453 |
+
if user_path:
|
454 |
+
langchain_mode_paths['UserData'] = user_path
|
455 |
+
|
456 |
assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
|
457 |
+
assert len(
|
458 |
+
set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents
|
459 |
|
460 |
# if specifically chose not to show My or User Data, disable upload, so gradio elements are simpler
|
461 |
if LangChainMode.MY_DATA.value not in visible_langchain_modes:
|
|
|
466 |
# auto-set langchain_mode
|
467 |
if have_langchain and langchain_mode is None:
|
468 |
# start in chat mode, in case just want to chat and don't want to get "No documents to query" by default.
|
469 |
+
langchain_mode = LangChainMode.LLM.value
|
470 |
+
if allow_upload_to_user_data and not is_public and langchain_mode_paths['UserData']:
|
471 |
print("Auto set langchain_mode=%s. Could use UserData instead." % langchain_mode, flush=True)
|
472 |
elif allow_upload_to_my_data:
|
473 |
print("Auto set langchain_mode=%s. Could use MyData instead."
|
474 |
" To allow UserData to pull files from disk,"
|
475 |
+
" set user_path or langchain_mode_paths, and ensure allow_upload_to_user_data=True" % langchain_mode,
|
476 |
+
flush=True)
|
477 |
else:
|
478 |
raise RuntimeError("Please pass --langchain_mode=<chosen mode> out of %s" % langchain_modes)
|
479 |
+
if not have_langchain and langchain_mode not in [None, LangChainMode.DISABLED.value, LangChainMode.LLM.value]:
|
480 |
raise RuntimeError("Asked for LangChain mode but langchain python package cannot be found.")
|
481 |
if langchain_mode is None:
|
482 |
# if not set yet, disable
|
483 |
langchain_mode = LangChainMode.DISABLED.value
|
484 |
+
print("Auto set langchain_mode=%s Have langchain package: %s" % (langchain_mode, have_langchain), flush=True)
|
485 |
|
486 |
if is_public:
|
487 |
allow_upload_to_user_data = False
|
|
|
537 |
# HF accounted for later in get_max_max_new_tokens()
|
538 |
save_dir = os.getenv('SAVE_DIR', save_dir)
|
539 |
score_model = os.getenv('SCORE_MODEL', score_model)
|
540 |
+
if str(score_model) == 'None':
|
541 |
score_model = ''
|
542 |
concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count))
|
543 |
api_open = bool(int(os.getenv('API_OPEN', str(int(api_open)))))
|
|
|
545 |
|
546 |
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
|
547 |
if n_gpus == 0:
|
548 |
+
enable_captions = False
|
549 |
gpu_id = None
|
550 |
load_8bit = False
|
551 |
load_4bit = False
|
|
|
563 |
if hf_embedding_model is None:
|
564 |
# if no GPUs, use simpler embedding model to avoid cost in time
|
565 |
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
|
566 |
+
if score_model == 'auto':
|
567 |
+
score_model = ''
|
568 |
else:
|
569 |
+
if score_model == 'auto':
|
570 |
+
score_model = 'OpenAssistant/reward-model-deberta-v3-large-v2'
|
571 |
if hf_embedding_model is None:
|
572 |
# if still None, then set default
|
573 |
hf_embedding_model = 'hkunlp/instructor-large'
|
|
|
592 |
|
593 |
if offload_folder:
|
594 |
makedirs(offload_folder)
|
|
|
|
|
595 |
|
596 |
placeholder_instruction, placeholder_input, \
|
597 |
stream_output, show_examples, \
|
|
|
617 |
verbose,
|
618 |
)
|
619 |
|
620 |
+
git_hash = get_githash() if is_public or os.getenv('GET_GITHASH') else "GET_GITHASH"
|
621 |
locals_dict = locals()
|
622 |
locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
|
623 |
if verbose:
|
|
|
631 |
get_some_dbs_from_hf()
|
632 |
dbs = {}
|
633 |
for langchain_mode1 in visible_langchain_modes:
|
634 |
+
if langchain_mode1 in ['MyData']: # FIXME: Remove other custom temp dbs
|
635 |
# don't use what is on disk, remove it instead
|
636 |
for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)):
|
637 |
if os.path.isdir(gpath1):
|
|
|
646 |
db = prep_langchain(persist_directory1,
|
647 |
load_db_if_exists,
|
648 |
db_type, use_openai_embedding,
|
649 |
+
langchain_mode1, langchain_mode_paths,
|
650 |
hf_embedding_model,
|
651 |
kwargs_make_db=locals())
|
652 |
finally:
|
|
|
665 |
model_state_none = dict(model=None, tokenizer=None, device=None,
|
666 |
base_model=None, tokenizer_base_model=None, lora_weights=None,
|
667 |
inference_server=None, prompt_type=None, prompt_dict=None)
|
668 |
+
my_db_state0 = {LangChainMode.MY_DATA.value: [None, None]}
|
669 |
+
selection_docs_state0 = dict(visible_langchain_modes=visible_langchain_modes,
|
670 |
+
langchain_mode_paths=langchain_mode_paths,
|
671 |
+
langchain_modes=langchain_modes)
|
672 |
+
selection_docs_state = selection_docs_state0
|
673 |
+
langchain_modes0 = langchain_modes
|
674 |
+
langchain_mode_paths0 = langchain_mode_paths
|
675 |
+
visible_langchain_modes0 = visible_langchain_modes
|
676 |
|
677 |
if cli:
|
678 |
from cli import run_cli
|
|
|
1041 |
client = gr_client or hf_client
|
1042 |
# Don't return None, None for model, tokenizer so triggers
|
1043 |
return client, tokenizer, 'http'
|
1044 |
+
if isinstance(inference_server, str) and (
|
1045 |
+
inference_server.startswith('openai') or inference_server.startswith('vllm')):
|
1046 |
+
if inference_server.startswith('openai'):
|
1047 |
+
assert os.getenv('OPENAI_API_KEY'), "Set environment for OPENAI_API_KEY"
|
1048 |
+
# Don't return None, None for model, tokenizer so triggers
|
1049 |
+
# include small token cushion
|
1050 |
+
tokenizer = FakeTokenizer(model_max_length=model_token_mapping[base_model] - 50)
|
1051 |
return inference_server, tokenizer, inference_server
|
1052 |
assert not inference_server, "Malformed inference_server=%s" % inference_server
|
1053 |
if base_model in non_hf_types:
|
|
|
1331 |
def evaluate(
|
1332 |
model_state,
|
1333 |
my_db_state,
|
1334 |
+
selection_docs_state,
|
1335 |
# START NOTE: Examples must have same order of parameters
|
1336 |
instruction,
|
1337 |
iinput,
|
|
|
1354 |
instruction_nochat,
|
1355 |
iinput_nochat,
|
1356 |
langchain_mode,
|
1357 |
+
add_chat_history_to_context,
|
1358 |
langchain_action,
|
1359 |
+
langchain_agents,
|
1360 |
top_k_docs,
|
1361 |
chunk,
|
1362 |
chunk_size,
|
|
|
1370 |
save_dir=None,
|
1371 |
sanitize_bot_response=False,
|
1372 |
model_state0=None,
|
1373 |
+
langchain_modes0=None,
|
1374 |
+
langchain_mode_paths0=None,
|
1375 |
+
visible_langchain_modes0=None,
|
1376 |
memory_restriction_level=None,
|
1377 |
max_max_new_tokens=None,
|
1378 |
is_public=None,
|
|
|
1380 |
raise_generate_gpu_exceptions=None,
|
1381 |
chat_context=None,
|
1382 |
lora_weights=None,
|
1383 |
+
use_llm_if_no_docs=False,
|
1384 |
load_db_if_exists=True,
|
1385 |
dbs=None,
|
|
|
1386 |
detect_user_path_changes_every_query=None,
|
1387 |
use_openai_embedding=None,
|
1388 |
use_openai_model=None,
|
1389 |
hf_embedding_model=None,
|
1390 |
+
cut_distance=None,
|
1391 |
db_type=None,
|
1392 |
n_jobs=None,
|
1393 |
first_para=None,
|
|
|
1416 |
assert chunk_size is not None and isinstance(chunk_size, int)
|
1417 |
assert n_jobs is not None
|
1418 |
assert first_para is not None
|
1419 |
+
assert isinstance(add_chat_history_to_context, bool)
|
1420 |
+
|
1421 |
+
if selection_docs_state is not None:
|
1422 |
+
langchain_modes = selection_docs_state.get('langchain_modes', langchain_modes0)
|
1423 |
+
langchain_mode_paths = selection_docs_state.get('langchain_mode_paths', langchain_mode_paths0)
|
1424 |
+
visible_langchain_modes = selection_docs_state.get('visible_langchain_modes', visible_langchain_modes0)
|
1425 |
+
else:
|
1426 |
+
langchain_modes = langchain_modes0
|
1427 |
+
langchain_mode_paths = langchain_mode_paths0
|
1428 |
+
visible_langchain_modes = visible_langchain_modes0
|
1429 |
|
1430 |
if debug:
|
1431 |
locals_dict = locals().copy()
|
|
|
1545 |
# THIRD PLACE where LangChain referenced, but imports only occur if enabled and have db to use
|
1546 |
assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
|
1547 |
assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
|
1548 |
+
assert len(
|
1549 |
+
set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents
|
1550 |
+
if dbs is not None and langchain_mode in dbs:
|
1551 |
+
db = dbs[langchain_mode]
|
1552 |
+
elif my_db_state is not None and langchain_mode in my_db_state:
|
1553 |
+
db1 = my_db_state[langchain_mode]
|
1554 |
+
if db1 is not None and len(db1) == 2:
|
1555 |
+
db = db1[0]
|
1556 |
+
else:
|
1557 |
+
db = None
|
1558 |
else:
|
1559 |
+
db = None
|
1560 |
+
do_langchain_path = langchain_mode not in [False, 'Disabled', 'LLM'] or \
|
1561 |
base_model in non_hf_types or \
|
1562 |
force_langchain_evaluate
|
1563 |
if do_langchain_path:
|
1564 |
outr = ""
|
1565 |
+
# use smaller cut_distance for wiki_full since so many matches could be obtained, and often irrelevant unless close
|
1566 |
from gpt_langchain import run_qa_db
|
1567 |
gen_hyper_langchain = dict(do_sample=do_sample,
|
1568 |
temperature=temperature,
|
|
|
1583 |
inference_server=inference_server,
|
1584 |
stream_output=stream_output,
|
1585 |
prompter=prompter,
|
1586 |
+
use_llm_if_no_docs=use_llm_if_no_docs,
|
1587 |
load_db_if_exists=load_db_if_exists,
|
1588 |
+
db=db,
|
1589 |
+
langchain_mode_paths=langchain_mode_paths,
|
1590 |
detect_user_path_changes_every_query=detect_user_path_changes_every_query,
|
1591 |
+
cut_distance=1.1 if langchain_mode in ['wiki_full'] else cut_distance,
|
1592 |
+
add_chat_history_to_context=add_chat_history_to_context,
|
1593 |
use_openai_embedding=use_openai_embedding,
|
1594 |
use_openai_model=use_openai_model,
|
1595 |
hf_embedding_model=hf_embedding_model,
|
|
|
1599 |
chunk_size=chunk_size,
|
1600 |
langchain_mode=langchain_mode,
|
1601 |
langchain_action=langchain_action,
|
1602 |
+
langchain_agents=langchain_agents,
|
1603 |
document_subset=document_subset,
|
1604 |
document_choice=document_choice,
|
1605 |
db_type=db_type,
|
|
|
1628 |
inference_server=inference_server,
|
1629 |
langchain_mode=langchain_mode,
|
1630 |
langchain_action=langchain_action,
|
1631 |
+
langchain_agents=langchain_agents,
|
1632 |
document_subset=document_subset,
|
1633 |
document_choice=document_choice,
|
1634 |
num_prompt_tokens=num_prompt_tokens,
|
|
|
1652 |
clear_torch_cache()
|
1653 |
return
|
1654 |
|
1655 |
+
if inference_server.startswith('vllm') or inference_server.startswith('openai') or inference_server.startswith(
|
1656 |
+
'http'):
|
1657 |
+
if inference_server.startswith('vllm') or inference_server.startswith('openai'):
|
1658 |
where_from = "openai_client"
|
1659 |
+
openai, inf_type = set_openai(inference_server)
|
1660 |
|
|
|
1661 |
terminate_response = prompter.terminate_response or []
|
1662 |
stop_sequences = list(set(terminate_response + [prompter.PreResponse]))
|
1663 |
stop_sequences = [x for x in stop_sequences if x]
|
|
|
1670 |
n=num_return_sequences,
|
1671 |
presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
|
1672 |
)
|
1673 |
+
if inf_type == 'vllm' or inference_server == 'openai':
|
1674 |
response = openai.Completion.create(
|
1675 |
model=base_model,
|
1676 |
prompt=prompt,
|
|
|
1693 |
yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
|
1694 |
sanitize_bot_response=sanitize_bot_response),
|
1695 |
sources='')
|
1696 |
+
elif inf_type == 'vllm_chat' or inference_server == 'openai_chat':
|
1697 |
+
if inf_type == 'vllm_chat':
|
1698 |
+
raise NotImplementedError('%s not supported by vLLM' % inf_type)
|
1699 |
response = openai.ChatCompletion.create(
|
1700 |
model=base_model,
|
1701 |
messages=[
|
|
|
1747 |
chat_client = False
|
1748 |
where_from = "gr_client"
|
1749 |
client_langchain_mode = 'Disabled'
|
1750 |
+
client_add_chat_history_to_context = True
|
1751 |
client_langchain_action = LangChainAction.QUERY.value
|
1752 |
+
client_langchain_agents = []
|
1753 |
gen_server_kwargs = dict(temperature=temperature,
|
1754 |
top_p=top_p,
|
1755 |
top_k=top_k,
|
|
|
1801 |
instruction_nochat=gr_prompt if not chat_client else '',
|
1802 |
iinput_nochat=gr_iinput, # only for chat=False
|
1803 |
langchain_mode=client_langchain_mode,
|
1804 |
+
add_chat_history_to_context=client_add_chat_history_to_context,
|
1805 |
langchain_action=client_langchain_action,
|
1806 |
+
langchain_agents=client_langchain_agents,
|
1807 |
top_k_docs=top_k_docs,
|
1808 |
chunk=chunk,
|
1809 |
chunk_size=chunk_size,
|
1810 |
+
document_subset=DocumentSubset.Relevant.name,
|
1811 |
+
document_choice=[DocumentChoice.ALL.value],
|
1812 |
)
|
1813 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
1814 |
if not stream_output:
|
|
|
2102 |
|
2103 |
|
2104 |
inputs_list_names = list(inspect.signature(evaluate).parameters)
|
2105 |
+
state_names = ['model_state', 'my_db_state', 'selection_docs_state']
|
2106 |
inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names]
|
2107 |
|
2108 |
|
|
|
2385 |
|
2386 |
# move to correct position
|
2387 |
for example in examples:
|
2388 |
+
example += [chat, '', '', LangChainMode.DISABLED.value, True, LangChainAction.QUERY.value, [],
|
2389 |
+
top_k_docs, chunk, chunk_size, DocumentSubset.Relevant.name, []
|
2390 |
]
|
2391 |
# adjust examples if non-chat mode
|
2392 |
if not chat:
|
|
|
2446 |
truncation=True,
|
2447 |
max_length=max_length_tokenize).to(smodel.device)
|
2448 |
try:
|
2449 |
+
score = torch.sigmoid(smodel(**inputs.to(smodel.device)).logits[0].float()).cpu().detach().numpy()[0]
|
2450 |
except torch.cuda.OutOfMemoryError as e:
|
2451 |
print("GPU OOM 3: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
|
2452 |
del inputs
|
|
|
2492 |
|
2493 |
|
2494 |
def get_model_max_length(model_state):
|
2495 |
+
if not isinstance(model_state['tokenizer'], (str, type(None))):
|
2496 |
return model_state['tokenizer'].model_max_length
|
2497 |
else:
|
2498 |
return 2048
|
2499 |
|
2500 |
|
2501 |
def get_max_max_new_tokens(model_state, **kwargs):
|
2502 |
+
if not isinstance(model_state['tokenizer'], (str, type(None))):
|
2503 |
max_max_new_tokens = model_state['tokenizer'].model_max_length
|
2504 |
else:
|
2505 |
max_max_new_tokens = None
|
|
|
2531 |
return min_top_k_docs, max_top_k_docs, label_top_k_docs
|
2532 |
|
2533 |
|
2534 |
+
def history_to_context(history, langchain_mode1,
|
2535 |
+
add_chat_history_to_context,
|
2536 |
+
prompt_type1, prompt_dict1, chat1, model_max_length1,
|
2537 |
memory_restriction_level1, keep_sources_in_context1):
|
2538 |
"""
|
2539 |
consumes all history up to (but not including) latest history item that is presumed to be an [instruction, None] pair
|
2540 |
:param history:
|
2541 |
:param langchain_mode1:
|
2542 |
+
:param add_chat_history_to_context:
|
2543 |
:param prompt_type1:
|
2544 |
:param prompt_dict1:
|
2545 |
:param chat1:
|
|
|
2552 |
_, _, _, max_prompt_length = get_cutoffs(memory_restriction_level1,
|
2553 |
for_context=True, model_max_length=model_max_length1)
|
2554 |
context1 = ''
|
2555 |
+
if max_prompt_length is not None and add_chat_history_to_context:
|
2556 |
context1 = ''
|
2557 |
# - 1 below because current instruction already in history from user()
|
2558 |
for histi in range(0, len(history) - 1):
|
|
|
2588 |
return context1
|
2589 |
|
2590 |
|
2591 |
+
def update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, extra):
|
2592 |
+
# update from saved state on disk
|
2593 |
+
langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file = \
|
2594 |
+
load_collection_enum(extra)
|
2595 |
+
|
2596 |
+
visible_langchain_modes_temp = visible_langchain_modes.copy() + visible_langchain_modes_from_file
|
2597 |
+
visible_langchain_modes.clear() # don't lose original reference
|
2598 |
+
[visible_langchain_modes.append(x) for x in visible_langchain_modes_temp if x not in visible_langchain_modes]
|
2599 |
+
|
2600 |
+
langchain_mode_paths.update(langchain_mode_paths_from_file)
|
2601 |
+
|
2602 |
+
langchain_modes_temp = langchain_modes.copy() + langchain_modes_from_file
|
2603 |
+
langchain_modes.clear() # don't lose original reference
|
2604 |
+
[langchain_modes.append(x) for x in langchain_modes_temp if x not in langchain_modes]
|
2605 |
+
|
2606 |
+
|
2607 |
def entrypoint_main():
|
2608 |
"""
|
2609 |
Examples:
|
gpt4all_llm.py
CHANGED
@@ -95,15 +95,17 @@ def get_llm_gpt4all(model_name,
|
|
95 |
streaming=False,
|
96 |
callbacks=None,
|
97 |
prompter=None,
|
|
|
|
|
98 |
verbose=False,
|
99 |
):
|
100 |
assert prompter is not None
|
101 |
env_gpt4all_file = ".env_gpt4all"
|
102 |
env_kwargs = dotenv_values(env_gpt4all_file)
|
103 |
-
|
104 |
default_kwargs = dict(context_erase=0.5,
|
105 |
n_batch=1,
|
106 |
-
|
107 |
n_predict=max_new_tokens,
|
108 |
repeat_last_n=64 if repetition_penalty != 1.0 else 0,
|
109 |
repeat_penalty=repetition_penalty,
|
@@ -117,7 +119,8 @@ def get_llm_gpt4all(model_name,
|
|
117 |
cls = H2OLlamaCpp
|
118 |
model_path = env_kwargs.pop('model_path_llama') if model is None else model
|
119 |
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
120 |
-
model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming,
|
|
|
121 |
llm = cls(**model_kwargs)
|
122 |
llm.client.verbose = verbose
|
123 |
elif model_name == 'gpt4all_llama':
|
@@ -125,14 +128,16 @@ def get_llm_gpt4all(model_name,
|
|
125 |
model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
|
126 |
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
127 |
model_kwargs.update(
|
128 |
-
dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming,
|
|
|
129 |
llm = cls(**model_kwargs)
|
130 |
elif model_name == 'gptj':
|
131 |
cls = H2OGPT4All
|
132 |
model_path = env_kwargs.pop('model_path_gptj') if model is None else model
|
133 |
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
134 |
model_kwargs.update(
|
135 |
-
dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming,
|
|
|
136 |
llm = cls(**model_kwargs)
|
137 |
else:
|
138 |
raise RuntimeError("No such model_name %s" % model_name)
|
@@ -142,6 +147,8 @@ def get_llm_gpt4all(model_name,
|
|
142 |
class H2OGPT4All(gpt4all.GPT4All):
|
143 |
model: Any
|
144 |
prompter: Any
|
|
|
|
|
145 |
"""Path to the pre-trained GPT4All model file."""
|
146 |
|
147 |
@root_validator()
|
@@ -187,10 +194,11 @@ class H2OGPT4All(gpt4all.GPT4All):
|
|
187 |
**kwargs,
|
188 |
) -> str:
|
189 |
# Roughly 4 chars per token if natural language
|
190 |
-
|
|
|
191 |
|
192 |
# use instruct prompting
|
193 |
-
data_point = dict(context=
|
194 |
prompt = self.prompter.generate_prompt(data_point)
|
195 |
|
196 |
verbose = False
|
@@ -206,6 +214,8 @@ from langchain.llms import LlamaCpp
|
|
206 |
class H2OLlamaCpp(LlamaCpp):
|
207 |
model_path: Any
|
208 |
prompter: Any
|
|
|
|
|
209 |
"""Path to the pre-trained GPT4All model file."""
|
210 |
|
211 |
@root_validator()
|
@@ -276,7 +286,7 @@ class H2OLlamaCpp(LlamaCpp):
|
|
276 |
print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
|
277 |
|
278 |
# use instruct prompting
|
279 |
-
data_point = dict(context=
|
280 |
prompt = self.prompter.generate_prompt(data_point)
|
281 |
|
282 |
if verbose:
|
|
|
95 |
streaming=False,
|
96 |
callbacks=None,
|
97 |
prompter=None,
|
98 |
+
context='',
|
99 |
+
iinput='',
|
100 |
verbose=False,
|
101 |
):
|
102 |
assert prompter is not None
|
103 |
env_gpt4all_file = ".env_gpt4all"
|
104 |
env_kwargs = dotenv_values(env_gpt4all_file)
|
105 |
+
max_tokens = env_kwargs.pop('max_tokens', 2048 - max_new_tokens)
|
106 |
default_kwargs = dict(context_erase=0.5,
|
107 |
n_batch=1,
|
108 |
+
max_tokens=max_tokens,
|
109 |
n_predict=max_new_tokens,
|
110 |
repeat_last_n=64 if repetition_penalty != 1.0 else 0,
|
111 |
repeat_penalty=repetition_penalty,
|
|
|
119 |
cls = H2OLlamaCpp
|
120 |
model_path = env_kwargs.pop('model_path_llama') if model is None else model
|
121 |
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
122 |
+
model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming,
|
123 |
+
prompter=prompter, context=context, iinput=iinput))
|
124 |
llm = cls(**model_kwargs)
|
125 |
llm.client.verbose = verbose
|
126 |
elif model_name == 'gpt4all_llama':
|
|
|
128 |
model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
|
129 |
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
130 |
model_kwargs.update(
|
131 |
+
dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming,
|
132 |
+
prompter=prompter, context=context, iinput=iinput))
|
133 |
llm = cls(**model_kwargs)
|
134 |
elif model_name == 'gptj':
|
135 |
cls = H2OGPT4All
|
136 |
model_path = env_kwargs.pop('model_path_gptj') if model is None else model
|
137 |
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
138 |
model_kwargs.update(
|
139 |
+
dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming,
|
140 |
+
prompter=prompter, context=context, iinput=iinput))
|
141 |
llm = cls(**model_kwargs)
|
142 |
else:
|
143 |
raise RuntimeError("No such model_name %s" % model_name)
|
|
|
147 |
class H2OGPT4All(gpt4all.GPT4All):
|
148 |
model: Any
|
149 |
prompter: Any
|
150 |
+
context: Any = ''
|
151 |
+
iinput: Any = ''
|
152 |
"""Path to the pre-trained GPT4All model file."""
|
153 |
|
154 |
@root_validator()
|
|
|
194 |
**kwargs,
|
195 |
) -> str:
|
196 |
# Roughly 4 chars per token if natural language
|
197 |
+
n_ctx = 2048
|
198 |
+
prompt = prompt[-self.max_tokens * 4:]
|
199 |
|
200 |
# use instruct prompting
|
201 |
+
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
|
202 |
prompt = self.prompter.generate_prompt(data_point)
|
203 |
|
204 |
verbose = False
|
|
|
214 |
class H2OLlamaCpp(LlamaCpp):
|
215 |
model_path: Any
|
216 |
prompter: Any
|
217 |
+
context: Any
|
218 |
+
iinput: Any
|
219 |
"""Path to the pre-trained GPT4All model file."""
|
220 |
|
221 |
@root_validator()
|
|
|
286 |
print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
|
287 |
|
288 |
# use instruct prompting
|
289 |
+
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
|
290 |
prompt = self.prompter.generate_prompt(data_point)
|
291 |
|
292 |
if verbose:
|
gpt_langchain.py
CHANGED
@@ -21,16 +21,17 @@ import filelock
|
|
21 |
from joblib import delayed
|
22 |
from langchain.callbacks import streaming_stdout
|
23 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
|
|
24 |
from tqdm import tqdm
|
25 |
|
26 |
-
from enums import
|
27 |
-
LangChainAction, LangChainMode
|
28 |
from evaluate_params import gen_hyper
|
29 |
from gen import get_model, SEED
|
30 |
from prompter import non_hf_types, PromptType, Prompter
|
31 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
32 |
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
|
33 |
-
have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_pymupdf
|
34 |
from utils_langchain import StreamingGradioCallbackHandler
|
35 |
|
36 |
import_matplotlib()
|
@@ -95,11 +96,15 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss',
|
|
95 |
db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
96 |
hf_embedding_model, verbose=False)
|
97 |
if db is None:
|
|
|
|
|
|
|
|
|
98 |
db = Chroma.from_documents(documents=sources,
|
99 |
embedding=embedding,
|
100 |
persist_directory=persist_directory,
|
101 |
collection_name=collection_name,
|
102 |
-
|
103 |
db.persist()
|
104 |
clear_embedding(db)
|
105 |
save_embed(db, use_openai_embedding, hf_embedding_model)
|
@@ -276,15 +281,7 @@ from typing import Any, Dict, List, Optional, Set
|
|
276 |
|
277 |
from pydantic import Extra, Field, root_validator
|
278 |
|
279 |
-
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
280 |
-
|
281 |
-
"""Wrapper around Huggingface text generation inference API."""
|
282 |
-
from functools import partial
|
283 |
-
from typing import Any, Dict, List, Optional
|
284 |
-
|
285 |
-
from pydantic import Extra, Field, root_validator
|
286 |
-
|
287 |
-
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
288 |
from langchain.llms.base import LLM
|
289 |
|
290 |
|
@@ -312,6 +309,8 @@ class GradioInference(LLM):
|
|
312 |
sanitize_bot_response: bool = False
|
313 |
|
314 |
prompter: Any = None
|
|
|
|
|
315 |
client: Any = None
|
316 |
|
317 |
class Config:
|
@@ -355,13 +354,15 @@ class GradioInference(LLM):
|
|
355 |
stream_output = self.stream
|
356 |
gr_client = self.client
|
357 |
client_langchain_mode = 'Disabled'
|
|
|
358 |
client_langchain_action = LangChainAction.QUERY.value
|
|
|
359 |
top_k_docs = 1
|
360 |
chunk = True
|
361 |
chunk_size = 512
|
362 |
client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True
|
363 |
-
iinput='', # only for chat=True
|
364 |
-
context=
|
365 |
# streaming output is supported, loops over and outputs each generation in streaming mode
|
366 |
# but leave stream_output=False for simple input/output mode
|
367 |
stream_output=stream_output,
|
@@ -382,14 +383,16 @@ class GradioInference(LLM):
|
|
382 |
chat=self.chat_client,
|
383 |
|
384 |
instruction_nochat=prompt if not self.chat_client else '',
|
385 |
-
iinput_nochat='',
|
386 |
langchain_mode=client_langchain_mode,
|
|
|
387 |
langchain_action=client_langchain_action,
|
|
|
388 |
top_k_docs=top_k_docs,
|
389 |
chunk=chunk,
|
390 |
chunk_size=chunk_size,
|
391 |
-
document_subset=
|
392 |
-
document_choice=[],
|
393 |
)
|
394 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
395 |
if not stream_output:
|
@@ -459,6 +462,8 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
|
|
459 |
stream: bool = False
|
460 |
sanitize_bot_response: bool = False
|
461 |
prompter: Any = None
|
|
|
|
|
462 |
tokenizer: Any = None
|
463 |
client: Any = None
|
464 |
|
@@ -500,7 +505,7 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
|
|
500 |
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
|
501 |
|
502 |
# NOTE: TGI server does not add prompting, so must do here
|
503 |
-
data_point = dict(context=
|
504 |
prompt = self.prompter.generate_prompt(data_point)
|
505 |
|
506 |
gen_server_kwargs = dict(do_sample=self.do_sample,
|
@@ -566,6 +571,94 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
|
|
566 |
|
567 |
|
568 |
from langchain.chat_models import ChatOpenAI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
569 |
|
570 |
|
571 |
class H2OChatOpenAI(ChatOpenAI):
|
@@ -596,17 +689,36 @@ def get_llm(use_openai_model=False,
|
|
596 |
prompt_type=None,
|
597 |
prompt_dict=None,
|
598 |
prompter=None,
|
|
|
|
|
599 |
sanitize_bot_response=False,
|
600 |
verbose=False,
|
601 |
):
|
602 |
-
if
|
|
|
|
|
603 |
if use_openai_model and model_name is None:
|
604 |
model_name = "gpt-3.5-turbo"
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
cls = H2OChatOpenAI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
610 |
callbacks = [StreamingGradioCallbackHandler()]
|
611 |
llm = cls(model_name=model_name,
|
612 |
temperature=temperature if do_sample else 0,
|
@@ -616,11 +728,18 @@ def get_llm(use_openai_model=False,
|
|
616 |
frequency_penalty=0,
|
617 |
presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
|
618 |
callbacks=callbacks if stream_output else None,
|
|
|
|
|
|
|
|
|
|
|
|
|
619 |
)
|
620 |
streamer = callbacks[0] if stream_output else None
|
621 |
if inference_server in ['openai', 'openai_chat']:
|
622 |
prompt_type = inference_server
|
623 |
else:
|
|
|
624 |
prompt_type = prompt_type or 'plain'
|
625 |
elif inference_server:
|
626 |
assert inference_server.startswith(
|
@@ -669,6 +788,8 @@ def get_llm(use_openai_model=False,
|
|
669 |
callbacks=callbacks if stream_output else None,
|
670 |
stream=stream_output,
|
671 |
prompter=prompter,
|
|
|
|
|
672 |
client=gr_client,
|
673 |
sanitize_bot_response=sanitize_bot_response,
|
674 |
)
|
@@ -689,6 +810,8 @@ def get_llm(use_openai_model=False,
|
|
689 |
callbacks=callbacks if stream_output else None,
|
690 |
stream=stream_output,
|
691 |
prompter=prompter,
|
|
|
|
|
692 |
tokenizer=tokenizer,
|
693 |
client=hf_client,
|
694 |
timeout=max_time,
|
@@ -721,6 +844,8 @@ def get_llm(use_openai_model=False,
|
|
721 |
verbose=verbose,
|
722 |
streaming=stream_output,
|
723 |
prompter=prompter,
|
|
|
|
|
724 |
)
|
725 |
else:
|
726 |
if model is None:
|
@@ -763,6 +888,8 @@ def get_llm(use_openai_model=False,
|
|
763 |
from h2oai_pipeline import H2OTextGenerationPipeline
|
764 |
pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
|
765 |
prompter=prompter,
|
|
|
|
|
766 |
prompt_type=prompt_type,
|
767 |
prompt_dict=prompt_dict,
|
768 |
sanitize_bot_response=sanitize_bot_response,
|
@@ -916,7 +1043,6 @@ def get_dai_docs(from_hf=False, get_pickle=True):
|
|
916 |
return sources
|
917 |
|
918 |
|
919 |
-
|
920 |
image_types = ["png", "jpg", "jpeg"]
|
921 |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
922 |
"md",
|
@@ -927,7 +1053,8 @@ non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
|
927 |
]
|
928 |
# "msg", GPL3
|
929 |
|
930 |
-
if have_libreoffice:
|
|
|
931 |
non_image_types.extend(["docx", "doc", "xls", "xlsx"])
|
932 |
|
933 |
file_types = non_image_types + image_types
|
@@ -936,9 +1063,11 @@ file_types = non_image_types + image_types
|
|
936 |
def add_meta(docs1, file):
|
937 |
file_extension = pathlib.Path(file).suffix
|
938 |
hashid = hash_file(file)
|
|
|
939 |
if not isinstance(docs1, (list, tuple, types.GeneratorType)):
|
940 |
docs1 = [docs1]
|
941 |
-
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now()), hashid=hashid)) for
|
|
|
942 |
|
943 |
|
944 |
def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
@@ -946,7 +1075,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
946 |
is_url=False, is_txt=False,
|
947 |
enable_captions=True,
|
948 |
captions_model=None,
|
949 |
-
enable_ocr=False, caption_loader=None,
|
950 |
headsize=50):
|
951 |
if file is None:
|
952 |
if fail_any_exception:
|
@@ -963,6 +1092,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
963 |
base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10]
|
964 |
base_path = os.path.join(dir_name, base_name)
|
965 |
if is_url:
|
|
|
966 |
if file.lower().startswith('arxiv:'):
|
967 |
query = file.lower().split('arxiv:')
|
968 |
if len(query) == 2 and have_arxiv:
|
@@ -1011,11 +1141,11 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
1011 |
add_meta(docs1, file)
|
1012 |
docs1 = clean_doc(docs1)
|
1013 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.HTML)
|
1014 |
-
elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
|
1015 |
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
|
1016 |
add_meta(docs1, file)
|
1017 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1018 |
-
elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and have_libreoffice:
|
1019 |
docs1 = UnstructuredExcelLoader(file_path=file).load()
|
1020 |
add_meta(docs1, file)
|
1021 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
@@ -1114,21 +1244,54 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
1114 |
from dotenv import dotenv_values
|
1115 |
env_kwargs = dotenv_values(env_gpt4all_file)
|
1116 |
pdf_class_name = env_kwargs.get('PDF_CLASS_NAME', 'PyMuPDFParser')
|
|
|
|
|
1117 |
if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
|
1118 |
# GPL, only use if installed
|
1119 |
from langchain.document_loaders import PyMuPDFLoader
|
1120 |
# load() still chunks by pages, but every page has title at start to help
|
1121 |
doc1 = PyMuPDFLoader(file).load()
|
|
|
|
|
|
|
1122 |
doc1 = clean_doc(doc1)
|
1123 |
-
|
1124 |
doc1 = UnstructuredPDFLoader(file).load()
|
|
|
|
|
|
|
1125 |
# seems to not need cleaning in most cases
|
1126 |
-
|
1127 |
# open-source fallback
|
1128 |
# load() still chunks by pages, but every page has title at start to help
|
1129 |
doc1 = PyPDFLoader(file).load()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1130 |
doc1 = clean_doc(doc1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1131 |
# Some PDFs return nothing or junk from PDFMinerLoader
|
|
|
|
|
|
|
|
|
|
|
|
|
1132 |
doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
|
1133 |
add_meta(doc1, file)
|
1134 |
elif file.lower().endswith('.csv'):
|
@@ -1181,7 +1344,7 @@ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True
|
|
1181 |
is_url=False, is_txt=False,
|
1182 |
enable_captions=True,
|
1183 |
captions_model=None,
|
1184 |
-
enable_ocr=False, caption_loader=None):
|
1185 |
if verbose:
|
1186 |
if is_url:
|
1187 |
print("Ingesting URL: %s" % file, flush=True)
|
@@ -1199,6 +1362,7 @@ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True
|
|
1199 |
enable_captions=enable_captions,
|
1200 |
captions_model=captions_model,
|
1201 |
enable_ocr=enable_ocr,
|
|
|
1202 |
caption_loader=caption_loader)
|
1203 |
except BaseException as e:
|
1204 |
print("Failed to ingest %s due to %s" % (file, traceback.format_exc()))
|
@@ -1207,7 +1371,7 @@ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True
|
|
1207 |
else:
|
1208 |
exception_doc = Document(
|
1209 |
page_content='',
|
1210 |
-
metadata={"source": file, "exception": '%s
|
1211 |
"traceback": traceback.format_exc()})
|
1212 |
res = [exception_doc]
|
1213 |
if return_file:
|
@@ -1228,6 +1392,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
1228 |
captions_model=None,
|
1229 |
caption_loader=None,
|
1230 |
enable_ocr=False,
|
|
|
1231 |
existing_files=[],
|
1232 |
existing_hash_ids={},
|
1233 |
):
|
@@ -1249,11 +1414,15 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
1249 |
[globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
|
1250 |
for ftype in non_image_types]
|
1251 |
else:
|
1252 |
-
if isinstance(path_or_paths, str)
|
1253 |
-
path_or_paths
|
|
|
|
|
|
|
|
|
1254 |
# list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
|
1255 |
-
assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)),
|
1256 |
-
path_or_paths)
|
1257 |
# reform out of allowed types
|
1258 |
globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
|
1259 |
# could do below:
|
@@ -1305,6 +1474,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
1305 |
captions_model=captions_model,
|
1306 |
caption_loader=caption_loader,
|
1307 |
enable_ocr=enable_ocr,
|
|
|
1308 |
)
|
1309 |
|
1310 |
if n_jobs != 1 and len(globs_non_image_types) > 1:
|
@@ -1337,7 +1507,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
1337 |
with open(fil, 'rb') as f:
|
1338 |
documents.extend(pickle.load(f))
|
1339 |
# remove temp pickle
|
1340 |
-
|
1341 |
else:
|
1342 |
documents = reduce(concat, documents)
|
1343 |
return documents
|
@@ -1345,7 +1515,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
1345 |
|
1346 |
def prep_langchain(persist_directory,
|
1347 |
load_db_if_exists,
|
1348 |
-
db_type, use_openai_embedding, langchain_mode,
|
1349 |
hf_embedding_model, n_jobs=-1, kwargs_make_db={}):
|
1350 |
"""
|
1351 |
do prep first time, involving downloads
|
@@ -1355,6 +1525,7 @@ def prep_langchain(persist_directory,
|
|
1355 |
assert langchain_mode not in ['MyData'], "Should not prep scratch data"
|
1356 |
|
1357 |
db_dir_exists = os.path.isdir(persist_directory)
|
|
|
1358 |
|
1359 |
if db_dir_exists and user_path is None:
|
1360 |
print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
|
@@ -1490,7 +1661,7 @@ def make_db(**langchain_kwargs):
|
|
1490 |
langchain_kwargs[k] = defaults_db[k]
|
1491 |
# final check for missing
|
1492 |
missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
|
1493 |
-
assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
|
1494 |
# only keep actual used
|
1495 |
langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names}
|
1496 |
return _make_db(**langchain_kwargs)
|
@@ -1524,13 +1695,14 @@ def _make_db(use_openai_embedding=False,
|
|
1524 |
first_para=False, text_limit=None,
|
1525 |
chunk=True, chunk_size=512,
|
1526 |
langchain_mode=None,
|
1527 |
-
|
1528 |
db_type='faiss',
|
1529 |
load_db_if_exists=True,
|
1530 |
db=None,
|
1531 |
n_jobs=-1,
|
1532 |
verbose=False):
|
1533 |
persist_directory = get_persist_directory(langchain_mode)
|
|
|
1534 |
# see if can get persistent chroma db
|
1535 |
db_trial = get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
1536 |
hf_embedding_model, verbose=verbose)
|
@@ -1538,23 +1710,8 @@ def _make_db(use_openai_embedding=False,
|
|
1538 |
db = db_trial
|
1539 |
|
1540 |
sources = []
|
1541 |
-
if not db
|
1542 |
-
|
1543 |
-
langchain_mode in ['UserData']:
|
1544 |
-
# Should not make MyData db this way, why avoided, only upload from UI
|
1545 |
-
assert langchain_mode not in ['MyData'], "Should not make MyData db this way"
|
1546 |
-
if verbose:
|
1547 |
-
if langchain_mode in ['UserData']:
|
1548 |
-
if user_path is not None:
|
1549 |
-
print("Checking if changed or new sources in %s, and generating sources them" % user_path,
|
1550 |
-
flush=True)
|
1551 |
-
elif db is None:
|
1552 |
-
print("user_path not passed and no db, no sources", flush=True)
|
1553 |
-
else:
|
1554 |
-
print("user_path not passed, using only existing db, no new sources", flush=True)
|
1555 |
-
else:
|
1556 |
-
print("Generating %s sources" % langchain_mode, flush=True)
|
1557 |
-
if langchain_mode in ['wiki_full', 'All', "'All'"]:
|
1558 |
from read_wiki_full import get_all_documents
|
1559 |
small_test = None
|
1560 |
print("Generating new wiki", flush=True)
|
@@ -1564,55 +1721,48 @@ def _make_db(use_openai_embedding=False,
|
|
1564 |
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1565 |
print("Chunked new wiki", flush=True)
|
1566 |
sources.extend(sources1)
|
1567 |
-
|
1568 |
sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
|
1569 |
if chunk:
|
1570 |
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1571 |
sources.extend(sources1)
|
1572 |
-
|
1573 |
# sources = get_github_docs("dagster-io", "dagster")
|
1574 |
sources1 = get_github_docs("h2oai", "h2ogpt")
|
1575 |
# FIXME: always chunk for now
|
1576 |
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1577 |
sources.extend(sources1)
|
1578 |
-
|
1579 |
sources1 = get_dai_docs(from_hf=True)
|
1580 |
if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
|
1581 |
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1582 |
sources.extend(sources1)
|
1583 |
-
|
1584 |
-
|
1585 |
-
|
1586 |
-
|
1587 |
-
|
1588 |
-
|
1589 |
-
|
1590 |
-
|
1591 |
-
|
1592 |
-
|
1593 |
-
|
1594 |
-
|
1595 |
-
|
1596 |
-
|
1597 |
-
|
1598 |
-
|
1599 |
-
|
1600 |
-
|
1601 |
-
|
1602 |
-
|
1603 |
-
|
1604 |
-
|
1605 |
-
|
1606 |
-
|
1607 |
-
|
1608 |
-
|
1609 |
-
# from langchain.document_loaders import UnstructuredURLLoader
|
1610 |
-
# loader = UnstructuredURLLoader(urls=urls)
|
1611 |
-
urls = ["https://www.birdsongsf.com/who-we-are/"]
|
1612 |
-
from langchain.document_loaders import PlaywrightURLLoader
|
1613 |
-
loader = PlaywrightURLLoader(urls=urls, remove_selectors=["header", "footer"])
|
1614 |
-
sources1 = loader.load()
|
1615 |
-
sources.extend(sources1)
|
1616 |
if not sources:
|
1617 |
if verbose:
|
1618 |
if db is not None:
|
@@ -1635,7 +1785,7 @@ def _make_db(use_openai_embedding=False,
|
|
1635 |
else:
|
1636 |
print("Did not generate db since no sources", flush=True)
|
1637 |
new_sources_metadata = [x.metadata for x in sources]
|
1638 |
-
elif user_path is not None
|
1639 |
print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
|
1640 |
db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
|
1641 |
use_openai_embedding=use_openai_embedding,
|
@@ -1733,7 +1883,7 @@ def run_qa_db(**kwargs):
|
|
1733 |
kwargs['answer_with_sources'] = True
|
1734 |
kwargs['show_rank'] = False
|
1735 |
missing_kwargs = [x for x in func_names if x not in kwargs]
|
1736 |
-
assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
|
1737 |
# only keep actual used
|
1738 |
kwargs = {k: v for k, v in kwargs.items() if k in func_names}
|
1739 |
try:
|
@@ -1747,7 +1897,7 @@ def _run_qa_db(query=None,
|
|
1747 |
context=None,
|
1748 |
use_openai_model=False, use_openai_embedding=False,
|
1749 |
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
1750 |
-
|
1751 |
detect_user_path_changes_every_query=False,
|
1752 |
db_type='faiss',
|
1753 |
model_name=None, model=None, tokenizer=None, inference_server=None,
|
@@ -1757,9 +1907,11 @@ def _run_qa_db(query=None,
|
|
1757 |
prompt_type=None,
|
1758 |
prompt_dict=None,
|
1759 |
answer_with_sources=True,
|
1760 |
-
|
|
|
1761 |
sanitize_bot_response=False,
|
1762 |
show_rank=False,
|
|
|
1763 |
load_db_if_exists=False,
|
1764 |
db=None,
|
1765 |
do_sample=False,
|
@@ -1775,8 +1927,9 @@ def _run_qa_db(query=None,
|
|
1775 |
num_return_sequences=1,
|
1776 |
langchain_mode=None,
|
1777 |
langchain_action=None,
|
1778 |
-
|
1779 |
-
|
|
|
1780 |
n_jobs=-1,
|
1781 |
verbose=False,
|
1782 |
cli=False,
|
@@ -1795,7 +1948,7 @@ def _run_qa_db(query=None,
|
|
1795 |
:param top_k_docs:
|
1796 |
:param chunk:
|
1797 |
:param chunk_size:
|
1798 |
-
:param
|
1799 |
:param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db
|
1800 |
:param model_name: model name, used to switch behaviors
|
1801 |
:param model: pre-initialized model, else will make new one
|
@@ -1803,6 +1956,7 @@ def _run_qa_db(query=None,
|
|
1803 |
:param answer_with_sources
|
1804 |
:return:
|
1805 |
"""
|
|
|
1806 |
if model is not None:
|
1807 |
assert model_name is not None # require so can make decisions
|
1808 |
assert query is not None
|
@@ -1817,6 +1971,8 @@ def _run_qa_db(query=None,
|
|
1817 |
else:
|
1818 |
prompt_dict = ''
|
1819 |
assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
|
|
|
|
|
1820 |
llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
|
1821 |
model=model,
|
1822 |
tokenizer=tokenizer,
|
@@ -1836,11 +1992,13 @@ def _run_qa_db(query=None,
|
|
1836 |
prompt_type=prompt_type,
|
1837 |
prompt_dict=prompt_dict,
|
1838 |
prompter=prompter,
|
|
|
|
|
1839 |
sanitize_bot_response=sanitize_bot_response,
|
1840 |
verbose=verbose,
|
1841 |
)
|
1842 |
|
1843 |
-
|
1844 |
scores = []
|
1845 |
chain = None
|
1846 |
|
@@ -1852,25 +2010,29 @@ def _run_qa_db(query=None,
|
|
1852 |
sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
|
1853 |
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
1854 |
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
1855 |
-
docs, chain, scores,
|
1856 |
if document_subset in non_query_commands:
|
1857 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
|
|
|
|
|
|
|
|
1858 |
yield formatted_doc_chunks, ''
|
1859 |
return
|
1860 |
-
if not
|
1861 |
-
|
1862 |
-
|
1863 |
-
|
1864 |
-
|
1865 |
-
|
1866 |
-
|
1867 |
-
|
1868 |
-
|
1869 |
-
|
1870 |
-
|
1871 |
-
|
1872 |
-
|
1873 |
-
|
1874 |
|
1875 |
if chain is None and model_name not in non_hf_types:
|
1876 |
# here if no docs at all and not HF type
|
@@ -1921,7 +2083,7 @@ def _run_qa_db(query=None,
|
|
1921 |
else:
|
1922 |
answer = chain()
|
1923 |
|
1924 |
-
if not
|
1925 |
ret = answer['output_text']
|
1926 |
extra = ''
|
1927 |
yield ret, extra
|
@@ -1933,9 +2095,10 @@ def _run_qa_db(query=None,
|
|
1933 |
|
1934 |
def get_chain(query=None,
|
1935 |
iinput=None,
|
|
|
1936 |
use_openai_model=False, use_openai_embedding=False,
|
1937 |
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
1938 |
-
|
1939 |
detect_user_path_changes_every_query=False,
|
1940 |
db_type='faiss',
|
1941 |
model_name=None,
|
@@ -1943,13 +2106,15 @@ def get_chain(query=None,
|
|
1943 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
1944 |
prompt_type=None,
|
1945 |
prompt_dict=None,
|
1946 |
-
|
|
|
1947 |
load_db_if_exists=False,
|
1948 |
db=None,
|
1949 |
langchain_mode=None,
|
1950 |
langchain_action=None,
|
1951 |
-
|
1952 |
-
|
|
|
1953 |
n_jobs=-1,
|
1954 |
# beyond run_db_query:
|
1955 |
llm=None,
|
@@ -1961,14 +2126,15 @@ def get_chain(query=None,
|
|
1961 |
auto_reduce_chunks=True,
|
1962 |
max_chunks=100,
|
1963 |
):
|
|
|
1964 |
# determine whether use of context out of docs is planned
|
1965 |
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
1966 |
-
if langchain_mode in ['Disabled', '
|
1967 |
-
|
1968 |
else:
|
1969 |
-
|
1970 |
else:
|
1971 |
-
|
1972 |
|
1973 |
# https://github.com/hwchase17/langchain/issues/1946
|
1974 |
# FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
|
@@ -1985,14 +2151,17 @@ def get_chain(query=None,
|
|
1985 |
# avoid looking at user_path during similarity search db handling,
|
1986 |
# if already have db and not updating from user_path every query
|
1987 |
# but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was
|
1988 |
-
|
|
|
|
|
|
|
1989 |
db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
|
1990 |
hf_embedding_model=hf_embedding_model,
|
1991 |
first_para=first_para, text_limit=text_limit,
|
1992 |
chunk=chunk,
|
1993 |
chunk_size=chunk_size,
|
1994 |
langchain_mode=langchain_mode,
|
1995 |
-
|
1996 |
db_type=db_type,
|
1997 |
load_db_if_exists=load_db_if_exists,
|
1998 |
db=db,
|
@@ -2012,7 +2181,7 @@ def get_chain(query=None,
|
|
2012 |
else:
|
2013 |
extra = ""
|
2014 |
prefix = ""
|
2015 |
-
if langchain_mode in ['Disabled', '
|
2016 |
template_if_no_docs = template = """%s{context}{question}""" % prefix
|
2017 |
else:
|
2018 |
template = """%s
|
@@ -2053,7 +2222,7 @@ def get_chain(query=None,
|
|
2053 |
else:
|
2054 |
use_template = False
|
2055 |
|
2056 |
-
if db and
|
2057 |
base_path = 'locks'
|
2058 |
makedirs(base_path)
|
2059 |
if hasattr(db, '_persist_directory'):
|
@@ -2067,10 +2236,10 @@ def get_chain(query=None,
|
|
2067 |
filter_kwargs = {}
|
2068 |
else:
|
2069 |
assert document_choice is not None, "Document choice was None"
|
2070 |
-
if len(document_choice) >= 1 and document_choice[0] ==
|
2071 |
filter_kwargs = {}
|
2072 |
elif len(document_choice) >= 2:
|
2073 |
-
if document_choice[0] ==
|
2074 |
# remove 'All'
|
2075 |
document_choice = document_choice[1:]
|
2076 |
or_filter = [{"source": {"$eq": x}} for x in document_choice]
|
@@ -2082,18 +2251,18 @@ def get_chain(query=None,
|
|
2082 |
else:
|
2083 |
# shouldn't reach
|
2084 |
filter_kwargs = {}
|
2085 |
-
if langchain_mode in [LangChainMode.LLM.value
|
2086 |
docs = []
|
2087 |
scores = []
|
2088 |
-
elif document_subset ==
|
2089 |
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
|
2090 |
# similar to langchain's chroma's _results_to_docs_and_scores
|
2091 |
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
|
2092 |
for result in zip(db_documents, db_metadatas)]
|
2093 |
|
2094 |
# order documents
|
2095 |
-
doc_hashes = [x
|
2096 |
-
doc_chunk_ids = [x
|
2097 |
docs_with_score = [x for _, _, x in
|
2098 |
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
|
2099 |
]
|
@@ -2173,8 +2342,8 @@ def get_chain(query=None,
|
|
2173 |
docs_with_score.reverse()
|
2174 |
# cut off so no high distance docs/sources considered
|
2175 |
have_any_docs |= len(docs_with_score) > 0 # before cut
|
2176 |
-
docs = [x[0] for x in docs_with_score if x[1] <
|
2177 |
-
scores = [x[1] for x in docs_with_score if x[1] <
|
2178 |
if len(scores) > 0 and verbose:
|
2179 |
print("Distance: min: %s max: %s mean: %s median: %s" %
|
2180 |
(scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
|
@@ -2182,7 +2351,7 @@ def get_chain(query=None,
|
|
2182 |
docs = []
|
2183 |
scores = []
|
2184 |
|
2185 |
-
if not docs and
|
2186 |
# if HF type and have no docs, can bail out
|
2187 |
return docs, None, [], False, have_any_docs
|
2188 |
|
@@ -2205,7 +2374,7 @@ def get_chain(query=None,
|
|
2205 |
|
2206 |
if len(docs) == 0:
|
2207 |
# avoid context == in prompt then
|
2208 |
-
|
2209 |
template = template_if_no_docs
|
2210 |
|
2211 |
if langchain_action == LangChainAction.QUERY.value:
|
@@ -2221,7 +2390,7 @@ def get_chain(query=None,
|
|
2221 |
else:
|
2222 |
# only if use_openai_model = True, unused normally except in testing
|
2223 |
chain = load_qa_with_sources_chain(llm)
|
2224 |
-
if not
|
2225 |
chain_kwargs = dict(input_documents=[], question=query)
|
2226 |
else:
|
2227 |
chain_kwargs = dict(input_documents=docs, question=query)
|
@@ -2248,7 +2417,7 @@ def get_chain(query=None,
|
|
2248 |
else:
|
2249 |
raise RuntimeError("No such langchain_action=%s" % langchain_action)
|
2250 |
|
2251 |
-
return docs, target, scores,
|
2252 |
|
2253 |
|
2254 |
def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
|
@@ -2302,6 +2471,7 @@ def clean_doc(docs1):
|
|
2302 |
|
2303 |
def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
|
2304 |
if not chunk:
|
|
|
2305 |
return sources
|
2306 |
if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources):
|
2307 |
# if just one document
|
@@ -2320,8 +2490,7 @@ def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
|
|
2320 |
source_chunks = splitter.split_documents(sources)
|
2321 |
|
2322 |
# currently in order, but when pull from db won't be, so mark order and document by hash
|
2323 |
-
|
2324 |
-
[x.metadata.update(dict(doc_hash=doc_hash, chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)]
|
2325 |
|
2326 |
return source_chunks
|
2327 |
|
|
|
21 |
from joblib import delayed
|
22 |
from langchain.callbacks import streaming_stdout
|
23 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
24 |
+
from langchain.schema import LLMResult
|
25 |
from tqdm import tqdm
|
26 |
|
27 |
+
from enums import DocumentSubset, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
|
28 |
+
LangChainAction, LangChainMode, DocumentChoice
|
29 |
from evaluate_params import gen_hyper
|
30 |
from gen import get_model, SEED
|
31 |
from prompter import non_hf_types, PromptType, Prompter
|
32 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
33 |
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
|
34 |
+
have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_pymupdf, set_openai
|
35 |
from utils_langchain import StreamingGradioCallbackHandler
|
36 |
|
37 |
import_matplotlib()
|
|
|
96 |
db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
97 |
hf_embedding_model, verbose=False)
|
98 |
if db is None:
|
99 |
+
from chromadb.config import Settings
|
100 |
+
client_settings = Settings(anonymized_telemetry=False,
|
101 |
+
chroma_db_impl="duckdb+parquet",
|
102 |
+
persist_directory=persist_directory)
|
103 |
db = Chroma.from_documents(documents=sources,
|
104 |
embedding=embedding,
|
105 |
persist_directory=persist_directory,
|
106 |
collection_name=collection_name,
|
107 |
+
client_settings=client_settings)
|
108 |
db.persist()
|
109 |
clear_embedding(db)
|
110 |
save_embed(db, use_openai_embedding, hf_embedding_model)
|
|
|
281 |
|
282 |
from pydantic import Extra, Field, root_validator
|
283 |
|
284 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
from langchain.llms.base import LLM
|
286 |
|
287 |
|
|
|
309 |
sanitize_bot_response: bool = False
|
310 |
|
311 |
prompter: Any = None
|
312 |
+
context: Any = ''
|
313 |
+
iinput: Any = ''
|
314 |
client: Any = None
|
315 |
|
316 |
class Config:
|
|
|
354 |
stream_output = self.stream
|
355 |
gr_client = self.client
|
356 |
client_langchain_mode = 'Disabled'
|
357 |
+
client_add_chat_history_to_context = True
|
358 |
client_langchain_action = LangChainAction.QUERY.value
|
359 |
+
client_langchain_agents = []
|
360 |
top_k_docs = 1
|
361 |
chunk = True
|
362 |
chunk_size = 512
|
363 |
client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True
|
364 |
+
iinput=self.iinput if self.chat_client else '', # only for chat=True
|
365 |
+
context=self.context,
|
366 |
# streaming output is supported, loops over and outputs each generation in streaming mode
|
367 |
# but leave stream_output=False for simple input/output mode
|
368 |
stream_output=stream_output,
|
|
|
383 |
chat=self.chat_client,
|
384 |
|
385 |
instruction_nochat=prompt if not self.chat_client else '',
|
386 |
+
iinput_nochat=self.iinput if not self.chat_client else '',
|
387 |
langchain_mode=client_langchain_mode,
|
388 |
+
add_chat_history_to_context=client_add_chat_history_to_context,
|
389 |
langchain_action=client_langchain_action,
|
390 |
+
langchain_agents=client_langchain_agents,
|
391 |
top_k_docs=top_k_docs,
|
392 |
chunk=chunk,
|
393 |
chunk_size=chunk_size,
|
394 |
+
document_subset=DocumentSubset.Relevant.name,
|
395 |
+
document_choice=[DocumentChoice.ALL.value],
|
396 |
)
|
397 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
398 |
if not stream_output:
|
|
|
462 |
stream: bool = False
|
463 |
sanitize_bot_response: bool = False
|
464 |
prompter: Any = None
|
465 |
+
context: Any = ''
|
466 |
+
iinput: Any = ''
|
467 |
tokenizer: Any = None
|
468 |
client: Any = None
|
469 |
|
|
|
505 |
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
|
506 |
|
507 |
# NOTE: TGI server does not add prompting, so must do here
|
508 |
+
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
|
509 |
prompt = self.prompter.generate_prompt(data_point)
|
510 |
|
511 |
gen_server_kwargs = dict(do_sample=self.do_sample,
|
|
|
571 |
|
572 |
|
573 |
from langchain.chat_models import ChatOpenAI
|
574 |
+
from langchain.llms import OpenAI
|
575 |
+
from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \
|
576 |
+
update_token_usage
|
577 |
+
|
578 |
+
|
579 |
+
class H2OOpenAI(OpenAI):
|
580 |
+
"""
|
581 |
+
New class to handle vLLM's use of OpenAI, no vllm_chat supported, so only need here
|
582 |
+
Handles prompting that OpenAI doesn't need, stopping as well
|
583 |
+
"""
|
584 |
+
stop_sequences: Any = None
|
585 |
+
sanitize_bot_response: bool = False
|
586 |
+
prompter: Any = None
|
587 |
+
context: Any = ''
|
588 |
+
iinput: Any = ''
|
589 |
+
tokenizer: Any = None
|
590 |
+
|
591 |
+
@classmethod
|
592 |
+
def all_required_field_names(cls) -> Set:
|
593 |
+
all_required_field_names = super(OpenAI, cls).all_required_field_names()
|
594 |
+
all_required_field_names.update(
|
595 |
+
{'top_p', 'frequency_penalty', 'presence_penalty', 'stop_sequences', 'sanitize_bot_response', 'prompter',
|
596 |
+
'tokenizer'})
|
597 |
+
return all_required_field_names
|
598 |
+
|
599 |
+
def _generate(
|
600 |
+
self,
|
601 |
+
prompts: List[str],
|
602 |
+
stop: Optional[List[str]] = None,
|
603 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
604 |
+
**kwargs: Any,
|
605 |
+
) -> LLMResult:
|
606 |
+
stop = self.stop_sequences if not stop else self.stop_sequences + stop
|
607 |
+
|
608 |
+
# HF inference server needs control over input tokens
|
609 |
+
assert self.tokenizer is not None
|
610 |
+
from h2oai_pipeline import H2OTextGenerationPipeline
|
611 |
+
for prompti, prompt in enumerate(prompts):
|
612 |
+
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
|
613 |
+
# NOTE: OpenAI/vLLM server does not add prompting, so must do here
|
614 |
+
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
|
615 |
+
prompt = self.prompter.generate_prompt(data_point)
|
616 |
+
prompts[prompti] = prompt
|
617 |
+
|
618 |
+
params = self._invocation_params
|
619 |
+
params = {**params, **kwargs}
|
620 |
+
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
621 |
+
choices = []
|
622 |
+
token_usage: Dict[str, int] = {}
|
623 |
+
# Get the token usage from the response.
|
624 |
+
# Includes prompt, completion, and total tokens used.
|
625 |
+
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
626 |
+
text = ''
|
627 |
+
for _prompts in sub_prompts:
|
628 |
+
if self.streaming:
|
629 |
+
text_with_prompt = ""
|
630 |
+
prompt = _prompts[0]
|
631 |
+
if len(_prompts) > 1:
|
632 |
+
raise ValueError("Cannot stream results with multiple prompts.")
|
633 |
+
params["stream"] = True
|
634 |
+
response = _streaming_response_template()
|
635 |
+
first = True
|
636 |
+
for stream_resp in completion_with_retry(
|
637 |
+
self, prompt=_prompts, **params
|
638 |
+
):
|
639 |
+
if first:
|
640 |
+
stream_resp["choices"][0]["text"] = prompt + stream_resp["choices"][0]["text"]
|
641 |
+
first = False
|
642 |
+
text_chunk = stream_resp["choices"][0]["text"]
|
643 |
+
text_with_prompt += text_chunk
|
644 |
+
text = self.prompter.get_response(text_with_prompt, prompt=prompt,
|
645 |
+
sanitize_bot_response=self.sanitize_bot_response)
|
646 |
+
if run_manager:
|
647 |
+
run_manager.on_llm_new_token(
|
648 |
+
text_chunk,
|
649 |
+
verbose=self.verbose,
|
650 |
+
logprobs=stream_resp["choices"][0]["logprobs"],
|
651 |
+
)
|
652 |
+
_update_response(response, stream_resp)
|
653 |
+
choices.extend(response["choices"])
|
654 |
+
else:
|
655 |
+
response = completion_with_retry(self, prompt=_prompts, **params)
|
656 |
+
choices.extend(response["choices"])
|
657 |
+
if not self.streaming:
|
658 |
+
# Can't update token usage if streaming
|
659 |
+
update_token_usage(_keys, response, token_usage)
|
660 |
+
choices[0]['text'] = text
|
661 |
+
return self.create_llm_result(choices, prompts, token_usage)
|
662 |
|
663 |
|
664 |
class H2OChatOpenAI(ChatOpenAI):
|
|
|
689 |
prompt_type=None,
|
690 |
prompt_dict=None,
|
691 |
prompter=None,
|
692 |
+
context=None,
|
693 |
+
iinput=None,
|
694 |
sanitize_bot_response=False,
|
695 |
verbose=False,
|
696 |
):
|
697 |
+
if inference_server is None:
|
698 |
+
inference_server = ''
|
699 |
+
if use_openai_model or inference_server.startswith('openai') or inference_server.startswith('vllm'):
|
700 |
if use_openai_model and model_name is None:
|
701 |
model_name = "gpt-3.5-turbo"
|
702 |
+
# FIXME: Will later import be ignored? I think so, so should be fine
|
703 |
+
openai, inf_type = set_openai(inference_server)
|
704 |
+
kwargs_extra = {}
|
705 |
+
if inference_server == 'openai_chat' or inf_type == 'vllm_chat':
|
706 |
cls = H2OChatOpenAI
|
707 |
+
# FIXME: Support context, iinput
|
708 |
+
else:
|
709 |
+
cls = H2OOpenAI
|
710 |
+
if inf_type == 'vllm':
|
711 |
+
terminate_response = prompter.terminate_response or []
|
712 |
+
stop_sequences = list(set(terminate_response + [prompter.PreResponse]))
|
713 |
+
stop_sequences = [x for x in stop_sequences if x]
|
714 |
+
kwargs_extra = dict(stop_sequences=stop_sequences,
|
715 |
+
sanitize_bot_response=sanitize_bot_response,
|
716 |
+
prompter=prompter,
|
717 |
+
context=context,
|
718 |
+
iinput=iinput,
|
719 |
+
tokenizer=tokenizer,
|
720 |
+
client=None)
|
721 |
+
|
722 |
callbacks = [StreamingGradioCallbackHandler()]
|
723 |
llm = cls(model_name=model_name,
|
724 |
temperature=temperature if do_sample else 0,
|
|
|
728 |
frequency_penalty=0,
|
729 |
presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
|
730 |
callbacks=callbacks if stream_output else None,
|
731 |
+
openai_api_key=openai.api_key,
|
732 |
+
openai_api_base=openai.api_base,
|
733 |
+
logit_bias=None if inf_type == 'vllm' else {},
|
734 |
+
max_retries=2,
|
735 |
+
streaming=stream_output,
|
736 |
+
**kwargs_extra
|
737 |
)
|
738 |
streamer = callbacks[0] if stream_output else None
|
739 |
if inference_server in ['openai', 'openai_chat']:
|
740 |
prompt_type = inference_server
|
741 |
else:
|
742 |
+
# vllm goes here
|
743 |
prompt_type = prompt_type or 'plain'
|
744 |
elif inference_server:
|
745 |
assert inference_server.startswith(
|
|
|
788 |
callbacks=callbacks if stream_output else None,
|
789 |
stream=stream_output,
|
790 |
prompter=prompter,
|
791 |
+
context=context,
|
792 |
+
iinput=iinput,
|
793 |
client=gr_client,
|
794 |
sanitize_bot_response=sanitize_bot_response,
|
795 |
)
|
|
|
810 |
callbacks=callbacks if stream_output else None,
|
811 |
stream=stream_output,
|
812 |
prompter=prompter,
|
813 |
+
context=context,
|
814 |
+
iinput=iinput,
|
815 |
tokenizer=tokenizer,
|
816 |
client=hf_client,
|
817 |
timeout=max_time,
|
|
|
844 |
verbose=verbose,
|
845 |
streaming=stream_output,
|
846 |
prompter=prompter,
|
847 |
+
context=context,
|
848 |
+
iinput=iinput,
|
849 |
)
|
850 |
else:
|
851 |
if model is None:
|
|
|
888 |
from h2oai_pipeline import H2OTextGenerationPipeline
|
889 |
pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
|
890 |
prompter=prompter,
|
891 |
+
context=context,
|
892 |
+
iinpout=iinput,
|
893 |
prompt_type=prompt_type,
|
894 |
prompt_dict=prompt_dict,
|
895 |
sanitize_bot_response=sanitize_bot_response,
|
|
|
1043 |
return sources
|
1044 |
|
1045 |
|
|
|
1046 |
image_types = ["png", "jpg", "jpeg"]
|
1047 |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
1048 |
"md",
|
|
|
1053 |
]
|
1054 |
# "msg", GPL3
|
1055 |
|
1056 |
+
if have_libreoffice or True:
|
1057 |
+
# or True so it tries to load, e.g. on MAC/Windows, even if don't have libreoffice since works without that
|
1058 |
non_image_types.extend(["docx", "doc", "xls", "xlsx"])
|
1059 |
|
1060 |
file_types = non_image_types + image_types
|
|
|
1063 |
def add_meta(docs1, file):
|
1064 |
file_extension = pathlib.Path(file).suffix
|
1065 |
hashid = hash_file(file)
|
1066 |
+
doc_hash = str(uuid.uuid4())[:10]
|
1067 |
if not isinstance(docs1, (list, tuple, types.GeneratorType)):
|
1068 |
docs1 = [docs1]
|
1069 |
+
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now()), hashid=hashid, doc_hash=doc_hash)) for
|
1070 |
+
x in docs1]
|
1071 |
|
1072 |
|
1073 |
def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
|
1075 |
is_url=False, is_txt=False,
|
1076 |
enable_captions=True,
|
1077 |
captions_model=None,
|
1078 |
+
enable_ocr=False, enable_pdf_ocr='auto', caption_loader=None,
|
1079 |
headsize=50):
|
1080 |
if file is None:
|
1081 |
if fail_any_exception:
|
|
|
1092 |
base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10]
|
1093 |
base_path = os.path.join(dir_name, base_name)
|
1094 |
if is_url:
|
1095 |
+
file = file.strip() # in case accidental spaces in front or at end
|
1096 |
if file.lower().startswith('arxiv:'):
|
1097 |
query = file.lower().split('arxiv:')
|
1098 |
if len(query) == 2 and have_arxiv:
|
|
|
1141 |
add_meta(docs1, file)
|
1142 |
docs1 = clean_doc(docs1)
|
1143 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.HTML)
|
1144 |
+
elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and (have_libreoffice or True):
|
1145 |
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
|
1146 |
add_meta(docs1, file)
|
1147 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1148 |
+
elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and (have_libreoffice or True):
|
1149 |
docs1 = UnstructuredExcelLoader(file_path=file).load()
|
1150 |
add_meta(docs1, file)
|
1151 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
|
|
1244 |
from dotenv import dotenv_values
|
1245 |
env_kwargs = dotenv_values(env_gpt4all_file)
|
1246 |
pdf_class_name = env_kwargs.get('PDF_CLASS_NAME', 'PyMuPDFParser')
|
1247 |
+
doc1 = []
|
1248 |
+
handled = False
|
1249 |
if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
|
1250 |
# GPL, only use if installed
|
1251 |
from langchain.document_loaders import PyMuPDFLoader
|
1252 |
# load() still chunks by pages, but every page has title at start to help
|
1253 |
doc1 = PyMuPDFLoader(file).load()
|
1254 |
+
# remove empty documents
|
1255 |
+
handled |= len(doc1) > 0
|
1256 |
+
doc1 = [x for x in doc1 if x.page_content]
|
1257 |
doc1 = clean_doc(doc1)
|
1258 |
+
if len(doc1) == 0:
|
1259 |
doc1 = UnstructuredPDFLoader(file).load()
|
1260 |
+
handled |= len(doc1) > 0
|
1261 |
+
# remove empty documents
|
1262 |
+
doc1 = [x for x in doc1 if x.page_content]
|
1263 |
# seems to not need cleaning in most cases
|
1264 |
+
if len(doc1) == 0:
|
1265 |
# open-source fallback
|
1266 |
# load() still chunks by pages, but every page has title at start to help
|
1267 |
doc1 = PyPDFLoader(file).load()
|
1268 |
+
handled |= len(doc1) > 0
|
1269 |
+
# remove empty documents
|
1270 |
+
doc1 = [x for x in doc1 if x.page_content]
|
1271 |
+
doc1 = clean_doc(doc1)
|
1272 |
+
if have_pymupdf and len(doc1) == 0:
|
1273 |
+
# GPL, only use if installed
|
1274 |
+
from langchain.document_loaders import PyMuPDFLoader
|
1275 |
+
# load() still chunks by pages, but every page has title at start to help
|
1276 |
+
doc1 = PyMuPDFLoader(file).load()
|
1277 |
+
handled |= len(doc1) > 0
|
1278 |
+
# remove empty documents
|
1279 |
+
doc1 = [x for x in doc1 if x.page_content]
|
1280 |
doc1 = clean_doc(doc1)
|
1281 |
+
if len(doc1) == 0 and enable_pdf_ocr == 'auto' or enable_pdf_ocr == 'on':
|
1282 |
+
# try OCR in end since slowest, but works on pure image pages well
|
1283 |
+
doc1 = UnstructuredPDFLoader(file, strategy='ocr_only').load()
|
1284 |
+
handled |= len(doc1) > 0
|
1285 |
+
# remove empty documents
|
1286 |
+
doc1 = [x for x in doc1 if x.page_content]
|
1287 |
+
# seems to not need cleaning in most cases
|
1288 |
# Some PDFs return nothing or junk from PDFMinerLoader
|
1289 |
+
if len(doc1) == 0:
|
1290 |
+
# if literally nothing, show failed to parse so user knows, since unlikely nothing in PDF at all.
|
1291 |
+
if handled:
|
1292 |
+
raise ValueError("%s had no valid text, but meta data was parsed" % file)
|
1293 |
+
else:
|
1294 |
+
raise ValueError("%s had no valid text and no meta data was parsed" % file)
|
1295 |
doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
|
1296 |
add_meta(doc1, file)
|
1297 |
elif file.lower().endswith('.csv'):
|
|
|
1344 |
is_url=False, is_txt=False,
|
1345 |
enable_captions=True,
|
1346 |
captions_model=None,
|
1347 |
+
enable_ocr=False, enable_pdf_ocr='auto', caption_loader=None):
|
1348 |
if verbose:
|
1349 |
if is_url:
|
1350 |
print("Ingesting URL: %s" % file, flush=True)
|
|
|
1362 |
enable_captions=enable_captions,
|
1363 |
captions_model=captions_model,
|
1364 |
enable_ocr=enable_ocr,
|
1365 |
+
enable_pdf_ocr=enable_pdf_ocr,
|
1366 |
caption_loader=caption_loader)
|
1367 |
except BaseException as e:
|
1368 |
print("Failed to ingest %s due to %s" % (file, traceback.format_exc()))
|
|
|
1371 |
else:
|
1372 |
exception_doc = Document(
|
1373 |
page_content='',
|
1374 |
+
metadata={"source": file, "exception": '%s Exception: %s' % (file, str(e)),
|
1375 |
"traceback": traceback.format_exc()})
|
1376 |
res = [exception_doc]
|
1377 |
if return_file:
|
|
|
1392 |
captions_model=None,
|
1393 |
caption_loader=None,
|
1394 |
enable_ocr=False,
|
1395 |
+
enable_pdf_ocr='auto',
|
1396 |
existing_files=[],
|
1397 |
existing_hash_ids={},
|
1398 |
):
|
|
|
1414 |
[globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
|
1415 |
for ftype in non_image_types]
|
1416 |
else:
|
1417 |
+
if isinstance(path_or_paths, str):
|
1418 |
+
if os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths):
|
1419 |
+
path_or_paths = [path_or_paths]
|
1420 |
+
else:
|
1421 |
+
# path was deleted etc.
|
1422 |
+
return []
|
1423 |
# list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
|
1424 |
+
assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)), \
|
1425 |
+
"Wrong type for path_or_paths: %s %s" % (path_or_paths, type(path_or_paths))
|
1426 |
# reform out of allowed types
|
1427 |
globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
|
1428 |
# could do below:
|
|
|
1474 |
captions_model=captions_model,
|
1475 |
caption_loader=caption_loader,
|
1476 |
enable_ocr=enable_ocr,
|
1477 |
+
enable_pdf_ocr=enable_pdf_ocr,
|
1478 |
)
|
1479 |
|
1480 |
if n_jobs != 1 and len(globs_non_image_types) > 1:
|
|
|
1507 |
with open(fil, 'rb') as f:
|
1508 |
documents.extend(pickle.load(f))
|
1509 |
# remove temp pickle
|
1510 |
+
remove(fil)
|
1511 |
else:
|
1512 |
documents = reduce(concat, documents)
|
1513 |
return documents
|
|
|
1515 |
|
1516 |
def prep_langchain(persist_directory,
|
1517 |
load_db_if_exists,
|
1518 |
+
db_type, use_openai_embedding, langchain_mode, langchain_mode_paths,
|
1519 |
hf_embedding_model, n_jobs=-1, kwargs_make_db={}):
|
1520 |
"""
|
1521 |
do prep first time, involving downloads
|
|
|
1525 |
assert langchain_mode not in ['MyData'], "Should not prep scratch data"
|
1526 |
|
1527 |
db_dir_exists = os.path.isdir(persist_directory)
|
1528 |
+
user_path = langchain_mode_paths.get(langchain_mode)
|
1529 |
|
1530 |
if db_dir_exists and user_path is None:
|
1531 |
print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
|
|
|
1661 |
langchain_kwargs[k] = defaults_db[k]
|
1662 |
# final check for missing
|
1663 |
missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
|
1664 |
+
assert not missing_kwargs, "Missing kwargs for make_db: %s" % missing_kwargs
|
1665 |
# only keep actual used
|
1666 |
langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names}
|
1667 |
return _make_db(**langchain_kwargs)
|
|
|
1695 |
first_para=False, text_limit=None,
|
1696 |
chunk=True, chunk_size=512,
|
1697 |
langchain_mode=None,
|
1698 |
+
langchain_mode_paths=None,
|
1699 |
db_type='faiss',
|
1700 |
load_db_if_exists=True,
|
1701 |
db=None,
|
1702 |
n_jobs=-1,
|
1703 |
verbose=False):
|
1704 |
persist_directory = get_persist_directory(langchain_mode)
|
1705 |
+
user_path = langchain_mode_paths.get(langchain_mode)
|
1706 |
# see if can get persistent chroma db
|
1707 |
db_trial = get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
1708 |
hf_embedding_model, verbose=verbose)
|
|
|
1710 |
db = db_trial
|
1711 |
|
1712 |
sources = []
|
1713 |
+
if not db:
|
1714 |
+
if langchain_mode in ['wiki_full']:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1715 |
from read_wiki_full import get_all_documents
|
1716 |
small_test = None
|
1717 |
print("Generating new wiki", flush=True)
|
|
|
1721 |
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1722 |
print("Chunked new wiki", flush=True)
|
1723 |
sources.extend(sources1)
|
1724 |
+
elif langchain_mode in ['wiki']:
|
1725 |
sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
|
1726 |
if chunk:
|
1727 |
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1728 |
sources.extend(sources1)
|
1729 |
+
elif langchain_mode in ['github h2oGPT']:
|
1730 |
# sources = get_github_docs("dagster-io", "dagster")
|
1731 |
sources1 = get_github_docs("h2oai", "h2ogpt")
|
1732 |
# FIXME: always chunk for now
|
1733 |
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1734 |
sources.extend(sources1)
|
1735 |
+
elif langchain_mode in ['DriverlessAI docs']:
|
1736 |
sources1 = get_dai_docs(from_hf=True)
|
1737 |
if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
|
1738 |
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1739 |
sources.extend(sources1)
|
1740 |
+
if user_path:
|
1741 |
+
# UserData or custom, which has to be from user's disk
|
1742 |
+
if db is not None:
|
1743 |
+
# NOTE: Ignore file names for now, only go by hash ids
|
1744 |
+
# existing_files = get_existing_files(db)
|
1745 |
+
existing_files = []
|
1746 |
+
existing_hash_ids = get_existing_hash_ids(db)
|
1747 |
+
else:
|
1748 |
+
# pretend no existing files so won't filter
|
1749 |
+
existing_files = []
|
1750 |
+
existing_hash_ids = []
|
1751 |
+
# chunk internally for speed over multiple docs
|
1752 |
+
# FIXME: If first had old Hash=None and switch embeddings,
|
1753 |
+
# then re-embed, and then hit here and reload so have hash, and then re-embed.
|
1754 |
+
sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
|
1755 |
+
existing_files=existing_files, existing_hash_ids=existing_hash_ids)
|
1756 |
+
new_metadata_sources = set([x.metadata['source'] for x in sources1])
|
1757 |
+
if new_metadata_sources:
|
1758 |
+
print("Loaded %s new files as sources to add to %s" % (len(new_metadata_sources), langchain_mode),
|
1759 |
+
flush=True)
|
1760 |
+
if verbose:
|
1761 |
+
print("Files added: %s" % '\n'.join(new_metadata_sources), flush=True)
|
1762 |
+
sources.extend(sources1)
|
1763 |
+
print("Loaded %s sources for potentially adding to %s" % (len(sources), langchain_mode), flush=True)
|
1764 |
+
|
1765 |
+
# see if got sources
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1766 |
if not sources:
|
1767 |
if verbose:
|
1768 |
if db is not None:
|
|
|
1785 |
else:
|
1786 |
print("Did not generate db since no sources", flush=True)
|
1787 |
new_sources_metadata = [x.metadata for x in sources]
|
1788 |
+
elif user_path is not None:
|
1789 |
print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
|
1790 |
db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
|
1791 |
use_openai_embedding=use_openai_embedding,
|
|
|
1883 |
kwargs['answer_with_sources'] = True
|
1884 |
kwargs['show_rank'] = False
|
1885 |
missing_kwargs = [x for x in func_names if x not in kwargs]
|
1886 |
+
assert not missing_kwargs, "Missing kwargs for run_qa_db: %s" % missing_kwargs
|
1887 |
# only keep actual used
|
1888 |
kwargs = {k: v for k, v in kwargs.items() if k in func_names}
|
1889 |
try:
|
|
|
1897 |
context=None,
|
1898 |
use_openai_model=False, use_openai_embedding=False,
|
1899 |
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
1900 |
+
langchain_mode_paths={},
|
1901 |
detect_user_path_changes_every_query=False,
|
1902 |
db_type='faiss',
|
1903 |
model_name=None, model=None, tokenizer=None, inference_server=None,
|
|
|
1907 |
prompt_type=None,
|
1908 |
prompt_dict=None,
|
1909 |
answer_with_sources=True,
|
1910 |
+
cut_distance=1.64,
|
1911 |
+
add_chat_history_to_context=True,
|
1912 |
sanitize_bot_response=False,
|
1913 |
show_rank=False,
|
1914 |
+
use_llm_if_no_docs=False,
|
1915 |
load_db_if_exists=False,
|
1916 |
db=None,
|
1917 |
do_sample=False,
|
|
|
1927 |
num_return_sequences=1,
|
1928 |
langchain_mode=None,
|
1929 |
langchain_action=None,
|
1930 |
+
langchain_agents=None,
|
1931 |
+
document_subset=DocumentSubset.Relevant.name,
|
1932 |
+
document_choice=[DocumentChoice.ALL.value],
|
1933 |
n_jobs=-1,
|
1934 |
verbose=False,
|
1935 |
cli=False,
|
|
|
1948 |
:param top_k_docs:
|
1949 |
:param chunk:
|
1950 |
:param chunk_size:
|
1951 |
+
:param langchain_mode_paths: dict of langchain_mode -> user path to glob recursively from
|
1952 |
:param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db
|
1953 |
:param model_name: model name, used to switch behaviors
|
1954 |
:param model: pre-initialized model, else will make new one
|
|
|
1956 |
:param answer_with_sources
|
1957 |
:return:
|
1958 |
"""
|
1959 |
+
assert langchain_mode_paths is not None
|
1960 |
if model is not None:
|
1961 |
assert model_name is not None # require so can make decisions
|
1962 |
assert query is not None
|
|
|
1971 |
else:
|
1972 |
prompt_dict = ''
|
1973 |
assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
|
1974 |
+
# pass in context to LLM directly, since already has prompt_type structure
|
1975 |
+
# can't pass through langchain in get_chain() to LLM: https://github.com/hwchase17/langchain/issues/6638
|
1976 |
llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
|
1977 |
model=model,
|
1978 |
tokenizer=tokenizer,
|
|
|
1992 |
prompt_type=prompt_type,
|
1993 |
prompt_dict=prompt_dict,
|
1994 |
prompter=prompter,
|
1995 |
+
context=context if add_chat_history_to_context else '',
|
1996 |
+
iinput=iinput if add_chat_history_to_context else '',
|
1997 |
sanitize_bot_response=sanitize_bot_response,
|
1998 |
verbose=verbose,
|
1999 |
)
|
2000 |
|
2001 |
+
use_docs_planned = False
|
2002 |
scores = []
|
2003 |
chain = None
|
2004 |
|
|
|
2010 |
sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
|
2011 |
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
2012 |
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
2013 |
+
docs, chain, scores, use_docs_planned, have_any_docs = get_chain(**sim_kwargs)
|
2014 |
if document_subset in non_query_commands:
|
2015 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
2016 |
+
if not formatted_doc_chunks and not use_llm_if_no_docs:
|
2017 |
+
yield "No sources", ''
|
2018 |
+
return
|
2019 |
+
# if no souces, outside gpt_langchain, LLM will be used with '' input
|
2020 |
yield formatted_doc_chunks, ''
|
2021 |
return
|
2022 |
+
if not use_llm_if_no_docs:
|
2023 |
+
if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
|
2024 |
+
LangChainAction.SUMMARIZE_ALL.value,
|
2025 |
+
LangChainAction.SUMMARIZE_REFINE.value]:
|
2026 |
+
ret = 'No relevant documents to summarize.' if have_any_docs else 'No documents to summarize.'
|
2027 |
+
extra = ''
|
2028 |
+
yield ret, extra
|
2029 |
+
return
|
2030 |
+
if not docs and langchain_mode not in [LangChainMode.DISABLED.value,
|
2031 |
+
LangChainMode.LLM.value]:
|
2032 |
+
ret = 'No relevant documents to query.' if have_any_docs else 'No documents to query.'
|
2033 |
+
extra = ''
|
2034 |
+
yield ret, extra
|
2035 |
+
return
|
2036 |
|
2037 |
if chain is None and model_name not in non_hf_types:
|
2038 |
# here if no docs at all and not HF type
|
|
|
2083 |
else:
|
2084 |
answer = chain()
|
2085 |
|
2086 |
+
if not use_docs_planned:
|
2087 |
ret = answer['output_text']
|
2088 |
extra = ''
|
2089 |
yield ret, extra
|
|
|
2095 |
|
2096 |
def get_chain(query=None,
|
2097 |
iinput=None,
|
2098 |
+
context=None, # FIXME: https://github.com/hwchase17/langchain/issues/6638
|
2099 |
use_openai_model=False, use_openai_embedding=False,
|
2100 |
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
2101 |
+
langchain_mode_paths=None,
|
2102 |
detect_user_path_changes_every_query=False,
|
2103 |
db_type='faiss',
|
2104 |
model_name=None,
|
|
|
2106 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
2107 |
prompt_type=None,
|
2108 |
prompt_dict=None,
|
2109 |
+
cut_distance=1.1,
|
2110 |
+
add_chat_history_to_context=True, # FIXME: https://github.com/hwchase17/langchain/issues/6638
|
2111 |
load_db_if_exists=False,
|
2112 |
db=None,
|
2113 |
langchain_mode=None,
|
2114 |
langchain_action=None,
|
2115 |
+
langchain_agents=None,
|
2116 |
+
document_subset=DocumentSubset.Relevant.name,
|
2117 |
+
document_choice=[DocumentChoice.ALL.value],
|
2118 |
n_jobs=-1,
|
2119 |
# beyond run_db_query:
|
2120 |
llm=None,
|
|
|
2126 |
auto_reduce_chunks=True,
|
2127 |
max_chunks=100,
|
2128 |
):
|
2129 |
+
assert langchain_agents is not None # should be at least []
|
2130 |
# determine whether use of context out of docs is planned
|
2131 |
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
2132 |
+
if langchain_mode in ['Disabled', 'LLM']:
|
2133 |
+
use_docs_planned = False
|
2134 |
else:
|
2135 |
+
use_docs_planned = True
|
2136 |
else:
|
2137 |
+
use_docs_planned = True
|
2138 |
|
2139 |
# https://github.com/hwchase17/langchain/issues/1946
|
2140 |
# FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
|
|
|
2151 |
# avoid looking at user_path during similarity search db handling,
|
2152 |
# if already have db and not updating from user_path every query
|
2153 |
# but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was
|
2154 |
+
if langchain_mode_paths is None:
|
2155 |
+
langchain_mode_paths = {}
|
2156 |
+
langchain_mode_paths = langchain_mode_paths.copy()
|
2157 |
+
langchain_mode_paths[langchain_mode] = None
|
2158 |
db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
|
2159 |
hf_embedding_model=hf_embedding_model,
|
2160 |
first_para=first_para, text_limit=text_limit,
|
2161 |
chunk=chunk,
|
2162 |
chunk_size=chunk_size,
|
2163 |
langchain_mode=langchain_mode,
|
2164 |
+
langchain_mode_paths=langchain_mode_paths,
|
2165 |
db_type=db_type,
|
2166 |
load_db_if_exists=load_db_if_exists,
|
2167 |
db=db,
|
|
|
2181 |
else:
|
2182 |
extra = ""
|
2183 |
prefix = ""
|
2184 |
+
if langchain_mode in ['Disabled', 'LLM'] or not use_docs_planned:
|
2185 |
template_if_no_docs = template = """%s{context}{question}""" % prefix
|
2186 |
else:
|
2187 |
template = """%s
|
|
|
2222 |
else:
|
2223 |
use_template = False
|
2224 |
|
2225 |
+
if db and use_docs_planned:
|
2226 |
base_path = 'locks'
|
2227 |
makedirs(base_path)
|
2228 |
if hasattr(db, '_persist_directory'):
|
|
|
2236 |
filter_kwargs = {}
|
2237 |
else:
|
2238 |
assert document_choice is not None, "Document choice was None"
|
2239 |
+
if len(document_choice) >= 1 and document_choice[0] == DocumentChoice.ALL.value:
|
2240 |
filter_kwargs = {}
|
2241 |
elif len(document_choice) >= 2:
|
2242 |
+
if document_choice[0] == DocumentChoice.ALL.value:
|
2243 |
# remove 'All'
|
2244 |
document_choice = document_choice[1:]
|
2245 |
or_filter = [{"source": {"$eq": x}} for x in document_choice]
|
|
|
2251 |
else:
|
2252 |
# shouldn't reach
|
2253 |
filter_kwargs = {}
|
2254 |
+
if langchain_mode in [LangChainMode.LLM.value]:
|
2255 |
docs = []
|
2256 |
scores = []
|
2257 |
+
elif document_subset == DocumentSubset.TopKSources.name or query in [None, '', '\n']:
|
2258 |
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
|
2259 |
# similar to langchain's chroma's _results_to_docs_and_scores
|
2260 |
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
|
2261 |
for result in zip(db_documents, db_metadatas)]
|
2262 |
|
2263 |
# order documents
|
2264 |
+
doc_hashes = [x.get('doc_hash', 'None') for x in db_metadatas]
|
2265 |
+
doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas]
|
2266 |
docs_with_score = [x for _, _, x in
|
2267 |
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
|
2268 |
]
|
|
|
2342 |
docs_with_score.reverse()
|
2343 |
# cut off so no high distance docs/sources considered
|
2344 |
have_any_docs |= len(docs_with_score) > 0 # before cut
|
2345 |
+
docs = [x[0] for x in docs_with_score if x[1] < cut_distance]
|
2346 |
+
scores = [x[1] for x in docs_with_score if x[1] < cut_distance]
|
2347 |
if len(scores) > 0 and verbose:
|
2348 |
print("Distance: min: %s max: %s mean: %s median: %s" %
|
2349 |
(scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
|
|
|
2351 |
docs = []
|
2352 |
scores = []
|
2353 |
|
2354 |
+
if not docs and use_docs_planned and model_name not in non_hf_types:
|
2355 |
# if HF type and have no docs, can bail out
|
2356 |
return docs, None, [], False, have_any_docs
|
2357 |
|
|
|
2374 |
|
2375 |
if len(docs) == 0:
|
2376 |
# avoid context == in prompt then
|
2377 |
+
use_docs_planned = False
|
2378 |
template = template_if_no_docs
|
2379 |
|
2380 |
if langchain_action == LangChainAction.QUERY.value:
|
|
|
2390 |
else:
|
2391 |
# only if use_openai_model = True, unused normally except in testing
|
2392 |
chain = load_qa_with_sources_chain(llm)
|
2393 |
+
if not use_docs_planned:
|
2394 |
chain_kwargs = dict(input_documents=[], question=query)
|
2395 |
else:
|
2396 |
chain_kwargs = dict(input_documents=docs, question=query)
|
|
|
2417 |
else:
|
2418 |
raise RuntimeError("No such langchain_action=%s" % langchain_action)
|
2419 |
|
2420 |
+
return docs, target, scores, use_docs_planned, have_any_docs
|
2421 |
|
2422 |
|
2423 |
def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
|
|
|
2471 |
|
2472 |
def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
|
2473 |
if not chunk:
|
2474 |
+
[x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(sources)]
|
2475 |
return sources
|
2476 |
if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources):
|
2477 |
# if just one document
|
|
|
2490 |
source_chunks = splitter.split_documents(sources)
|
2491 |
|
2492 |
# currently in order, but when pull from db won't be, so mark order and document by hash
|
2493 |
+
[x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)]
|
|
|
2494 |
|
2495 |
return source_chunks
|
2496 |
|
gradio_runner.py
CHANGED
@@ -50,16 +50,20 @@ def fix_pydantic_duplicate_validators_error():
|
|
50 |
|
51 |
fix_pydantic_duplicate_validators_error()
|
52 |
|
53 |
-
from enums import
|
|
|
54 |
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
|
55 |
text_xsm
|
56 |
from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
|
57 |
get_prompt
|
58 |
-
from utils import
|
59 |
-
ping, get_short_name, makedirs, get_kwargs, remove, system_info, ping_gpu, get_url, get_local_ip
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
63 |
|
64 |
from apscheduler.schedulers.background import BackgroundScheduler
|
65 |
|
@@ -94,13 +98,11 @@ def go_gradio(**kwargs):
|
|
94 |
memory_restriction_level = kwargs['memory_restriction_level']
|
95 |
n_gpus = kwargs['n_gpus']
|
96 |
admin_pass = kwargs['admin_pass']
|
97 |
-
model_state0 = kwargs['model_state0']
|
98 |
model_states = kwargs['model_states']
|
99 |
-
score_model_state0 = kwargs['score_model_state0']
|
100 |
dbs = kwargs['dbs']
|
101 |
db_type = kwargs['db_type']
|
102 |
-
visible_langchain_modes = kwargs['visible_langchain_modes']
|
103 |
visible_langchain_actions = kwargs['visible_langchain_actions']
|
|
|
104 |
allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
|
105 |
allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
|
106 |
enable_sources_list = kwargs['enable_sources_list']
|
@@ -111,8 +113,19 @@ def go_gradio(**kwargs):
|
|
111 |
enable_captions = kwargs['enable_captions']
|
112 |
captions_model = kwargs['captions_model']
|
113 |
enable_ocr = kwargs['enable_ocr']
|
|
|
114 |
caption_loader = kwargs['caption_loader']
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
# easy update of kwargs needed for evaluate() etc.
|
117 |
queue = True
|
118 |
allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
|
@@ -132,25 +145,11 @@ def go_gradio(**kwargs):
|
|
132 |
" use Enter for multiple input lines)"
|
133 |
|
134 |
title = 'h2oGPT'
|
135 |
-
|
136 |
-
|
137 |
-
description = f"""Model {kwargs['base_model']} Instruct dataset.
|
138 |
-
For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
|
139 |
-
Command: {str(' '.join(sys.argv))}
|
140 |
-
Hash: {get_githash()}
|
141 |
-
"""
|
142 |
-
else:
|
143 |
-
description = more_info
|
144 |
-
description_bottom = "If this host is busy, try [Multi-Model](https://gpt.h2o.ai), [Falcon 40B](http://falcon.h2o.ai), [HF Spaces1](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
|
145 |
if is_hf:
|
146 |
description_bottom += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
|
147 |
-
|
148 |
-
if kwargs['verbose']:
|
149 |
-
task_info_md = f"""
|
150 |
-
### Task: {kwargs['task_info']}"""
|
151 |
-
else:
|
152 |
-
task_info_md = ''
|
153 |
-
|
154 |
css_code = get_css(kwargs)
|
155 |
|
156 |
if kwargs['gradio_offline_level'] >= 0:
|
@@ -180,9 +179,9 @@ def go_gradio(**kwargs):
|
|
180 |
demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
|
181 |
callback = gr.CSVLogger()
|
182 |
|
183 |
-
|
184 |
-
if kwargs['base_model'].strip() not in
|
185 |
-
|
186 |
lora_options = kwargs['extra_lora_options']
|
187 |
if kwargs['lora_weights'].strip() not in lora_options:
|
188 |
lora_options = [kwargs['lora_weights'].strip()] + lora_options
|
@@ -197,7 +196,7 @@ def go_gradio(**kwargs):
|
|
197 |
|
198 |
# always add in no lora case
|
199 |
# add fake space so doesn't go away in gradio dropdown
|
200 |
-
|
201 |
lora_options = [no_lora_str] + lora_options
|
202 |
server_options = [no_server_str] + server_options
|
203 |
# always add in no model case so can free memory
|
@@ -251,6 +250,14 @@ def go_gradio(**kwargs):
|
|
251 |
# else gets input_list at time of submit that is old, and shows up as truncated in chatbot
|
252 |
return x
|
253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
with demo:
|
255 |
# avoid actual model/tokenizer here or anything that would be bad to deepcopy
|
256 |
# https://github.com/gradio-app/gradio/issues/3558
|
@@ -264,18 +271,32 @@ def go_gradio(**kwargs):
|
|
264 |
prompt_dict=kwargs['prompt_dict'],
|
265 |
)
|
266 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
model_state2 = gr.State(kwargs['model_state_none'].copy())
|
268 |
-
model_options_state = gr.State([
|
269 |
lora_options_state = gr.State([lora_options])
|
270 |
server_options_state = gr.State([server_options])
|
271 |
-
my_db_state = gr.State(
|
272 |
chat_state = gr.State({})
|
273 |
-
docs_state00 = kwargs['document_choice'] + [
|
274 |
docs_state0 = []
|
275 |
[docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
|
276 |
docs_state = gr.State(docs_state0)
|
277 |
viewable_docs_state0 = []
|
278 |
viewable_docs_state = gr.State(viewable_docs_state0)
|
|
|
|
|
|
|
279 |
gr.Markdown(f"""
|
280 |
{get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)}
|
281 |
""")
|
@@ -289,7 +310,7 @@ def go_gradio(**kwargs):
|
|
289 |
'model_lock'] else "Response Scores: %s" % nas
|
290 |
|
291 |
if kwargs['langchain_mode'] != LangChainMode.DISABLED.value:
|
292 |
-
extra_prompt_form = ". For summarization,
|
293 |
else:
|
294 |
extra_prompt_form = ""
|
295 |
if kwargs['input_lines'] > 1:
|
@@ -297,6 +318,34 @@ def go_gradio(**kwargs):
|
|
297 |
else:
|
298 |
instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
|
299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
normal_block = gr.Row(visible=not base_wanted, equal_height=False)
|
301 |
with normal_block:
|
302 |
side_bar = gr.Column(elem_id="col_container", scale=1, min_width=100)
|
@@ -317,6 +366,7 @@ def go_gradio(**kwargs):
|
|
317 |
scale=1,
|
318 |
min_width=0,
|
319 |
elem_id="warning", elem_classes="feedback")
|
|
|
320 |
url_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload
|
321 |
url_label = 'URL/ArXiv' if have_arxiv else 'URL'
|
322 |
url_text = gr.Textbox(label=url_label,
|
@@ -330,29 +380,20 @@ def go_gradio(**kwargs):
|
|
330 |
visible=text_visible)
|
331 |
github_textbox = gr.Textbox(label="Github URL", visible=False) # FIXME WIP
|
332 |
database_visible = kwargs['langchain_mode'] != 'Disabled'
|
333 |
-
with gr.Accordion("
|
334 |
-
|
335 |
-
# don't show 'wiki' since only usually useful for internal testing at moment
|
336 |
-
no_show_modes = ['Disabled', 'wiki']
|
337 |
-
else:
|
338 |
-
no_show_modes = ['Disabled']
|
339 |
-
allowed_modes = visible_langchain_modes.copy()
|
340 |
-
allowed_modes = [x for x in allowed_modes if x in dbs]
|
341 |
-
allowed_modes += ['ChatLLM', 'LLM']
|
342 |
-
if allow_upload_to_my_data and 'MyData' not in allowed_modes:
|
343 |
-
allowed_modes += ['MyData']
|
344 |
-
if allow_upload_to_user_data and 'UserData' not in allowed_modes:
|
345 |
-
allowed_modes += ['UserData']
|
346 |
langchain_mode = gr.Radio(
|
347 |
-
|
348 |
value=kwargs['langchain_mode'],
|
349 |
label="Collections",
|
350 |
show_label=True,
|
351 |
visible=kwargs['langchain_mode'] != 'Disabled',
|
352 |
min_width=100)
|
353 |
-
|
|
|
|
|
354 |
label="Subset",
|
355 |
-
value=
|
356 |
interactive=True,
|
357 |
)
|
358 |
allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
|
@@ -361,6 +402,14 @@ def go_gradio(**kwargs):
|
|
361 |
value=allowed_actions[0] if len(allowed_actions) > 0 else None,
|
362 |
label="Action",
|
363 |
visible=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
col_tabs = gr.Column(elem_id="col_container", scale=10)
|
365 |
with (col_tabs, gr.Tabs()):
|
366 |
with gr.TabItem("Chat"):
|
@@ -408,9 +457,9 @@ def go_gradio(**kwargs):
|
|
408 |
mw1 = 50
|
409 |
mw2 = 50
|
410 |
with gr.Column(min_width=mw1):
|
411 |
-
submit = gr.Button(value='Submit', variant='primary',
|
412 |
min_width=mw1)
|
413 |
-
stop_btn = gr.Button(value="Stop", variant='secondary',
|
414 |
min_width=mw1)
|
415 |
save_chat_btn = gr.Button("Save", size='sm', min_width=mw1)
|
416 |
with gr.Column(min_width=mw2):
|
@@ -431,20 +480,50 @@ def go_gradio(**kwargs):
|
|
431 |
with gr.TabItem("Document Selection"):
|
432 |
document_choice = gr.Dropdown(docs_state0,
|
433 |
label="Select Subset of Document(s) %s" % file_types_str,
|
434 |
-
value=
|
435 |
interactive=True,
|
436 |
multiselect=True,
|
437 |
visible=kwargs['langchain_mode'] != 'Disabled',
|
438 |
)
|
439 |
sources_visible = kwargs['langchain_mode'] != 'Disabled' and enable_sources_list
|
440 |
with gr.Row():
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
|
449 |
sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list,
|
450 |
equal_height=False)
|
@@ -469,6 +548,7 @@ def go_gradio(**kwargs):
|
|
469 |
value=None,
|
470 |
interactive=True,
|
471 |
multiselect=False,
|
|
|
472 |
)
|
473 |
with gr.Column(scale=4):
|
474 |
pass
|
@@ -713,19 +793,20 @@ def go_gradio(**kwargs):
|
|
713 |
side_bar_btn = gr.Button("Toggle SideBar", variant="secondary", size="sm")
|
714 |
submit_buttons_btn = gr.Button("Toggle Submit Buttons", variant="secondary", size="sm")
|
715 |
col_tabs_scale = gr.Slider(minimum=1, maximum=20, value=10, step=1, label='Window Size')
|
716 |
-
text_outputs_height = gr.Slider(minimum=100, maximum=
|
717 |
-
step=
|
718 |
dark_mode_btn = gr.Button("Dark Mode", variant="secondary", size="sm")
|
719 |
with gr.Column(scale=4):
|
720 |
pass
|
|
|
721 |
admin_row = gr.Row()
|
722 |
with admin_row:
|
723 |
with gr.Column(scale=1):
|
724 |
-
admin_pass_textbox = gr.Textbox(label="Admin Password", type='password',
|
725 |
-
|
726 |
with gr.Column(scale=4):
|
727 |
pass
|
728 |
-
system_row = gr.Row(visible=
|
729 |
with system_row:
|
730 |
with gr.Column():
|
731 |
with gr.Row():
|
@@ -789,23 +870,24 @@ def go_gradio(**kwargs):
|
|
789 |
else:
|
790 |
return tuple([gr.update(interactive=True)] * len(args))
|
791 |
|
792 |
-
# Add to UserData
|
793 |
update_db_func = functools.partial(update_user_db,
|
794 |
dbs=dbs,
|
795 |
db_type=db_type,
|
796 |
use_openai_embedding=use_openai_embedding,
|
797 |
hf_embedding_model=hf_embedding_model,
|
798 |
-
enable_captions=enable_captions,
|
799 |
captions_model=captions_model,
|
800 |
-
|
801 |
caption_loader=caption_loader,
|
|
|
|
|
802 |
verbose=kwargs['verbose'],
|
803 |
-
user_path=kwargs['user_path'],
|
804 |
n_jobs=kwargs['n_jobs'],
|
805 |
)
|
806 |
add_file_outputs = [fileup_output, langchain_mode]
|
807 |
add_file_kwargs = dict(fn=update_db_func,
|
808 |
-
inputs=[fileup_output, my_db_state, chunk, chunk_size,
|
|
|
809 |
outputs=add_file_outputs + [sources_text, doc_exception_text],
|
810 |
queue=queue,
|
811 |
api_name='add_file' if allow_api and allow_upload_to_user_data else None)
|
@@ -817,6 +899,15 @@ def go_gradio(**kwargs):
|
|
817 |
eventdb1b = eventdb1.then(make_interactive, inputs=add_file_outputs, outputs=add_file_outputs,
|
818 |
show_progress='minimal')
|
819 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
820 |
# note for update_user_db_func output is ignored for db
|
821 |
|
822 |
def clear_textbox():
|
@@ -826,7 +917,8 @@ def go_gradio(**kwargs):
|
|
826 |
|
827 |
add_url_outputs = [url_text, langchain_mode]
|
828 |
add_url_kwargs = dict(fn=update_user_db_url_func,
|
829 |
-
inputs=[url_text, my_db_state, chunk, chunk_size,
|
|
|
830 |
outputs=add_url_outputs + [sources_text, doc_exception_text],
|
831 |
queue=queue,
|
832 |
api_name='add_url' if allow_api and allow_upload_to_user_data else None)
|
@@ -843,7 +935,8 @@ def go_gradio(**kwargs):
|
|
843 |
update_user_db_txt_func = functools.partial(update_db_func, is_txt=True)
|
844 |
add_text_outputs = [user_text_text, langchain_mode]
|
845 |
add_text_kwargs = dict(fn=update_user_db_txt_func,
|
846 |
-
inputs=[user_text_text, my_db_state, chunk, chunk_size,
|
|
|
847 |
outputs=add_text_outputs + [sources_text, doc_exception_text],
|
848 |
queue=queue,
|
849 |
api_name='add_text' if allow_api and allow_upload_to_user_data else None
|
@@ -855,7 +948,7 @@ def go_gradio(**kwargs):
|
|
855 |
eventdb3 = eventdb3b.then(**add_text_kwargs, show_progress='full')
|
856 |
eventdb3c = eventdb3.then(make_interactive, inputs=add_text_outputs, outputs=add_text_outputs,
|
857 |
show_progress='minimal')
|
858 |
-
db_events = [eventdb1a, eventdb1, eventdb1b,
|
859 |
eventdb2a, eventdb2, eventdb2b, eventdb2c,
|
860 |
eventdb3a, eventdb3b, eventdb3, eventdb3c]
|
861 |
|
@@ -863,14 +956,14 @@ def go_gradio(**kwargs):
|
|
863 |
|
864 |
# if change collection source, must clear doc selections from it to avoid inconsistency
|
865 |
def clear_doc_choice():
|
866 |
-
return gr.Dropdown.update(choices=docs_state0, value=
|
867 |
|
868 |
langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice, queue=False)
|
869 |
|
870 |
def resize_col_tabs(x):
|
871 |
return gr.Dropdown.update(scale=x)
|
872 |
|
873 |
-
col_tabs_scale.change(fn=resize_col_tabs, inputs=col_tabs_scale, outputs=col_tabs)
|
874 |
|
875 |
def resize_chatbots(x, num_model_lock=0):
|
876 |
if num_model_lock == 0:
|
@@ -881,7 +974,7 @@ def go_gradio(**kwargs):
|
|
881 |
|
882 |
resize_chatbots_func = functools.partial(resize_chatbots, num_model_lock=len(text_outputs))
|
883 |
text_outputs_height.change(fn=resize_chatbots_func, inputs=text_outputs_height,
|
884 |
-
outputs=[text_output, text_output2] + text_outputs)
|
885 |
|
886 |
def update_dropdown(x):
|
887 |
return gr.Dropdown.update(choices=x, value=[docs_state0[0]])
|
@@ -972,7 +1065,8 @@ def go_gradio(**kwargs):
|
|
972 |
if file.startswith('http') or file.startswith('https'):
|
973 |
# if file is online, then might as well use google(?)
|
974 |
document1 = file
|
975 |
-
return gr.update(visible=True,
|
|
|
976 |
</iframe>
|
977 |
"""), dummy1, dummy1, dummy1
|
978 |
else:
|
@@ -995,9 +1089,11 @@ def go_gradio(**kwargs):
|
|
995 |
|
996 |
refresh_sources1 = functools.partial(update_and_get_source_files_given_langchain_mode,
|
997 |
**get_kwargs(update_and_get_source_files_given_langchain_mode,
|
998 |
-
exclude_names=['
|
|
|
999 |
**all_kwargs))
|
1000 |
-
eventdb9 = refresh_sources_btn.click(fn=refresh_sources1,
|
|
|
1001 |
outputs=sources_text,
|
1002 |
api_name='refresh_sources' if allow_api else None)
|
1003 |
|
@@ -1007,9 +1103,153 @@ def go_gradio(**kwargs):
|
|
1007 |
def close_admin(x):
|
1008 |
return gr.update(visible=not (x == admin_pass))
|
1009 |
|
1010 |
-
|
1011 |
.then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
|
1012 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1013 |
inputs_list, inputs_dict = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=1)
|
1014 |
inputs_list2, inputs_dict2 = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=2)
|
1015 |
from functools import partial
|
@@ -1021,11 +1261,11 @@ def go_gradio(**kwargs):
|
|
1021 |
def evaluate_nochat(*args1, default_kwargs1=None, str_api=False, **kwargs1):
|
1022 |
args_list = list(args1)
|
1023 |
if str_api:
|
1024 |
-
user_kwargs = args_list[
|
1025 |
assert isinstance(user_kwargs, str)
|
1026 |
user_kwargs = ast.literal_eval(user_kwargs)
|
1027 |
else:
|
1028 |
-
user_kwargs = {k: v for k, v in zip(eval_func_param_names, args_list[
|
1029 |
# only used for submit_nochat_api
|
1030 |
user_kwargs['chat'] = False
|
1031 |
if 'stream_output' not in user_kwargs:
|
@@ -1035,6 +1275,8 @@ def go_gradio(**kwargs):
|
|
1035 |
user_kwargs['langchain_mode'] = 'Disabled'
|
1036 |
if 'langchain_action' not in user_kwargs:
|
1037 |
user_kwargs['langchain_action'] = LangChainAction.QUERY.value
|
|
|
|
|
1038 |
|
1039 |
set1 = set(list(default_kwargs1.keys()))
|
1040 |
set2 = set(eval_func_param_names)
|
@@ -1042,10 +1284,11 @@ def go_gradio(**kwargs):
|
|
1042 |
# correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
|
1043 |
model_state1 = args_list[0]
|
1044 |
my_db_state1 = args_list[1]
|
|
|
1045 |
args_list = [user_kwargs[k] if k in user_kwargs and user_kwargs[k] is not None else default_kwargs1[k] for k
|
1046 |
in eval_func_param_names]
|
1047 |
assert len(args_list) == len(eval_func_param_names)
|
1048 |
-
args_list = [model_state1, my_db_state1] + args_list
|
1049 |
|
1050 |
try:
|
1051 |
for res_dict in evaluate(*tuple(args_list), **kwargs1):
|
@@ -1216,6 +1459,7 @@ def go_gradio(**kwargs):
|
|
1216 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
1217 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1218 |
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
|
|
1219 |
document_subset1 = args_list[eval_func_param_names.index('document_subset')]
|
1220 |
document_choice1 = args_list[eval_func_param_names.index('document_choice')]
|
1221 |
if not prompt_type1:
|
@@ -1248,10 +1492,7 @@ def go_gradio(**kwargs):
|
|
1248 |
history[-1][1] = None
|
1249 |
return history
|
1250 |
if user_message1 in ['', None, '\n']:
|
1251 |
-
if
|
1252 |
-
DocumentChoices.All.name != document_subset1 \
|
1253 |
-
or \
|
1254 |
-
langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
|
1255 |
# reject non-retry submit/enter
|
1256 |
return history
|
1257 |
user_message1 = fix_text_for_gradio(user_message1)
|
@@ -1298,10 +1539,12 @@ def go_gradio(**kwargs):
|
|
1298 |
API only called for which_model=0, default for inputs_list, but rest should ignore inputs_list
|
1299 |
:return: last element is True if should run bot, False if should just yield history
|
1300 |
"""
|
|
|
1301 |
# don't deepcopy, can contain model itself
|
1302 |
args_list = list(args).copy()
|
1303 |
-
model_state1 = args_list[-
|
1304 |
-
my_db_state1 = args_list[-
|
|
|
1305 |
history = args_list[-1]
|
1306 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
1307 |
prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
|
@@ -1309,9 +1552,11 @@ def go_gradio(**kwargs):
|
|
1309 |
if model_state1['model'] is None or model_state1['model'] == no_model_str:
|
1310 |
return history, None, None, None
|
1311 |
|
1312 |
-
args_list = args_list[:-
|
1313 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
|
|
1314 |
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
|
|
1315 |
document_subset1 = args_list[eval_func_param_names.index('document_subset')]
|
1316 |
document_choice1 = args_list[eval_func_param_names.index('document_choice')]
|
1317 |
if not history:
|
@@ -1324,10 +1569,7 @@ def go_gradio(**kwargs):
|
|
1324 |
instruction1 = history[-1][0]
|
1325 |
history[-1][1] = None
|
1326 |
elif not instruction1:
|
1327 |
-
if
|
1328 |
-
DocumentChoices.All.name != document_choice1 \
|
1329 |
-
or \
|
1330 |
-
langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
|
1331 |
# if not retrying, then reject empty query
|
1332 |
return history, None, None, None
|
1333 |
elif len(history) > 0 and history[-1][1] not in [None, '']:
|
@@ -1344,7 +1586,9 @@ def go_gradio(**kwargs):
|
|
1344 |
|
1345 |
chat1 = args_list[eval_func_param_names.index('chat')]
|
1346 |
model_max_length1 = get_model_max_length(model_state1)
|
1347 |
-
context1 = history_to_context(history, langchain_mode1,
|
|
|
|
|
1348 |
model_max_length1, memory_restriction_level,
|
1349 |
kwargs['keep_sources_in_context'])
|
1350 |
args_list[0] = instruction1 # override original instruction with history from user
|
@@ -1353,6 +1597,7 @@ def go_gradio(**kwargs):
|
|
1353 |
fun1 = partial(evaluate,
|
1354 |
model_state1,
|
1355 |
my_db_state1,
|
|
|
1356 |
*tuple(args_list),
|
1357 |
**kwargs_evaluate)
|
1358 |
|
@@ -1398,24 +1643,26 @@ def go_gradio(**kwargs):
|
|
1398 |
clear_torch_cache()
|
1399 |
return
|
1400 |
|
1401 |
-
def clear_embeddings(langchain_mode1,
|
1402 |
# clear any use of embedding that sits on GPU, else keeps accumulating GPU usage even if clear torch cache
|
1403 |
-
if db_type == 'chroma' and langchain_mode1 not in ['
|
1404 |
from gpt_langchain import clear_embedding
|
1405 |
db = dbs.get('langchain_mode1')
|
1406 |
if db is not None and not isinstance(db, str):
|
1407 |
clear_embedding(db)
|
1408 |
-
if
|
1409 |
-
|
|
|
|
|
1410 |
|
1411 |
def bot(*args, retry=False):
|
1412 |
-
history, fun1, langchain_mode1,
|
1413 |
try:
|
1414 |
for res in get_response(fun1, history):
|
1415 |
yield res
|
1416 |
finally:
|
1417 |
clear_torch_cache()
|
1418 |
-
clear_embeddings(langchain_mode1,
|
1419 |
|
1420 |
def all_bot(*args, retry=False, model_states1=None):
|
1421 |
args_list = list(args).copy()
|
@@ -1425,12 +1672,14 @@ def go_gradio(**kwargs):
|
|
1425 |
stream_output1 = args_list[eval_func_param_names.index('stream_output')]
|
1426 |
max_time1 = args_list[eval_func_param_names.index('max_time')]
|
1427 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1428 |
-
|
|
|
1429 |
try:
|
1430 |
gen_list = []
|
1431 |
for chatboti, (chatbot1, model_state1) in enumerate(zip(chatbots, model_states1)):
|
1432 |
args_list1 = args_list0.copy()
|
1433 |
-
args_list1.insert(-
|
|
|
1434 |
# if at start, have None in response still, replace with '' so client etc. acts like normal
|
1435 |
# assumes other parts of code treat '' and None as if no response yet from bot
|
1436 |
# can't do this later in bot code as racy with threaded generators
|
@@ -1440,8 +1689,8 @@ def go_gradio(**kwargs):
|
|
1440 |
# so consistent with prep_bot()
|
1441 |
# with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
|
1442 |
# langchain_mode1 and my_db_state1 should be same for every bot
|
1443 |
-
history, fun1, langchain_mode1,
|
1444 |
-
|
1445 |
gen1 = get_response(fun1, history)
|
1446 |
if stream_output1:
|
1447 |
gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
|
@@ -1487,7 +1736,7 @@ def go_gradio(**kwargs):
|
|
1487 |
print("Generate exceptions: %s" % exceptions, flush=True)
|
1488 |
finally:
|
1489 |
clear_torch_cache()
|
1490 |
-
clear_embeddings(langchain_mode1,
|
1491 |
|
1492 |
# NORMAL MODEL
|
1493 |
user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
|
@@ -1495,11 +1744,11 @@ def go_gradio(**kwargs):
|
|
1495 |
outputs=text_output,
|
1496 |
)
|
1497 |
bot_args = dict(fn=bot,
|
1498 |
-
inputs=inputs_list + [model_state, my_db_state] + [text_output],
|
1499 |
outputs=[text_output, chat_exception_text],
|
1500 |
)
|
1501 |
retry_bot_args = dict(fn=functools.partial(bot, retry=True),
|
1502 |
-
inputs=inputs_list + [model_state, my_db_state] + [text_output],
|
1503 |
outputs=[text_output, chat_exception_text],
|
1504 |
)
|
1505 |
retry_user_args = dict(fn=functools.partial(user, retry=True),
|
@@ -1517,11 +1766,11 @@ def go_gradio(**kwargs):
|
|
1517 |
outputs=text_output2,
|
1518 |
)
|
1519 |
bot_args2 = dict(fn=bot,
|
1520 |
-
inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
|
1521 |
outputs=[text_output2, chat_exception_text],
|
1522 |
)
|
1523 |
retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
|
1524 |
-
inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
|
1525 |
outputs=[text_output2, chat_exception_text],
|
1526 |
)
|
1527 |
retry_user_args2 = dict(fn=functools.partial(user, retry=True),
|
@@ -1542,11 +1791,11 @@ def go_gradio(**kwargs):
|
|
1542 |
outputs=text_outputs,
|
1543 |
)
|
1544 |
all_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states),
|
1545 |
-
inputs=inputs_list + [my_db_state] + text_outputs,
|
1546 |
outputs=text_outputs + [chat_exception_text],
|
1547 |
)
|
1548 |
all_retry_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, retry=True),
|
1549 |
-
inputs=inputs_list + [my_db_state] + text_outputs,
|
1550 |
outputs=text_outputs + [chat_exception_text],
|
1551 |
)
|
1552 |
all_retry_user_args = dict(fn=functools.partial(all_user, retry=True,
|
@@ -1708,6 +1957,11 @@ def go_gradio(**kwargs):
|
|
1708 |
def get_short_chat(x, short_chats, short_len=20, words=4):
|
1709 |
if x and len(x[0]) == 2 and x[0][0] is not None:
|
1710 |
short_chat = ' '.join(x[0][0][:short_len].split(' ')[:words]).strip()
|
|
|
|
|
|
|
|
|
|
|
1711 |
short_chat = dedup(short_chat, short_chats)
|
1712 |
else:
|
1713 |
short_chat = None
|
@@ -1775,14 +2029,12 @@ def go_gradio(**kwargs):
|
|
1775 |
already_exists = any([is_chat_same(chat_list, x) for x in old_chat_lists])
|
1776 |
if not already_exists:
|
1777 |
chat_state1[short_chat] = chat_list.copy()
|
1778 |
-
|
1779 |
-
|
1780 |
-
|
1781 |
-
|
1782 |
-
|
1783 |
-
|
1784 |
-
ret_list = [chat_list] + [chat_state1]
|
1785 |
-
return tuple(ret_list)
|
1786 |
|
1787 |
def switch_chat(chat_key, chat_state1, num_model_lock=0):
|
1788 |
chosen_chat = chat_state1[chat_key]
|
@@ -1813,7 +2065,7 @@ def go_gradio(**kwargs):
|
|
1813 |
|
1814 |
remove_chat_event = remove_chat_btn.click(remove_chat,
|
1815 |
inputs=[radio_chats, chat_state], outputs=[radio_chats, chat_state],
|
1816 |
-
queue=False)
|
1817 |
|
1818 |
def get_chats1(chat_state1):
|
1819 |
base = 'chats'
|
@@ -1844,7 +2096,7 @@ def go_gradio(**kwargs):
|
|
1844 |
new_chats = json.loads(f.read())
|
1845 |
for chat1_k, chat1_v in new_chats.items():
|
1846 |
# ignore chat1_k, regenerate and de-dup to avoid loss
|
1847 |
-
|
1848 |
except BaseException as e:
|
1849 |
t, v, tb = sys.exc_info()
|
1850 |
ex = ''.join(traceback.format_exception(t, v, tb))
|
@@ -1870,24 +2122,17 @@ def go_gradio(**kwargs):
|
|
1870 |
.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \
|
1871 |
.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
|
1872 |
|
1873 |
-
def update_radio_chats(chat_state1):
|
1874 |
-
# reverse so newest at top
|
1875 |
-
choices = list(chat_state1.keys()).copy()
|
1876 |
-
choices.reverse()
|
1877 |
-
return gr.update(choices=choices, value=None)
|
1878 |
-
|
1879 |
clear_event = save_chat_btn.click(save_chat,
|
1880 |
inputs=[text_output, text_output2] + text_outputs + [chat_state],
|
1881 |
-
outputs=[
|
1882 |
-
api_name='save_chat' if allow_api else None)
|
1883 |
-
|
1884 |
-
|
1885 |
-
.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
|
1886 |
|
1887 |
# NOTE: clear of instruction/iinput for nochat has to come after score,
|
1888 |
# because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
|
1889 |
no_chat_args = dict(fn=fun,
|
1890 |
-
inputs=[model_state, my_db_state] + inputs_list,
|
1891 |
outputs=text_output_nochat,
|
1892 |
queue=queue,
|
1893 |
)
|
@@ -1906,7 +2151,8 @@ def go_gradio(**kwargs):
|
|
1906 |
.then(clear_torch_cache)
|
1907 |
|
1908 |
submit_event_nochat_api = submit_nochat_api.click(fun_with_dict_str,
|
1909 |
-
inputs=[model_state, my_db_state,
|
|
|
1910 |
outputs=text_output_nochat_api,
|
1911 |
queue=True, # required for generator
|
1912 |
api_name='submit_nochat_api' if allow_api else None) \
|
@@ -2156,6 +2402,8 @@ def go_gradio(**kwargs):
|
|
2156 |
print("Exception: %s" % str(e), flush=True)
|
2157 |
return json.dumps(sys_dict)
|
2158 |
|
|
|
|
|
2159 |
get_system_info_dict_func = functools.partial(get_system_info_dict, **all_kwargs)
|
2160 |
|
2161 |
system_dict_event = system_btn2.click(get_system_info_dict_func,
|
@@ -2185,12 +2433,15 @@ def go_gradio(**kwargs):
|
|
2185 |
else:
|
2186 |
tokenizer = None
|
2187 |
if tokenizer is not None:
|
2188 |
-
langchain_mode1 = '
|
|
|
2189 |
# fake user message to mimic bot()
|
2190 |
chat1 = copy.deepcopy(chat1)
|
2191 |
chat1 = chat1 + [['user_message1', None]]
|
2192 |
model_max_length1 = tokenizer.model_max_length
|
2193 |
-
context1 = history_to_context(chat1, langchain_mode1,
|
|
|
|
|
2194 |
model_max_length1,
|
2195 |
memory_restriction_level1, keep_sources_in_context1)
|
2196 |
return str(tokenizer(context1, return_tensors="pt")['input_ids'].shape[1])
|
@@ -2220,7 +2471,7 @@ def go_gradio(**kwargs):
|
|
2220 |
,
|
2221 |
queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
|
2222 |
|
2223 |
-
demo.load(None, None, None, _js=get_dark_js() if kwargs['
|
2224 |
|
2225 |
demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
|
2226 |
favicon_path = "h2o-logo.svg"
|
@@ -2235,7 +2486,8 @@ def go_gradio(**kwargs):
|
|
2235 |
# FIXME: disable for gptj, langchain or gpt4all modify print itself
|
2236 |
# FIXME: and any multi-threaded/async print will enter model output!
|
2237 |
scheduler.add_job(func=ping, trigger="interval", seconds=60)
|
2238 |
-
|
|
|
2239 |
scheduler.start()
|
2240 |
|
2241 |
# import control
|
@@ -2254,9 +2506,6 @@ def go_gradio(**kwargs):
|
|
2254 |
demo.block_thread()
|
2255 |
|
2256 |
|
2257 |
-
input_args_list = ['model_state', 'my_db_state']
|
2258 |
-
|
2259 |
-
|
2260 |
def get_inputs_list(inputs_dict, model_lower, model_id=1):
|
2261 |
"""
|
2262 |
map gradio objects in locals() to inputs for evaluate().
|
@@ -2290,8 +2539,9 @@ def get_inputs_list(inputs_dict, model_lower, model_id=1):
|
|
2290 |
return inputs_list, inputs_dict_out
|
2291 |
|
2292 |
|
2293 |
-
def get_sources(
|
2294 |
-
|
|
|
2295 |
|
2296 |
if langchain_mode in ['ChatLLM', 'LLM']:
|
2297 |
source_files_added = "NA"
|
@@ -2300,7 +2550,8 @@ def get_sources(db1, langchain_mode, dbs=None, docs_state0=None):
|
|
2300 |
source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
|
2301 |
" Ask jon.mckinney@h2o.ai for file if required."
|
2302 |
source_list = []
|
2303 |
-
elif langchain_mode
|
|
|
2304 |
from gpt_langchain import get_metadatas
|
2305 |
metadatas = get_metadatas(db1[0])
|
2306 |
source_list = sorted(set([x['source'] for x in metadatas]))
|
@@ -2331,14 +2582,13 @@ def set_userid(db1):
|
|
2331 |
db1[1] = str(uuid.uuid4())
|
2332 |
|
2333 |
|
2334 |
-
def update_user_db(file,
|
2335 |
-
|
2336 |
-
|
2337 |
if file is None:
|
2338 |
raise RuntimeError("Don't use change, use input")
|
2339 |
|
2340 |
try:
|
2341 |
-
return _update_user_db(file,
|
2342 |
langchain_mode=langchain_mode, dbs=dbs,
|
2343 |
**kwargs)
|
2344 |
except BaseException as e:
|
@@ -2369,25 +2619,30 @@ def get_lock_file(db1, langchain_mode):
|
|
2369 |
user_id = db1[1]
|
2370 |
base_path = 'locks'
|
2371 |
makedirs(base_path)
|
2372 |
-
lock_file = "db_%s_%s.lock" % (langchain_mode.replace(' ', '_'), user_id)
|
2373 |
return lock_file
|
2374 |
|
2375 |
|
2376 |
def _update_user_db(file,
|
2377 |
-
|
2378 |
chunk=None, chunk_size=None,
|
2379 |
-
dbs=None, db_type=None,
|
2380 |
-
|
|
|
|
|
|
|
2381 |
use_openai_embedding=None,
|
2382 |
hf_embedding_model=None,
|
2383 |
caption_loader=None,
|
2384 |
enable_captions=None,
|
2385 |
captions_model=None,
|
2386 |
enable_ocr=None,
|
|
|
2387 |
verbose=None,
|
|
|
2388 |
is_url=None, is_txt=None,
|
2389 |
-
|
2390 |
-
assert
|
2391 |
assert chunk is not None
|
2392 |
assert chunk_size is not None
|
2393 |
assert use_openai_embedding is not None
|
@@ -2396,10 +2651,9 @@ def _update_user_db(file,
|
|
2396 |
assert enable_captions is not None
|
2397 |
assert captions_model is not None
|
2398 |
assert enable_ocr is not None
|
|
|
2399 |
assert verbose is not None
|
2400 |
|
2401 |
-
set_userid(db1)
|
2402 |
-
|
2403 |
if dbs is None:
|
2404 |
dbs = {}
|
2405 |
assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs))
|
@@ -2417,17 +2671,22 @@ def _update_user_db(file,
|
|
2417 |
if langchain_mode == LangChainMode.DISABLED.value:
|
2418 |
return None, langchain_mode, get_source_files(), ""
|
2419 |
|
2420 |
-
if langchain_mode in [LangChainMode.
|
2421 |
# then switch to MyData, so langchain_mode also becomes way to select where upload goes
|
2422 |
# but default to mydata if nothing chosen, since safest
|
2423 |
-
|
2424 |
-
|
2425 |
-
|
|
|
|
|
|
|
|
|
|
|
2426 |
# move temp files from gradio upload to stable location
|
2427 |
for fili, fil in enumerate(file):
|
2428 |
-
if isinstance(fil, str):
|
2429 |
-
|
2430 |
-
|
2431 |
if os.path.isfile(new_fil):
|
2432 |
remove(new_fil)
|
2433 |
try:
|
@@ -2447,15 +2706,22 @@ def _update_user_db(file,
|
|
2447 |
enable_captions=enable_captions,
|
2448 |
captions_model=captions_model,
|
2449 |
enable_ocr=enable_ocr,
|
|
|
2450 |
caption_loader=caption_loader,
|
2451 |
)
|
2452 |
exceptions = [x for x in sources if x.metadata.get('exception')]
|
2453 |
exceptions_strs = [x.metadata['exception'] for x in exceptions]
|
2454 |
sources = [x for x in sources if 'exception' not in x.metadata]
|
2455 |
|
2456 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
2457 |
with filelock.FileLock(lock_file):
|
2458 |
-
if langchain_mode
|
2459 |
if db1[0] is not None:
|
2460 |
# then add
|
2461 |
db, num_new_sources, new_sources_metadata = add_to_db(db1[0], sources, db_type=db_type,
|
@@ -2465,7 +2731,8 @@ def _update_user_db(file,
|
|
2465 |
# in testing expect:
|
2466 |
# assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1
|
2467 |
# for production hit, when user gets clicky:
|
2468 |
-
assert len(db1) == 2, "Bad
|
|
|
2469 |
# then create
|
2470 |
# if added has to original state and didn't change, then would be shared db for all users
|
2471 |
persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
|
@@ -2487,7 +2754,7 @@ def _update_user_db(file,
|
|
2487 |
use_openai_embedding=use_openai_embedding,
|
2488 |
hf_embedding_model=hf_embedding_model)
|
2489 |
else:
|
2490 |
-
# then create
|
2491 |
db = get_db(sources, use_openai_embedding=use_openai_embedding,
|
2492 |
db_type=db_type,
|
2493 |
persist_directory=persist_directory,
|
@@ -2501,14 +2768,15 @@ def _update_user_db(file,
|
|
2501 |
return None, langchain_mode, source_files_added, '\n'.join(exceptions_strs)
|
2502 |
|
2503 |
|
2504 |
-
def get_db(
|
2505 |
-
|
|
|
2506 |
|
2507 |
with filelock.FileLock(lock_file):
|
2508 |
if langchain_mode in ['wiki_full']:
|
2509 |
# NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
|
2510 |
db = None
|
2511 |
-
elif langchain_mode
|
2512 |
db = db1[0]
|
2513 |
elif dbs is not None and langchain_mode in dbs and dbs[langchain_mode] is not None:
|
2514 |
db = dbs[langchain_mode]
|
@@ -2517,8 +2785,8 @@ def get_db(db1, langchain_mode, dbs=None):
|
|
2517 |
return db
|
2518 |
|
2519 |
|
2520 |
-
def get_source_files_given_langchain_mode(
|
2521 |
-
db = get_db(
|
2522 |
if langchain_mode in ['ChatLLM', 'LLM'] or db is None:
|
2523 |
return "Sources: N/A"
|
2524 |
return get_source_files(db=db, exceptions=None)
|
@@ -2617,11 +2885,19 @@ def get_source_files(db=None, exceptions=None, metadatas=None):
|
|
2617 |
return source_files_added
|
2618 |
|
2619 |
|
2620 |
-
def update_and_get_source_files_given_langchain_mode(
|
2621 |
-
|
2622 |
-
|
|
|
2623 |
n_jobs=None, verbose=None):
|
2624 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2625 |
|
2626 |
from gpt_langchain import make_db
|
2627 |
db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=False,
|
@@ -2630,11 +2906,27 @@ def update_and_get_source_files_given_langchain_mode(db1, langchain_mode, dbs=No
|
|
2630 |
chunk=chunk,
|
2631 |
chunk_size=chunk_size,
|
2632 |
langchain_mode=langchain_mode,
|
2633 |
-
|
2634 |
db_type=db_type,
|
2635 |
load_db_if_exists=load_db_if_exists,
|
2636 |
db=db,
|
2637 |
n_jobs=n_jobs,
|
2638 |
verbose=verbose)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2639 |
# return only new sources with text saying such
|
2640 |
return get_source_files(db=None, exceptions=None, metadatas=new_sources_metadata)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
fix_pydantic_duplicate_validators_error()
|
52 |
|
53 |
+
from enums import DocumentSubset, no_model_str, no_lora_str, no_server_str, LangChainAction, LangChainMode, \
|
54 |
+
DocumentChoice, langchain_modes_intrinsic
|
55 |
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
|
56 |
text_xsm
|
57 |
from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
|
58 |
get_prompt
|
59 |
+
from utils import flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
60 |
+
ping, get_short_name, makedirs, get_kwargs, remove, system_info, ping_gpu, get_url, get_local_ip, \
|
61 |
+
save_collection_names
|
62 |
+
from gen import get_model, languages_covered, evaluate, score_qa, inputs_kwargs_list, scratch_base_dir, \
|
63 |
+
get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions, langchain_agents_list, \
|
64 |
+
update_langchain
|
65 |
+
from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults, \
|
66 |
+
input_args_list
|
67 |
|
68 |
from apscheduler.schedulers.background import BackgroundScheduler
|
69 |
|
|
|
98 |
memory_restriction_level = kwargs['memory_restriction_level']
|
99 |
n_gpus = kwargs['n_gpus']
|
100 |
admin_pass = kwargs['admin_pass']
|
|
|
101 |
model_states = kwargs['model_states']
|
|
|
102 |
dbs = kwargs['dbs']
|
103 |
db_type = kwargs['db_type']
|
|
|
104 |
visible_langchain_actions = kwargs['visible_langchain_actions']
|
105 |
+
visible_langchain_agents = kwargs['visible_langchain_agents']
|
106 |
allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
|
107 |
allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
|
108 |
enable_sources_list = kwargs['enable_sources_list']
|
|
|
113 |
enable_captions = kwargs['enable_captions']
|
114 |
captions_model = kwargs['captions_model']
|
115 |
enable_ocr = kwargs['enable_ocr']
|
116 |
+
enable_pdf_ocr = kwargs['enable_pdf_ocr']
|
117 |
caption_loader = kwargs['caption_loader']
|
118 |
|
119 |
+
# for dynamic state per user session in gradio
|
120 |
+
model_state0 = kwargs['model_state0']
|
121 |
+
score_model_state0 = kwargs['score_model_state0']
|
122 |
+
my_db_state0 = kwargs['my_db_state0']
|
123 |
+
selection_docs_state0 = kwargs['selection_docs_state0']
|
124 |
+
# for evaluate defaults
|
125 |
+
langchain_modes0 = kwargs['langchain_modes']
|
126 |
+
visible_langchain_modes0 = kwargs['visible_langchain_modes']
|
127 |
+
langchain_mode_paths0 = kwargs['langchain_mode_paths']
|
128 |
+
|
129 |
# easy update of kwargs needed for evaluate() etc.
|
130 |
queue = True
|
131 |
allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
|
|
|
145 |
" use Enter for multiple input lines)"
|
146 |
|
147 |
title = 'h2oGPT'
|
148 |
+
description = """<iframe src="https://ghbtns.com/github-btn.html?user=h2oai&repo=h2ogpt&type=star&count=true&size=small" frameborder="0" scrolling="0" width="250" height="20" title="GitHub"></iframe><small><a href="https://github.com/h2oai/h2ogpt">h2oGPT</a> <a href="https://github.com/h2oai/h2o-llmstudio">H2O LLM Studio</a><br><a href="https://huggingface.co/h2oai">🤗 Models</a>"""
|
149 |
+
description_bottom = "If this host is busy, try<br>[Multi-Model](https://gpt.h2o.ai)<br>[Falcon 40B](https://falcon.h2o.ai)<br>[Vicuna 33B](https://wizardvicuna.h2o.ai)<br>[MPT 30B-Chat](https://mpt.h2o.ai)<br>[HF Spaces1](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot)<br>[HF Spaces2](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
if is_hf:
|
151 |
description_bottom += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
|
152 |
+
task_info_md = ''
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
css_code = get_css(kwargs)
|
154 |
|
155 |
if kwargs['gradio_offline_level'] >= 0:
|
|
|
179 |
demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
|
180 |
callback = gr.CSVLogger()
|
181 |
|
182 |
+
model_options0 = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
|
183 |
+
if kwargs['base_model'].strip() not in model_options0:
|
184 |
+
model_options0 = [kwargs['base_model'].strip()] + model_options0
|
185 |
lora_options = kwargs['extra_lora_options']
|
186 |
if kwargs['lora_weights'].strip() not in lora_options:
|
187 |
lora_options = [kwargs['lora_weights'].strip()] + lora_options
|
|
|
196 |
|
197 |
# always add in no lora case
|
198 |
# add fake space so doesn't go away in gradio dropdown
|
199 |
+
model_options0 = [no_model_str] + model_options0
|
200 |
lora_options = [no_lora_str] + lora_options
|
201 |
server_options = [no_server_str] + server_options
|
202 |
# always add in no model case so can free memory
|
|
|
250 |
# else gets input_list at time of submit that is old, and shows up as truncated in chatbot
|
251 |
return x
|
252 |
|
253 |
+
def allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1):
|
254 |
+
allow = False
|
255 |
+
allow |= langchain_action1 not in LangChainAction.QUERY.value
|
256 |
+
allow |= document_subset1 in DocumentSubset.TopKSources.name
|
257 |
+
if langchain_mode1 in [LangChainMode.LLM.value]:
|
258 |
+
allow = False
|
259 |
+
return allow
|
260 |
+
|
261 |
with demo:
|
262 |
# avoid actual model/tokenizer here or anything that would be bad to deepcopy
|
263 |
# https://github.com/gradio-app/gradio/issues/3558
|
|
|
271 |
prompt_dict=kwargs['prompt_dict'],
|
272 |
)
|
273 |
)
|
274 |
+
|
275 |
+
def update_langchain_mode_paths(db1s, selection_docs_state1):
|
276 |
+
if allow_upload_to_my_data:
|
277 |
+
selection_docs_state1['langchain_mode_paths'].update({k: None for k in db1s})
|
278 |
+
dup = selection_docs_state1['langchain_mode_paths'].copy()
|
279 |
+
for k, v in dup.items():
|
280 |
+
if k not in selection_docs_state1['visible_langchain_modes']:
|
281 |
+
selection_docs_state1['langchain_mode_paths'].pop(k)
|
282 |
+
return selection_docs_state1
|
283 |
+
|
284 |
+
# Setup some gradio states for per-user dynamic state
|
285 |
model_state2 = gr.State(kwargs['model_state_none'].copy())
|
286 |
+
model_options_state = gr.State([model_options0])
|
287 |
lora_options_state = gr.State([lora_options])
|
288 |
server_options_state = gr.State([server_options])
|
289 |
+
my_db_state = gr.State(my_db_state0)
|
290 |
chat_state = gr.State({})
|
291 |
+
docs_state00 = kwargs['document_choice'] + [DocumentChoice.ALL.value]
|
292 |
docs_state0 = []
|
293 |
[docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
|
294 |
docs_state = gr.State(docs_state0)
|
295 |
viewable_docs_state0 = []
|
296 |
viewable_docs_state = gr.State(viewable_docs_state0)
|
297 |
+
selection_docs_state0 = update_langchain_mode_paths(my_db_state0, selection_docs_state0)
|
298 |
+
selection_docs_state = gr.State(selection_docs_state0)
|
299 |
+
|
300 |
gr.Markdown(f"""
|
301 |
{get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)}
|
302 |
""")
|
|
|
310 |
'model_lock'] else "Response Scores: %s" % nas
|
311 |
|
312 |
if kwargs['langchain_mode'] != LangChainMode.DISABLED.value:
|
313 |
+
extra_prompt_form = ". For summarization, no query required, just click submit"
|
314 |
else:
|
315 |
extra_prompt_form = ""
|
316 |
if kwargs['input_lines'] > 1:
|
|
|
318 |
else:
|
319 |
instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
|
320 |
|
321 |
+
def get_langchain_choices(selection_docs_state1):
|
322 |
+
langchain_modes = selection_docs_state1['langchain_modes']
|
323 |
+
visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
|
324 |
+
|
325 |
+
if is_hf:
|
326 |
+
# don't show 'wiki' since only usually useful for internal testing at moment
|
327 |
+
no_show_modes = ['Disabled', 'wiki']
|
328 |
+
else:
|
329 |
+
no_show_modes = ['Disabled']
|
330 |
+
allowed_modes = visible_langchain_modes.copy()
|
331 |
+
# allowed_modes = [x for x in allowed_modes if x in dbs]
|
332 |
+
allowed_modes += ['LLM']
|
333 |
+
if allow_upload_to_my_data and 'MyData' not in allowed_modes:
|
334 |
+
allowed_modes += ['MyData']
|
335 |
+
if allow_upload_to_user_data and 'UserData' not in allowed_modes:
|
336 |
+
allowed_modes += ['UserData']
|
337 |
+
choices = [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes]
|
338 |
+
return choices
|
339 |
+
|
340 |
+
def get_df_langchain_mode_paths(selection_docs_state1):
|
341 |
+
langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
|
342 |
+
if langchain_mode_paths:
|
343 |
+
df = pd.DataFrame.from_dict(langchain_mode_paths.items(), orient='columns')
|
344 |
+
df.columns = ['Collection', 'Path']
|
345 |
+
else:
|
346 |
+
df = pd.DataFrame(None)
|
347 |
+
return df
|
348 |
+
|
349 |
normal_block = gr.Row(visible=not base_wanted, equal_height=False)
|
350 |
with normal_block:
|
351 |
side_bar = gr.Column(elem_id="col_container", scale=1, min_width=100)
|
|
|
366 |
scale=1,
|
367 |
min_width=0,
|
368 |
elem_id="warning", elem_classes="feedback")
|
369 |
+
fileup_output_text = gr.Textbox(visible=False)
|
370 |
url_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload
|
371 |
url_label = 'URL/ArXiv' if have_arxiv else 'URL'
|
372 |
url_text = gr.Textbox(label=url_label,
|
|
|
380 |
visible=text_visible)
|
381 |
github_textbox = gr.Textbox(label="Github URL", visible=False) # FIXME WIP
|
382 |
database_visible = kwargs['langchain_mode'] != 'Disabled'
|
383 |
+
with gr.Accordion("Resources", open=False, visible=database_visible):
|
384 |
+
langchain_choices0 = get_langchain_choices(selection_docs_state0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
langchain_mode = gr.Radio(
|
386 |
+
langchain_choices0,
|
387 |
value=kwargs['langchain_mode'],
|
388 |
label="Collections",
|
389 |
show_label=True,
|
390 |
visible=kwargs['langchain_mode'] != 'Disabled',
|
391 |
min_width=100)
|
392 |
+
add_chat_history_to_context = gr.Checkbox(label="Chat History",
|
393 |
+
value=kwargs['add_chat_history_to_context'])
|
394 |
+
document_subset = gr.Radio([x.name for x in DocumentSubset],
|
395 |
label="Subset",
|
396 |
+
value=DocumentSubset.Relevant.name,
|
397 |
interactive=True,
|
398 |
)
|
399 |
allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
|
|
|
402 |
value=allowed_actions[0] if len(allowed_actions) > 0 else None,
|
403 |
label="Action",
|
404 |
visible=True)
|
405 |
+
allowed_agents = [x for x in langchain_agents_list if x in visible_langchain_agents]
|
406 |
+
langchain_agents = gr.Dropdown(
|
407 |
+
langchain_agents_list,
|
408 |
+
value=kwargs['langchain_agents'],
|
409 |
+
label="Agents",
|
410 |
+
multiselect=True,
|
411 |
+
interactive=True,
|
412 |
+
visible=False) # WIP
|
413 |
col_tabs = gr.Column(elem_id="col_container", scale=10)
|
414 |
with (col_tabs, gr.Tabs()):
|
415 |
with gr.TabItem("Chat"):
|
|
|
457 |
mw1 = 50
|
458 |
mw2 = 50
|
459 |
with gr.Column(min_width=mw1):
|
460 |
+
submit = gr.Button(value='Submit', variant='primary', size='sm',
|
461 |
min_width=mw1)
|
462 |
+
stop_btn = gr.Button(value="Stop", variant='secondary', size='sm',
|
463 |
min_width=mw1)
|
464 |
save_chat_btn = gr.Button("Save", size='sm', min_width=mw1)
|
465 |
with gr.Column(min_width=mw2):
|
|
|
480 |
with gr.TabItem("Document Selection"):
|
481 |
document_choice = gr.Dropdown(docs_state0,
|
482 |
label="Select Subset of Document(s) %s" % file_types_str,
|
483 |
+
value=[DocumentChoice.ALL.value],
|
484 |
interactive=True,
|
485 |
multiselect=True,
|
486 |
visible=kwargs['langchain_mode'] != 'Disabled',
|
487 |
)
|
488 |
sources_visible = kwargs['langchain_mode'] != 'Disabled' and enable_sources_list
|
489 |
with gr.Row():
|
490 |
+
with gr.Column(scale=1):
|
491 |
+
get_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0, size='sm',
|
492 |
+
visible=sources_visible)
|
493 |
+
show_sources_btn = gr.Button(value="Show Sources from DB", scale=0, size='sm',
|
494 |
+
visible=sources_visible)
|
495 |
+
refresh_sources_btn = gr.Button(value="Update DB with new/changed files on disk", scale=0,
|
496 |
+
size='sm',
|
497 |
+
visible=sources_visible and allow_upload_to_user_data)
|
498 |
+
with gr.Column(scale=4):
|
499 |
+
pass
|
500 |
+
with gr.Row():
|
501 |
+
with gr.Column(scale=1):
|
502 |
+
add_placeholder = "e.g. UserData2, user_path2 (optional)" \
|
503 |
+
if not is_public else "e.g. MyData2"
|
504 |
+
remove_placeholder = "e.g. UserData2" if not is_public else "e.g. MyData2"
|
505 |
+
new_langchain_mode_text = gr.Textbox(value="", visible=allow_upload_to_user_data or
|
506 |
+
allow_upload_to_my_data,
|
507 |
+
label='Add Collection',
|
508 |
+
placeholder=add_placeholder,
|
509 |
+
interactive=True)
|
510 |
+
remove_langchain_mode_text = gr.Textbox(value="", visible=allow_upload_to_user_data or
|
511 |
+
allow_upload_to_my_data,
|
512 |
+
label='Remove Collection',
|
513 |
+
placeholder=remove_placeholder,
|
514 |
+
interactive=True)
|
515 |
+
load_langchain = gr.Button(value="Load LangChain State", scale=0, size='sm',
|
516 |
+
visible=allow_upload_to_user_data)
|
517 |
+
with gr.Column(scale=1):
|
518 |
+
df0 = get_df_langchain_mode_paths(selection_docs_state0)
|
519 |
+
langchain_mode_path_text = gr.Dataframe(value=df0,
|
520 |
+
visible=allow_upload_to_user_data or
|
521 |
+
allow_upload_to_my_data,
|
522 |
+
label='LangChain Mode-Path',
|
523 |
+
show_label=False,
|
524 |
+
interactive=False)
|
525 |
+
with gr.Column(scale=4):
|
526 |
+
pass
|
527 |
|
528 |
sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list,
|
529 |
equal_height=False)
|
|
|
548 |
value=None,
|
549 |
interactive=True,
|
550 |
multiselect=False,
|
551 |
+
visible=True,
|
552 |
)
|
553 |
with gr.Column(scale=4):
|
554 |
pass
|
|
|
793 |
side_bar_btn = gr.Button("Toggle SideBar", variant="secondary", size="sm")
|
794 |
submit_buttons_btn = gr.Button("Toggle Submit Buttons", variant="secondary", size="sm")
|
795 |
col_tabs_scale = gr.Slider(minimum=1, maximum=20, value=10, step=1, label='Window Size')
|
796 |
+
text_outputs_height = gr.Slider(minimum=100, maximum=2000, value=kwargs['height'] or 400,
|
797 |
+
step=50, label='Chat Height')
|
798 |
dark_mode_btn = gr.Button("Dark Mode", variant="secondary", size="sm")
|
799 |
with gr.Column(scale=4):
|
800 |
pass
|
801 |
+
system_visible0 = not is_public and not admin_pass
|
802 |
admin_row = gr.Row()
|
803 |
with admin_row:
|
804 |
with gr.Column(scale=1):
|
805 |
+
admin_pass_textbox = gr.Textbox(label="Admin Password", type='password',
|
806 |
+
visible=not system_visible0)
|
807 |
with gr.Column(scale=4):
|
808 |
pass
|
809 |
+
system_row = gr.Row(visible=system_visible0)
|
810 |
with system_row:
|
811 |
with gr.Column():
|
812 |
with gr.Row():
|
|
|
870 |
else:
|
871 |
return tuple([gr.update(interactive=True)] * len(args))
|
872 |
|
873 |
+
# Add to UserData or custom user db
|
874 |
update_db_func = functools.partial(update_user_db,
|
875 |
dbs=dbs,
|
876 |
db_type=db_type,
|
877 |
use_openai_embedding=use_openai_embedding,
|
878 |
hf_embedding_model=hf_embedding_model,
|
|
|
879 |
captions_model=captions_model,
|
880 |
+
enable_captions=enable_captions,
|
881 |
caption_loader=caption_loader,
|
882 |
+
enable_ocr=enable_ocr,
|
883 |
+
enable_pdf_ocr=enable_pdf_ocr,
|
884 |
verbose=kwargs['verbose'],
|
|
|
885 |
n_jobs=kwargs['n_jobs'],
|
886 |
)
|
887 |
add_file_outputs = [fileup_output, langchain_mode]
|
888 |
add_file_kwargs = dict(fn=update_db_func,
|
889 |
+
inputs=[fileup_output, my_db_state, selection_docs_state, chunk, chunk_size,
|
890 |
+
langchain_mode],
|
891 |
outputs=add_file_outputs + [sources_text, doc_exception_text],
|
892 |
queue=queue,
|
893 |
api_name='add_file' if allow_api and allow_upload_to_user_data else None)
|
|
|
899 |
eventdb1b = eventdb1.then(make_interactive, inputs=add_file_outputs, outputs=add_file_outputs,
|
900 |
show_progress='minimal')
|
901 |
|
902 |
+
# deal with challenge to have fileup_output itself as input
|
903 |
+
add_file_kwargs2 = dict(fn=update_db_func,
|
904 |
+
inputs=[fileup_output_text, my_db_state, selection_docs_state, chunk, chunk_size,
|
905 |
+
langchain_mode],
|
906 |
+
outputs=add_file_outputs + [sources_text, doc_exception_text],
|
907 |
+
queue=queue,
|
908 |
+
api_name='add_file_api' if allow_api and allow_upload_to_user_data else None)
|
909 |
+
eventdb1_api = fileup_output_text.submit(**add_file_kwargs2, show_progress='full')
|
910 |
+
|
911 |
# note for update_user_db_func output is ignored for db
|
912 |
|
913 |
def clear_textbox():
|
|
|
917 |
|
918 |
add_url_outputs = [url_text, langchain_mode]
|
919 |
add_url_kwargs = dict(fn=update_user_db_url_func,
|
920 |
+
inputs=[url_text, my_db_state, selection_docs_state, chunk, chunk_size,
|
921 |
+
langchain_mode],
|
922 |
outputs=add_url_outputs + [sources_text, doc_exception_text],
|
923 |
queue=queue,
|
924 |
api_name='add_url' if allow_api and allow_upload_to_user_data else None)
|
|
|
935 |
update_user_db_txt_func = functools.partial(update_db_func, is_txt=True)
|
936 |
add_text_outputs = [user_text_text, langchain_mode]
|
937 |
add_text_kwargs = dict(fn=update_user_db_txt_func,
|
938 |
+
inputs=[user_text_text, my_db_state, selection_docs_state, chunk, chunk_size,
|
939 |
+
langchain_mode],
|
940 |
outputs=add_text_outputs + [sources_text, doc_exception_text],
|
941 |
queue=queue,
|
942 |
api_name='add_text' if allow_api and allow_upload_to_user_data else None
|
|
|
948 |
eventdb3 = eventdb3b.then(**add_text_kwargs, show_progress='full')
|
949 |
eventdb3c = eventdb3.then(make_interactive, inputs=add_text_outputs, outputs=add_text_outputs,
|
950 |
show_progress='minimal')
|
951 |
+
db_events = [eventdb1a, eventdb1, eventdb1b, eventdb1_api,
|
952 |
eventdb2a, eventdb2, eventdb2b, eventdb2c,
|
953 |
eventdb3a, eventdb3b, eventdb3, eventdb3c]
|
954 |
|
|
|
956 |
|
957 |
# if change collection source, must clear doc selections from it to avoid inconsistency
|
958 |
def clear_doc_choice():
|
959 |
+
return gr.Dropdown.update(choices=docs_state0, value=DocumentChoice.ALL.value)
|
960 |
|
961 |
langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice, queue=False)
|
962 |
|
963 |
def resize_col_tabs(x):
|
964 |
return gr.Dropdown.update(scale=x)
|
965 |
|
966 |
+
col_tabs_scale.change(fn=resize_col_tabs, inputs=col_tabs_scale, outputs=col_tabs, queue=False)
|
967 |
|
968 |
def resize_chatbots(x, num_model_lock=0):
|
969 |
if num_model_lock == 0:
|
|
|
974 |
|
975 |
resize_chatbots_func = functools.partial(resize_chatbots, num_model_lock=len(text_outputs))
|
976 |
text_outputs_height.change(fn=resize_chatbots_func, inputs=text_outputs_height,
|
977 |
+
outputs=[text_output, text_output2] + text_outputs, queue=False)
|
978 |
|
979 |
def update_dropdown(x):
|
980 |
return gr.Dropdown.update(choices=x, value=[docs_state0[0]])
|
|
|
1065 |
if file.startswith('http') or file.startswith('https'):
|
1066 |
# if file is online, then might as well use google(?)
|
1067 |
document1 = file
|
1068 |
+
return gr.update(visible=True,
|
1069 |
+
value=f"""<iframe width="1000" height="800" src="https://docs.google.com/viewerng/viewer?url={document1}&embedded=true" frameborder="0" height="100%" width="100%">
|
1070 |
</iframe>
|
1071 |
"""), dummy1, dummy1, dummy1
|
1072 |
else:
|
|
|
1089 |
|
1090 |
refresh_sources1 = functools.partial(update_and_get_source_files_given_langchain_mode,
|
1091 |
**get_kwargs(update_and_get_source_files_given_langchain_mode,
|
1092 |
+
exclude_names=['db1s', 'langchain_mode', 'chunk',
|
1093 |
+
'chunk_size'],
|
1094 |
**all_kwargs))
|
1095 |
+
eventdb9 = refresh_sources_btn.click(fn=refresh_sources1,
|
1096 |
+
inputs=[my_db_state, langchain_mode, chunk, chunk_size],
|
1097 |
outputs=sources_text,
|
1098 |
api_name='refresh_sources' if allow_api else None)
|
1099 |
|
|
|
1103 |
def close_admin(x):
|
1104 |
return gr.update(visible=not (x == admin_pass))
|
1105 |
|
1106 |
+
admin_pass_textbox.submit(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row, queue=False) \
|
1107 |
.then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
|
1108 |
|
1109 |
+
def add_langchain_mode(db1s, selection_docs_state1, langchain_mode1, y):
|
1110 |
+
for k in db1s:
|
1111 |
+
set_userid(db1s[k])
|
1112 |
+
langchain_modes = selection_docs_state1['langchain_modes']
|
1113 |
+
langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
|
1114 |
+
visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
|
1115 |
+
|
1116 |
+
user_path = None
|
1117 |
+
valid = True
|
1118 |
+
y2 = y.strip().replace(' ', '').split(',')
|
1119 |
+
if len(y2) >= 1:
|
1120 |
+
langchain_mode2 = y2[0]
|
1121 |
+
if len(langchain_mode2) >= 3 and langchain_mode2.isalnum():
|
1122 |
+
# real restriction is:
|
1123 |
+
# ValueError: Expected collection name that (1) contains 3-63 characters, (2) starts and ends with an alphanumeric character, (3) otherwise contains only alphanumeric characters, underscores or hyphens (-), (4) contains no two consecutive periods (..) and (5) is not a valid IPv4 address, got me
|
1124 |
+
# but just make simpler
|
1125 |
+
user_path = y2[1] if len(y2) > 1 else None # assume scratch if don't have user_path
|
1126 |
+
if user_path in ['', "''"]:
|
1127 |
+
# for scratch spaces
|
1128 |
+
user_path = None
|
1129 |
+
if langchain_mode2 in langchain_modes_intrinsic:
|
1130 |
+
user_path = None
|
1131 |
+
textbox = "Invalid access to use internal name: %s" % langchain_mode2
|
1132 |
+
valid = False
|
1133 |
+
langchain_mode2 = langchain_mode1
|
1134 |
+
elif user_path and allow_upload_to_user_data or not user_path and allow_upload_to_my_data:
|
1135 |
+
langchain_mode_paths.update({langchain_mode2: user_path})
|
1136 |
+
if langchain_mode2 not in visible_langchain_modes:
|
1137 |
+
visible_langchain_modes.append(langchain_mode2)
|
1138 |
+
if langchain_mode2 not in langchain_modes:
|
1139 |
+
langchain_modes.append(langchain_mode2)
|
1140 |
+
textbox = ''
|
1141 |
+
if user_path:
|
1142 |
+
makedirs(user_path, exist_ok=True)
|
1143 |
+
else:
|
1144 |
+
valid = False
|
1145 |
+
langchain_mode2 = langchain_mode1
|
1146 |
+
textbox = "Invalid access. user allowed: %s " \
|
1147 |
+
"scratch allowed: %s" % (allow_upload_to_user_data, allow_upload_to_my_data)
|
1148 |
+
else:
|
1149 |
+
valid = False
|
1150 |
+
langchain_mode2 = langchain_mode1
|
1151 |
+
textbox = "Invalid, collection must be >=3 characters and alphanumeric"
|
1152 |
+
else:
|
1153 |
+
valid = False
|
1154 |
+
langchain_mode2 = langchain_mode1
|
1155 |
+
textbox = "Invalid, must be like UserData2, user_path2"
|
1156 |
+
selection_docs_state1 = update_langchain_mode_paths(db1s, selection_docs_state1)
|
1157 |
+
df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1)
|
1158 |
+
choices = get_langchain_choices(selection_docs_state1)
|
1159 |
+
|
1160 |
+
if valid and not user_path:
|
1161 |
+
# needs to have key for it to make it known different from userdata case in _update_user_db()
|
1162 |
+
db1s[langchain_mode2] = [None, None]
|
1163 |
+
if valid:
|
1164 |
+
save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode,
|
1165 |
+
db1s)
|
1166 |
+
|
1167 |
+
return db1s, selection_docs_state1, gr.update(choices=choices,
|
1168 |
+
value=langchain_mode2), textbox, df_langchain_mode_paths1
|
1169 |
+
|
1170 |
+
def remove_langchain_mode(db1s, selection_docs_state1, langchain_mode1, langchain_mode2, dbsu=None):
|
1171 |
+
for k in db1s:
|
1172 |
+
set_userid(db1s[k])
|
1173 |
+
assert dbsu is not None
|
1174 |
+
langchain_modes = selection_docs_state1['langchain_modes']
|
1175 |
+
langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
|
1176 |
+
visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
|
1177 |
+
|
1178 |
+
if langchain_mode2 in db1s and not allow_upload_to_my_data or \
|
1179 |
+
dbsu is not None and langchain_mode2 in dbsu and not allow_upload_to_user_data or \
|
1180 |
+
langchain_mode2 in langchain_modes_intrinsic:
|
1181 |
+
# NOTE: Doesn't fail if remove MyData, but didn't debug odd behavior seen with upload after gone
|
1182 |
+
textbox = "Invalid access, cannot remove %s" % langchain_mode2
|
1183 |
+
df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1)
|
1184 |
+
else:
|
1185 |
+
# change global variables
|
1186 |
+
if langchain_mode2 in visible_langchain_modes:
|
1187 |
+
visible_langchain_modes.remove(langchain_mode2)
|
1188 |
+
textbox = ""
|
1189 |
+
else:
|
1190 |
+
textbox = "%s was not visible" % langchain_mode2
|
1191 |
+
if langchain_mode2 in langchain_modes:
|
1192 |
+
langchain_modes.remove(langchain_mode2)
|
1193 |
+
if langchain_mode2 in langchain_mode_paths:
|
1194 |
+
langchain_mode_paths.pop(langchain_mode2)
|
1195 |
+
if langchain_mode2 in db1s:
|
1196 |
+
# remove db entirely, so not in list, else need to manage visible list in update_langchain_mode_paths()
|
1197 |
+
# FIXME: Remove location?
|
1198 |
+
if langchain_mode2 != LangChainMode.MY_DATA.value:
|
1199 |
+
# don't remove last MyData, used as user hash
|
1200 |
+
db1s.pop(langchain_mode2)
|
1201 |
+
# only show
|
1202 |
+
selection_docs_state1 = update_langchain_mode_paths(db1s, selection_docs_state1)
|
1203 |
+
df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1)
|
1204 |
+
|
1205 |
+
save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode,
|
1206 |
+
db1s)
|
1207 |
+
|
1208 |
+
return db1s, selection_docs_state1, \
|
1209 |
+
gr.update(choices=get_langchain_choices(selection_docs_state1),
|
1210 |
+
value=langchain_mode2), textbox, df_langchain_mode_paths1
|
1211 |
+
|
1212 |
+
new_langchain_mode_text.submit(fn=add_langchain_mode,
|
1213 |
+
inputs=[my_db_state, selection_docs_state, langchain_mode,
|
1214 |
+
new_langchain_mode_text],
|
1215 |
+
outputs=[my_db_state, selection_docs_state, langchain_mode,
|
1216 |
+
new_langchain_mode_text,
|
1217 |
+
langchain_mode_path_text],
|
1218 |
+
api_name='new_langchain_mode_text' if allow_api and allow_upload_to_user_data else None)
|
1219 |
+
remove_langchain_mode_func = functools.partial(remove_langchain_mode, dbsu=dbs)
|
1220 |
+
remove_langchain_mode_text.submit(fn=remove_langchain_mode_func,
|
1221 |
+
inputs=[my_db_state, selection_docs_state, langchain_mode,
|
1222 |
+
remove_langchain_mode_text],
|
1223 |
+
outputs=[my_db_state, selection_docs_state, langchain_mode,
|
1224 |
+
remove_langchain_mode_text,
|
1225 |
+
langchain_mode_path_text],
|
1226 |
+
api_name='remove_langchain_mode_text' if allow_api and allow_upload_to_user_data else None)
|
1227 |
+
|
1228 |
+
def update_langchain_gr(db1s, selection_docs_state1, langchain_mode1):
|
1229 |
+
for k in db1s:
|
1230 |
+
set_userid(db1s[k])
|
1231 |
+
langchain_modes = selection_docs_state1['langchain_modes']
|
1232 |
+
langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
|
1233 |
+
visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
|
1234 |
+
# in-place
|
1235 |
+
|
1236 |
+
# update user collaborative collections
|
1237 |
+
update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, '')
|
1238 |
+
# update scratch single-user collections
|
1239 |
+
user_hash = db1s.get(LangChainMode.MY_DATA.value, '')[1]
|
1240 |
+
update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, user_hash)
|
1241 |
+
|
1242 |
+
selection_docs_state1 = update_langchain_mode_paths(db1s, selection_docs_state1)
|
1243 |
+
df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1)
|
1244 |
+
return selection_docs_state1, \
|
1245 |
+
gr.update(choices=get_langchain_choices(selection_docs_state1),
|
1246 |
+
value=langchain_mode1), df_langchain_mode_paths1
|
1247 |
+
|
1248 |
+
load_langchain.click(fn=update_langchain_gr,
|
1249 |
+
inputs=[my_db_state, selection_docs_state, langchain_mode],
|
1250 |
+
outputs=[selection_docs_state, langchain_mode, langchain_mode_path_text],
|
1251 |
+
api_name='load_langchain' if allow_api and allow_upload_to_user_data else None)
|
1252 |
+
|
1253 |
inputs_list, inputs_dict = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=1)
|
1254 |
inputs_list2, inputs_dict2 = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=2)
|
1255 |
from functools import partial
|
|
|
1261 |
def evaluate_nochat(*args1, default_kwargs1=None, str_api=False, **kwargs1):
|
1262 |
args_list = list(args1)
|
1263 |
if str_api:
|
1264 |
+
user_kwargs = args_list[len(input_args_list)]
|
1265 |
assert isinstance(user_kwargs, str)
|
1266 |
user_kwargs = ast.literal_eval(user_kwargs)
|
1267 |
else:
|
1268 |
+
user_kwargs = {k: v for k, v in zip(eval_func_param_names, args_list[len(input_args_list):])}
|
1269 |
# only used for submit_nochat_api
|
1270 |
user_kwargs['chat'] = False
|
1271 |
if 'stream_output' not in user_kwargs:
|
|
|
1275 |
user_kwargs['langchain_mode'] = 'Disabled'
|
1276 |
if 'langchain_action' not in user_kwargs:
|
1277 |
user_kwargs['langchain_action'] = LangChainAction.QUERY.value
|
1278 |
+
if 'langchain_agents' not in user_kwargs:
|
1279 |
+
user_kwargs['langchain_agents'] = []
|
1280 |
|
1281 |
set1 = set(list(default_kwargs1.keys()))
|
1282 |
set2 = set(eval_func_param_names)
|
|
|
1284 |
# correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
|
1285 |
model_state1 = args_list[0]
|
1286 |
my_db_state1 = args_list[1]
|
1287 |
+
selection_docs_state1 = args_list[2]
|
1288 |
args_list = [user_kwargs[k] if k in user_kwargs and user_kwargs[k] is not None else default_kwargs1[k] for k
|
1289 |
in eval_func_param_names]
|
1290 |
assert len(args_list) == len(eval_func_param_names)
|
1291 |
+
args_list = [model_state1, my_db_state1, selection_docs_state1] + args_list
|
1292 |
|
1293 |
try:
|
1294 |
for res_dict in evaluate(*tuple(args_list), **kwargs1):
|
|
|
1459 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
1460 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1461 |
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
1462 |
+
langchain_agents1 = args_list[eval_func_param_names.index('langchain_agents')]
|
1463 |
document_subset1 = args_list[eval_func_param_names.index('document_subset')]
|
1464 |
document_choice1 = args_list[eval_func_param_names.index('document_choice')]
|
1465 |
if not prompt_type1:
|
|
|
1492 |
history[-1][1] = None
|
1493 |
return history
|
1494 |
if user_message1 in ['', None, '\n']:
|
1495 |
+
if not allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1):
|
|
|
|
|
|
|
1496 |
# reject non-retry submit/enter
|
1497 |
return history
|
1498 |
user_message1 = fix_text_for_gradio(user_message1)
|
|
|
1539 |
API only called for which_model=0, default for inputs_list, but rest should ignore inputs_list
|
1540 |
:return: last element is True if should run bot, False if should just yield history
|
1541 |
"""
|
1542 |
+
isize = len(input_args_list) + 1 # states + chat history
|
1543 |
# don't deepcopy, can contain model itself
|
1544 |
args_list = list(args).copy()
|
1545 |
+
model_state1 = args_list[-isize]
|
1546 |
+
my_db_state1 = args_list[-isize + 1]
|
1547 |
+
selection_docs_state1 = args_list[-isize + 2]
|
1548 |
history = args_list[-1]
|
1549 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
1550 |
prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
|
|
|
1552 |
if model_state1['model'] is None or model_state1['model'] == no_model_str:
|
1553 |
return history, None, None, None
|
1554 |
|
1555 |
+
args_list = args_list[:-isize] # only keep rest needed for evaluate()
|
1556 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1557 |
+
add_chat_history_to_context1 = args_list[eval_func_param_names.index('add_chat_history_to_context')]
|
1558 |
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
1559 |
+
langchain_agents1 = args_list[eval_func_param_names.index('langchain_agents')]
|
1560 |
document_subset1 = args_list[eval_func_param_names.index('document_subset')]
|
1561 |
document_choice1 = args_list[eval_func_param_names.index('document_choice')]
|
1562 |
if not history:
|
|
|
1569 |
instruction1 = history[-1][0]
|
1570 |
history[-1][1] = None
|
1571 |
elif not instruction1:
|
1572 |
+
if not allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1):
|
|
|
|
|
|
|
1573 |
# if not retrying, then reject empty query
|
1574 |
return history, None, None, None
|
1575 |
elif len(history) > 0 and history[-1][1] not in [None, '']:
|
|
|
1586 |
|
1587 |
chat1 = args_list[eval_func_param_names.index('chat')]
|
1588 |
model_max_length1 = get_model_max_length(model_state1)
|
1589 |
+
context1 = history_to_context(history, langchain_mode1,
|
1590 |
+
add_chat_history_to_context1,
|
1591 |
+
prompt_type1, prompt_dict1, chat1,
|
1592 |
model_max_length1, memory_restriction_level,
|
1593 |
kwargs['keep_sources_in_context'])
|
1594 |
args_list[0] = instruction1 # override original instruction with history from user
|
|
|
1597 |
fun1 = partial(evaluate,
|
1598 |
model_state1,
|
1599 |
my_db_state1,
|
1600 |
+
selection_docs_state1,
|
1601 |
*tuple(args_list),
|
1602 |
**kwargs_evaluate)
|
1603 |
|
|
|
1643 |
clear_torch_cache()
|
1644 |
return
|
1645 |
|
1646 |
+
def clear_embeddings(langchain_mode1, db1s):
|
1647 |
# clear any use of embedding that sits on GPU, else keeps accumulating GPU usage even if clear torch cache
|
1648 |
+
if db_type == 'chroma' and langchain_mode1 not in ['LLM', 'Disabled', None, '']:
|
1649 |
from gpt_langchain import clear_embedding
|
1650 |
db = dbs.get('langchain_mode1')
|
1651 |
if db is not None and not isinstance(db, str):
|
1652 |
clear_embedding(db)
|
1653 |
+
if db1s is not None and langchain_mode1 in db1s:
|
1654 |
+
db1 = db1s[langchain_mode1]
|
1655 |
+
if len(db1) == 2:
|
1656 |
+
clear_embedding(db1[0])
|
1657 |
|
1658 |
def bot(*args, retry=False):
|
1659 |
+
history, fun1, langchain_mode1, db1 = prep_bot(*args, retry=retry)
|
1660 |
try:
|
1661 |
for res in get_response(fun1, history):
|
1662 |
yield res
|
1663 |
finally:
|
1664 |
clear_torch_cache()
|
1665 |
+
clear_embeddings(langchain_mode1, db1)
|
1666 |
|
1667 |
def all_bot(*args, retry=False, model_states1=None):
|
1668 |
args_list = list(args).copy()
|
|
|
1672 |
stream_output1 = args_list[eval_func_param_names.index('stream_output')]
|
1673 |
max_time1 = args_list[eval_func_param_names.index('max_time')]
|
1674 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1675 |
+
isize = len(input_args_list) + 1 # states + chat history
|
1676 |
+
db1s = None
|
1677 |
try:
|
1678 |
gen_list = []
|
1679 |
for chatboti, (chatbot1, model_state1) in enumerate(zip(chatbots, model_states1)):
|
1680 |
args_list1 = args_list0.copy()
|
1681 |
+
args_list1.insert(-isize + 2,
|
1682 |
+
model_state1) # insert at -2 so is at -3, and after chatbot1 added, at -4
|
1683 |
# if at start, have None in response still, replace with '' so client etc. acts like normal
|
1684 |
# assumes other parts of code treat '' and None as if no response yet from bot
|
1685 |
# can't do this later in bot code as racy with threaded generators
|
|
|
1689 |
# so consistent with prep_bot()
|
1690 |
# with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
|
1691 |
# langchain_mode1 and my_db_state1 should be same for every bot
|
1692 |
+
history, fun1, langchain_mode1, db1s = prep_bot(*tuple(args_list1), retry=retry,
|
1693 |
+
which_model=chatboti)
|
1694 |
gen1 = get_response(fun1, history)
|
1695 |
if stream_output1:
|
1696 |
gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
|
|
|
1736 |
print("Generate exceptions: %s" % exceptions, flush=True)
|
1737 |
finally:
|
1738 |
clear_torch_cache()
|
1739 |
+
clear_embeddings(langchain_mode1, db1s)
|
1740 |
|
1741 |
# NORMAL MODEL
|
1742 |
user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
|
|
|
1744 |
outputs=text_output,
|
1745 |
)
|
1746 |
bot_args = dict(fn=bot,
|
1747 |
+
inputs=inputs_list + [model_state, my_db_state, selection_docs_state] + [text_output],
|
1748 |
outputs=[text_output, chat_exception_text],
|
1749 |
)
|
1750 |
retry_bot_args = dict(fn=functools.partial(bot, retry=True),
|
1751 |
+
inputs=inputs_list + [model_state, my_db_state, selection_docs_state] + [text_output],
|
1752 |
outputs=[text_output, chat_exception_text],
|
1753 |
)
|
1754 |
retry_user_args = dict(fn=functools.partial(user, retry=True),
|
|
|
1766 |
outputs=text_output2,
|
1767 |
)
|
1768 |
bot_args2 = dict(fn=bot,
|
1769 |
+
inputs=inputs_list2 + [model_state2, my_db_state, selection_docs_state] + [text_output2],
|
1770 |
outputs=[text_output2, chat_exception_text],
|
1771 |
)
|
1772 |
retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
|
1773 |
+
inputs=inputs_list2 + [model_state2, my_db_state, selection_docs_state] + [text_output2],
|
1774 |
outputs=[text_output2, chat_exception_text],
|
1775 |
)
|
1776 |
retry_user_args2 = dict(fn=functools.partial(user, retry=True),
|
|
|
1791 |
outputs=text_outputs,
|
1792 |
)
|
1793 |
all_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states),
|
1794 |
+
inputs=inputs_list + [my_db_state, selection_docs_state] + text_outputs,
|
1795 |
outputs=text_outputs + [chat_exception_text],
|
1796 |
)
|
1797 |
all_retry_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, retry=True),
|
1798 |
+
inputs=inputs_list + [my_db_state, selection_docs_state] + text_outputs,
|
1799 |
outputs=text_outputs + [chat_exception_text],
|
1800 |
)
|
1801 |
all_retry_user_args = dict(fn=functools.partial(all_user, retry=True,
|
|
|
1957 |
def get_short_chat(x, short_chats, short_len=20, words=4):
|
1958 |
if x and len(x[0]) == 2 and x[0][0] is not None:
|
1959 |
short_chat = ' '.join(x[0][0][:short_len].split(' ')[:words]).strip()
|
1960 |
+
if not short_chat:
|
1961 |
+
# e.g.summarization, try using answer
|
1962 |
+
short_chat = ' '.join(x[0][1][:short_len].split(' ')[:words]).strip()
|
1963 |
+
if not short_chat:
|
1964 |
+
short_chat = 'Unk'
|
1965 |
short_chat = dedup(short_chat, short_chats)
|
1966 |
else:
|
1967 |
short_chat = None
|
|
|
2029 |
already_exists = any([is_chat_same(chat_list, x) for x in old_chat_lists])
|
2030 |
if not already_exists:
|
2031 |
chat_state1[short_chat] = chat_list.copy()
|
2032 |
+
|
2033 |
+
# reverse so newest at top
|
2034 |
+
choices = list(chat_state1.keys()).copy()
|
2035 |
+
choices.reverse()
|
2036 |
+
|
2037 |
+
return chat_state1, gr.update(choices=choices, value=None)
|
|
|
|
|
2038 |
|
2039 |
def switch_chat(chat_key, chat_state1, num_model_lock=0):
|
2040 |
chosen_chat = chat_state1[chat_key]
|
|
|
2065 |
|
2066 |
remove_chat_event = remove_chat_btn.click(remove_chat,
|
2067 |
inputs=[radio_chats, chat_state], outputs=[radio_chats, chat_state],
|
2068 |
+
queue=False, api_name='remove_chat')
|
2069 |
|
2070 |
def get_chats1(chat_state1):
|
2071 |
base = 'chats'
|
|
|
2096 |
new_chats = json.loads(f.read())
|
2097 |
for chat1_k, chat1_v in new_chats.items():
|
2098 |
# ignore chat1_k, regenerate and de-dup to avoid loss
|
2099 |
+
chat_state1, _ = save_chat(chat1_v, chat_state1, chat_is_list=True)
|
2100 |
except BaseException as e:
|
2101 |
t, v, tb = sys.exc_info()
|
2102 |
ex = ''.join(traceback.format_exception(t, v, tb))
|
|
|
2122 |
.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \
|
2123 |
.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
|
2124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
2125 |
clear_event = save_chat_btn.click(save_chat,
|
2126 |
inputs=[text_output, text_output2] + text_outputs + [chat_state],
|
2127 |
+
outputs=[chat_state, radio_chats],
|
2128 |
+
api_name='save_chat' if allow_api else None)
|
2129 |
+
if kwargs['score_model']:
|
2130 |
+
clear_event2 = clear_event.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
|
|
|
2131 |
|
2132 |
# NOTE: clear of instruction/iinput for nochat has to come after score,
|
2133 |
# because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
|
2134 |
no_chat_args = dict(fn=fun,
|
2135 |
+
inputs=[model_state, my_db_state, selection_docs_state] + inputs_list,
|
2136 |
outputs=text_output_nochat,
|
2137 |
queue=queue,
|
2138 |
)
|
|
|
2151 |
.then(clear_torch_cache)
|
2152 |
|
2153 |
submit_event_nochat_api = submit_nochat_api.click(fun_with_dict_str,
|
2154 |
+
inputs=[model_state, my_db_state, selection_docs_state,
|
2155 |
+
inputs_dict_str],
|
2156 |
outputs=text_output_nochat_api,
|
2157 |
queue=True, # required for generator
|
2158 |
api_name='submit_nochat_api' if allow_api else None) \
|
|
|
2402 |
print("Exception: %s" % str(e), flush=True)
|
2403 |
return json.dumps(sys_dict)
|
2404 |
|
2405 |
+
system_kwargs = all_kwargs.copy()
|
2406 |
+
system_kwargs.update(dict(command=str(' '.join(sys.argv))))
|
2407 |
get_system_info_dict_func = functools.partial(get_system_info_dict, **all_kwargs)
|
2408 |
|
2409 |
system_dict_event = system_btn2.click(get_system_info_dict_func,
|
|
|
2433 |
else:
|
2434 |
tokenizer = None
|
2435 |
if tokenizer is not None:
|
2436 |
+
langchain_mode1 = 'LLM'
|
2437 |
+
add_chat_history_to_context1 = True
|
2438 |
# fake user message to mimic bot()
|
2439 |
chat1 = copy.deepcopy(chat1)
|
2440 |
chat1 = chat1 + [['user_message1', None]]
|
2441 |
model_max_length1 = tokenizer.model_max_length
|
2442 |
+
context1 = history_to_context(chat1, langchain_mode1,
|
2443 |
+
add_chat_history_to_context1,
|
2444 |
+
prompt_type1, prompt_dict1, chat1,
|
2445 |
model_max_length1,
|
2446 |
memory_restriction_level1, keep_sources_in_context1)
|
2447 |
return str(tokenizer(context1, return_tensors="pt")['input_ids'].shape[1])
|
|
|
2471 |
,
|
2472 |
queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
|
2473 |
|
2474 |
+
demo.load(None, None, None, _js=get_dark_js() if kwargs['dark'] else None)
|
2475 |
|
2476 |
demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
|
2477 |
favicon_path = "h2o-logo.svg"
|
|
|
2486 |
# FIXME: disable for gptj, langchain or gpt4all modify print itself
|
2487 |
# FIXME: and any multi-threaded/async print will enter model output!
|
2488 |
scheduler.add_job(func=ping, trigger="interval", seconds=60)
|
2489 |
+
if is_public or os.getenv('PING_GPU'):
|
2490 |
+
scheduler.add_job(func=ping_gpu, trigger="interval", seconds=60 * 10)
|
2491 |
scheduler.start()
|
2492 |
|
2493 |
# import control
|
|
|
2506 |
demo.block_thread()
|
2507 |
|
2508 |
|
|
|
|
|
|
|
2509 |
def get_inputs_list(inputs_dict, model_lower, model_id=1):
|
2510 |
"""
|
2511 |
map gradio objects in locals() to inputs for evaluate().
|
|
|
2539 |
return inputs_list, inputs_dict_out
|
2540 |
|
2541 |
|
2542 |
+
def get_sources(db1s, langchain_mode, dbs=None, docs_state0=None):
|
2543 |
+
for k in db1s:
|
2544 |
+
set_userid(db1s[k])
|
2545 |
|
2546 |
if langchain_mode in ['ChatLLM', 'LLM']:
|
2547 |
source_files_added = "NA"
|
|
|
2550 |
source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
|
2551 |
" Ask jon.mckinney@h2o.ai for file if required."
|
2552 |
source_list = []
|
2553 |
+
elif langchain_mode in db1s and len(db1s[langchain_mode]) == 2 and db1s[langchain_mode][0] is not None:
|
2554 |
+
db1 = db1s[langchain_mode]
|
2555 |
from gpt_langchain import get_metadatas
|
2556 |
metadatas = get_metadatas(db1[0])
|
2557 |
source_list = sorted(set([x['source'] for x in metadatas]))
|
|
|
2582 |
db1[1] = str(uuid.uuid4())
|
2583 |
|
2584 |
|
2585 |
+
def update_user_db(file, db1s, selection_docs_state1, chunk, chunk_size, langchain_mode, dbs=None, **kwargs):
|
2586 |
+
kwargs.update(selection_docs_state1)
|
|
|
2587 |
if file is None:
|
2588 |
raise RuntimeError("Don't use change, use input")
|
2589 |
|
2590 |
try:
|
2591 |
+
return _update_user_db(file, db1s=db1s, chunk=chunk, chunk_size=chunk_size,
|
2592 |
langchain_mode=langchain_mode, dbs=dbs,
|
2593 |
**kwargs)
|
2594 |
except BaseException as e:
|
|
|
2619 |
user_id = db1[1]
|
2620 |
base_path = 'locks'
|
2621 |
makedirs(base_path)
|
2622 |
+
lock_file = os.path.join(base_path, "db_%s_%s.lock" % (langchain_mode.replace(' ', '_'), user_id))
|
2623 |
return lock_file
|
2624 |
|
2625 |
|
2626 |
def _update_user_db(file,
|
2627 |
+
db1s=None,
|
2628 |
chunk=None, chunk_size=None,
|
2629 |
+
dbs=None, db_type=None,
|
2630 |
+
langchain_mode='UserData',
|
2631 |
+
langchain_modes=None, # unused but required as part of selection_docs_state1
|
2632 |
+
langchain_mode_paths=None,
|
2633 |
+
visible_langchain_modes=None,
|
2634 |
use_openai_embedding=None,
|
2635 |
hf_embedding_model=None,
|
2636 |
caption_loader=None,
|
2637 |
enable_captions=None,
|
2638 |
captions_model=None,
|
2639 |
enable_ocr=None,
|
2640 |
+
enable_pdf_ocr=None,
|
2641 |
verbose=None,
|
2642 |
+
n_jobs=-1,
|
2643 |
is_url=None, is_txt=None,
|
2644 |
+
):
|
2645 |
+
assert db1s is not None
|
2646 |
assert chunk is not None
|
2647 |
assert chunk_size is not None
|
2648 |
assert use_openai_embedding is not None
|
|
|
2651 |
assert enable_captions is not None
|
2652 |
assert captions_model is not None
|
2653 |
assert enable_ocr is not None
|
2654 |
+
assert enable_pdf_ocr is not None
|
2655 |
assert verbose is not None
|
2656 |
|
|
|
|
|
2657 |
if dbs is None:
|
2658 |
dbs = {}
|
2659 |
assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs))
|
|
|
2671 |
if langchain_mode == LangChainMode.DISABLED.value:
|
2672 |
return None, langchain_mode, get_source_files(), ""
|
2673 |
|
2674 |
+
if langchain_mode in [LangChainMode.LLM.value]:
|
2675 |
# then switch to MyData, so langchain_mode also becomes way to select where upload goes
|
2676 |
# but default to mydata if nothing chosen, since safest
|
2677 |
+
if LangChainMode.MY_DATA.value in visible_langchain_modes:
|
2678 |
+
langchain_mode = LangChainMode.MY_DATA.value
|
2679 |
+
|
2680 |
+
if langchain_mode_paths is None:
|
2681 |
+
langchain_mode_paths = {}
|
2682 |
+
user_path = langchain_mode_paths.get(langchain_mode)
|
2683 |
+
# UserData or custom, which has to be from user's disk
|
2684 |
+
if user_path is not None:
|
2685 |
# move temp files from gradio upload to stable location
|
2686 |
for fili, fil in enumerate(file):
|
2687 |
+
if isinstance(fil, str) and os.path.isfile(fil): # not url, text
|
2688 |
+
new_fil = os.path.normpath(os.path.join(user_path, os.path.basename(fil)))
|
2689 |
+
if os.path.normpath(os.path.abspath(fil)) != os.path.normpath(os.path.abspath(new_fil)):
|
2690 |
if os.path.isfile(new_fil):
|
2691 |
remove(new_fil)
|
2692 |
try:
|
|
|
2706 |
enable_captions=enable_captions,
|
2707 |
captions_model=captions_model,
|
2708 |
enable_ocr=enable_ocr,
|
2709 |
+
enable_pdf_ocr=enable_pdf_ocr,
|
2710 |
caption_loader=caption_loader,
|
2711 |
)
|
2712 |
exceptions = [x for x in sources if x.metadata.get('exception')]
|
2713 |
exceptions_strs = [x.metadata['exception'] for x in exceptions]
|
2714 |
sources = [x for x in sources if 'exception' not in x.metadata]
|
2715 |
|
2716 |
+
# below must at least come after langchain_mode is modified in case was LLM -> MyData,
|
2717 |
+
# so original langchain mode changed
|
2718 |
+
for k in db1s:
|
2719 |
+
set_userid(db1s[k])
|
2720 |
+
db1 = get_db1(db1s, langchain_mode)
|
2721 |
+
|
2722 |
+
lock_file = get_lock_file(db1s[LangChainMode.MY_DATA.value], langchain_mode) # user-level lock, not db-level lock
|
2723 |
with filelock.FileLock(lock_file):
|
2724 |
+
if langchain_mode in db1s:
|
2725 |
if db1[0] is not None:
|
2726 |
# then add
|
2727 |
db, num_new_sources, new_sources_metadata = add_to_db(db1[0], sources, db_type=db_type,
|
|
|
2731 |
# in testing expect:
|
2732 |
# assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1
|
2733 |
# for production hit, when user gets clicky:
|
2734 |
+
assert len(db1) == 2, "Bad %s db: %s" % (langchain_mode, db1)
|
2735 |
+
assert db1[1] is not None, "db hash was None, not allowed"
|
2736 |
# then create
|
2737 |
# if added has to original state and didn't change, then would be shared db for all users
|
2738 |
persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
|
|
|
2754 |
use_openai_embedding=use_openai_embedding,
|
2755 |
hf_embedding_model=hf_embedding_model)
|
2756 |
else:
|
2757 |
+
# then create. Or might just be that dbs is unfilled, then it will fill, then add
|
2758 |
db = get_db(sources, use_openai_embedding=use_openai_embedding,
|
2759 |
db_type=db_type,
|
2760 |
persist_directory=persist_directory,
|
|
|
2768 |
return None, langchain_mode, source_files_added, '\n'.join(exceptions_strs)
|
2769 |
|
2770 |
|
2771 |
+
def get_db(db1s, langchain_mode, dbs=None):
|
2772 |
+
db1 = get_db1(db1s, langchain_mode)
|
2773 |
+
lock_file = get_lock_file(db1s[LangChainMode.MY_DATA.value], langchain_mode)
|
2774 |
|
2775 |
with filelock.FileLock(lock_file):
|
2776 |
if langchain_mode in ['wiki_full']:
|
2777 |
# NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
|
2778 |
db = None
|
2779 |
+
elif langchain_mode in db1s and len(db1) == 2 and db1[0] is not None:
|
2780 |
db = db1[0]
|
2781 |
elif dbs is not None and langchain_mode in dbs and dbs[langchain_mode] is not None:
|
2782 |
db = dbs[langchain_mode]
|
|
|
2785 |
return db
|
2786 |
|
2787 |
|
2788 |
+
def get_source_files_given_langchain_mode(db1s, langchain_mode='UserData', dbs=None):
|
2789 |
+
db = get_db(db1s, langchain_mode, dbs=dbs)
|
2790 |
if langchain_mode in ['ChatLLM', 'LLM'] or db is None:
|
2791 |
return "Sources: N/A"
|
2792 |
return get_source_files(db=db, exceptions=None)
|
|
|
2885 |
return source_files_added
|
2886 |
|
2887 |
|
2888 |
+
def update_and_get_source_files_given_langchain_mode(db1s, langchain_mode, chunk, chunk_size,
|
2889 |
+
dbs=None, first_para=None,
|
2890 |
+
text_limit=None,
|
2891 |
+
langchain_mode_paths=None, db_type=None, load_db_if_exists=None,
|
2892 |
n_jobs=None, verbose=None):
|
2893 |
+
has_path = {k: v for k, v in langchain_mode_paths.items() if v}
|
2894 |
+
if langchain_mode in [LangChainMode.LLM.value, LangChainMode.MY_DATA.value]:
|
2895 |
+
# then assume user really meant UserData, to avoid extra clicks in UI,
|
2896 |
+
# since others can't be on disk, except custom user modes, which they should then select to query it
|
2897 |
+
if LangChainMode.USER_DATA.value in has_path:
|
2898 |
+
langchain_mode = LangChainMode.USER_DATA.value
|
2899 |
+
|
2900 |
+
db = get_db(db1s, langchain_mode, dbs=dbs)
|
2901 |
|
2902 |
from gpt_langchain import make_db
|
2903 |
db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=False,
|
|
|
2906 |
chunk=chunk,
|
2907 |
chunk_size=chunk_size,
|
2908 |
langchain_mode=langchain_mode,
|
2909 |
+
langchain_mode_paths=langchain_mode_paths,
|
2910 |
db_type=db_type,
|
2911 |
load_db_if_exists=load_db_if_exists,
|
2912 |
db=db,
|
2913 |
n_jobs=n_jobs,
|
2914 |
verbose=verbose)
|
2915 |
+
# during refreshing, might have "created" new db since not in dbs[] yet, so insert back just in case
|
2916 |
+
# so even if persisted, not kept up-to-date with dbs memory
|
2917 |
+
if langchain_mode in db1s:
|
2918 |
+
db1s[langchain_mode][0] = db
|
2919 |
+
else:
|
2920 |
+
dbs[langchain_mode] = db
|
2921 |
+
|
2922 |
# return only new sources with text saying such
|
2923 |
return get_source_files(db=None, exceptions=None, metadatas=new_sources_metadata)
|
2924 |
+
|
2925 |
+
|
2926 |
+
def get_db1(db1s, langchain_mode1):
|
2927 |
+
if langchain_mode1 in db1s:
|
2928 |
+
db1 = db1s[langchain_mode1]
|
2929 |
+
else:
|
2930 |
+
# indicates to code that not scratch database
|
2931 |
+
db1 = [None, None]
|
2932 |
+
return db1
|
gradio_utils/__init__.py
ADDED
File without changes
|
gradio_utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (134 Bytes). View file
|
|
gradio_utils/__pycache__/css.cpython-310.pyc
CHANGED
Binary files a/gradio_utils/__pycache__/css.cpython-310.pyc and b/gradio_utils/__pycache__/css.cpython-310.pyc differ
|
|
gradio_utils/css.py
CHANGED
@@ -53,4 +53,8 @@ def make_css_base() -> str:
|
|
53 |
margin-bottom: 2.5rem;
|
54 |
}
|
55 |
.chatsmall chatbot {font-size: 10px !important}
|
|
|
|
|
|
|
|
|
56 |
"""
|
|
|
53 |
margin-bottom: 2.5rem;
|
54 |
}
|
55 |
.chatsmall chatbot {font-size: 10px !important}
|
56 |
+
|
57 |
+
.gradio-container {
|
58 |
+
max-width: none !important;
|
59 |
+
}
|
60 |
"""
|
h2oai_pipeline.py
CHANGED
@@ -11,6 +11,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
11 |
def __init__(self, *args, debug=False, chat=False, stream_output=False,
|
12 |
sanitize_bot_response=False,
|
13 |
use_prompter=True, prompter=None,
|
|
|
14 |
prompt_type=None, prompt_dict=None,
|
15 |
max_input_tokens=2048 - 256, **kwargs):
|
16 |
"""
|
@@ -34,6 +35,8 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
34 |
self.prompt_type = prompt_type
|
35 |
self.prompt_dict = prompt_dict
|
36 |
self.prompter = prompter
|
|
|
|
|
37 |
if self.use_prompter:
|
38 |
if self.prompter is not None:
|
39 |
assert self.prompter.prompt_type is not None
|
@@ -113,7 +116,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
113 |
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
|
114 |
prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
|
115 |
|
116 |
-
data_point = dict(context=
|
117 |
if self.prompter is not None:
|
118 |
prompt_text = self.prompter.generate_prompt(data_point)
|
119 |
self.prompt_text = prompt_text
|
|
|
11 |
def __init__(self, *args, debug=False, chat=False, stream_output=False,
|
12 |
sanitize_bot_response=False,
|
13 |
use_prompter=True, prompter=None,
|
14 |
+
context='', iinput='',
|
15 |
prompt_type=None, prompt_dict=None,
|
16 |
max_input_tokens=2048 - 256, **kwargs):
|
17 |
"""
|
|
|
35 |
self.prompt_type = prompt_type
|
36 |
self.prompt_dict = prompt_dict
|
37 |
self.prompter = prompter
|
38 |
+
self.context = context
|
39 |
+
self.iinput = iinput
|
40 |
if self.use_prompter:
|
41 |
if self.prompter is not None:
|
42 |
assert self.prompter.prompt_type is not None
|
|
|
116 |
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
|
117 |
prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
|
118 |
|
119 |
+
data_point = dict(context=self.context, instruction=prompt_text, input=self.iinput)
|
120 |
if self.prompter is not None:
|
121 |
prompt_text = self.prompter.generate_prompt(data_point)
|
122 |
self.prompt_text = prompt_text
|
iterators/__pycache__/timeout_iterator.cpython-310.pyc
CHANGED
Binary files a/iterators/__pycache__/timeout_iterator.cpython-310.pyc and b/iterators/__pycache__/timeout_iterator.cpython-310.pyc differ
|
|
iterators/timeout_iterator.py
CHANGED
@@ -48,7 +48,7 @@ class TimeoutIterator:
|
|
48 |
def interrupt(self):
|
49 |
"""
|
50 |
interrupt and stop the underlying thread.
|
51 |
-
the thread
|
52 |
the underlying iterator yields a value after that.
|
53 |
"""
|
54 |
self._interrupt = True
|
|
|
48 |
def interrupt(self):
|
49 |
"""
|
50 |
interrupt and stop the underlying thread.
|
51 |
+
the thread actually dies only after interrupt has been set and
|
52 |
the underlying iterator yields a value after that.
|
53 |
"""
|
54 |
self._interrupt = True
|
prompter.py
CHANGED
@@ -77,6 +77,12 @@ prompt_type_to_model_name = {
|
|
77 |
"mptchat": ['mosaicml/mpt-7b-chat', 'mosaicml/mpt-30b-chat', 'TheBloke/mpt-30B-chat-GGML'],
|
78 |
"vicuna11": ['lmsys/vicuna-33b-v1.3'],
|
79 |
"falcon": ['tiiuae/falcon-40b-instruct', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-7b'],
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
# could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
|
81 |
}
|
82 |
if os.getenv('OPENAI_API_KEY'):
|
@@ -582,6 +588,42 @@ ASSISTANT:
|
|
582 |
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
583 |
PreResponse = PreResponse
|
584 |
# generates_leading_space = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
585 |
else:
|
586 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
587 |
|
@@ -810,9 +852,20 @@ class Prompter(object):
|
|
810 |
if oi > 0:
|
811 |
# post fix outputs with seperator
|
812 |
output += '\n'
|
|
|
813 |
outputs[oi] = output
|
814 |
# join all outputs, only one extra new line between outputs
|
815 |
output = '\n'.join(outputs)
|
816 |
if self.debug:
|
817 |
print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
|
818 |
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
"mptchat": ['mosaicml/mpt-7b-chat', 'mosaicml/mpt-30b-chat', 'TheBloke/mpt-30B-chat-GGML'],
|
78 |
"vicuna11": ['lmsys/vicuna-33b-v1.3'],
|
79 |
"falcon": ['tiiuae/falcon-40b-instruct', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-7b'],
|
80 |
+
"llama2": [
|
81 |
+
'meta-llama/Llama-2-7b-chat-hf',
|
82 |
+
'meta-llama/Llama-2-13b-chat-hf',
|
83 |
+
'meta-llama/Llama-2-34b-chat-hf',
|
84 |
+
'meta-llama/Llama-2-70b-chat-hf',
|
85 |
+
],
|
86 |
# could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
|
87 |
}
|
88 |
if os.getenv('OPENAI_API_KEY'):
|
|
|
588 |
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
589 |
PreResponse = PreResponse
|
590 |
# generates_leading_space = True
|
591 |
+
elif prompt_type in [PromptType.guanaco.value, str(PromptType.guanaco.value),
|
592 |
+
PromptType.guanaco.name]:
|
593 |
+
# https://huggingface.co/TheBloke/guanaco-65B-GPTQ
|
594 |
+
promptA = promptB = "" if not (chat and reduced) else ''
|
595 |
+
|
596 |
+
PreInstruct = """### Human: """
|
597 |
+
|
598 |
+
PreInput = None
|
599 |
+
|
600 |
+
PreResponse = """### Assistant:"""
|
601 |
+
terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
602 |
+
chat_turn_sep = chat_sep = '\n'
|
603 |
+
humanstr = PreInstruct
|
604 |
+
botstr = PreResponse
|
605 |
+
elif prompt_type in [PromptType.llama2.value, str(PromptType.llama2.value),
|
606 |
+
PromptType.llama2.name]:
|
607 |
+
PreInstruct = ""
|
608 |
+
llama2_sys = "<<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
|
609 |
+
prompt = "<s>[INST] "
|
610 |
+
enable_sys = False # too much safety, hurts accuracy
|
611 |
+
if not (chat and reduced):
|
612 |
+
if enable_sys:
|
613 |
+
promptA = promptB = prompt + llama2_sys
|
614 |
+
else:
|
615 |
+
promptA = promptB = prompt
|
616 |
+
else:
|
617 |
+
promptA = promptB = ''
|
618 |
+
PreInput = None
|
619 |
+
PreResponse = ""
|
620 |
+
terminate_response = ["[INST]", "</s>"]
|
621 |
+
chat_sep = ' [/INST]'
|
622 |
+
chat_turn_sep = ' </s><s>[INST] '
|
623 |
+
humanstr = PreInstruct
|
624 |
+
botstr = PreResponse
|
625 |
+
if making_context:
|
626 |
+
PreResponse += " "
|
627 |
else:
|
628 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
629 |
|
|
|
852 |
if oi > 0:
|
853 |
# post fix outputs with seperator
|
854 |
output += '\n'
|
855 |
+
output = self.fix_text(self.prompt_type, output)
|
856 |
outputs[oi] = output
|
857 |
# join all outputs, only one extra new line between outputs
|
858 |
output = '\n'.join(outputs)
|
859 |
if self.debug:
|
860 |
print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
|
861 |
return output
|
862 |
+
|
863 |
+
@staticmethod
|
864 |
+
def fix_text(prompt_type1, text1):
|
865 |
+
if prompt_type1 == 'human_bot':
|
866 |
+
# hack bug in vLLM with stopping, stops right, but doesn't return last token
|
867 |
+
hfix = '<human'
|
868 |
+
if text1.endswith(hfix):
|
869 |
+
text1 = text1[:-len(hfix)]
|
870 |
+
return text1
|
871 |
+
|
requirements.txt
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
# for generate (gradio server) and finetune
|
2 |
datasets==2.13.0
|
3 |
sentencepiece==0.1.99
|
4 |
-
gradio==3.
|
5 |
-
huggingface_hub==0.
|
6 |
appdirs==1.4.4
|
7 |
fire==0.5.0
|
8 |
docutils==0.20.1
|
@@ -19,7 +19,7 @@ matplotlib==3.7.1
|
|
19 |
loralib==0.1.1
|
20 |
bitsandbytes==0.39.0
|
21 |
accelerate==0.20.3
|
22 |
-
|
23 |
transformers==4.30.2
|
24 |
tokenizers==0.13.3
|
25 |
APScheduler==3.10.1
|
@@ -35,7 +35,7 @@ tensorboard==2.13.0
|
|
35 |
neptune==1.2.0
|
36 |
|
37 |
# for gradio client
|
38 |
-
gradio_client==0.2.
|
39 |
beautifulsoup4==4.12.2
|
40 |
markdown==3.4.3
|
41 |
|
@@ -64,8 +64,8 @@ tiktoken==0.4.0
|
|
64 |
# optional: for OpenAI endpoint or embeddings (requires key)
|
65 |
openai==0.27.8
|
66 |
# optional for chat with PDF
|
67 |
-
langchain==0.0.
|
68 |
-
pypdf==3.
|
69 |
# avoid textract, requires old six
|
70 |
#textract==1.6.5
|
71 |
|
@@ -78,10 +78,10 @@ chromadb==0.3.25
|
|
78 |
#pymilvus==2.2.8
|
79 |
|
80 |
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
81 |
-
# unstructured==0.
|
82 |
|
83 |
# strong support for images
|
84 |
-
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
|
85 |
unstructured[local-inference]==0.7.4
|
86 |
#pdf2image==1.16.3
|
87 |
#pytesseract==0.3.10
|
@@ -104,10 +104,10 @@ tabulate==0.9.0
|
|
104 |
pip-licenses==4.3.0
|
105 |
|
106 |
# weaviate vector db
|
107 |
-
weaviate-client==3.
|
108 |
# optional for chat with PDF
|
109 |
-
langchain==0.0.
|
110 |
-
pypdf==3.
|
111 |
# avoid textract, requires old six
|
112 |
#textract==1.6.5
|
113 |
|
@@ -120,10 +120,10 @@ chromadb==0.3.25
|
|
120 |
#pymilvus==2.2.8
|
121 |
|
122 |
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
123 |
-
# unstructured==0.
|
124 |
|
125 |
# strong support for images
|
126 |
-
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
|
127 |
unstructured[local-inference]==0.7.4
|
128 |
#pdf2image==1.16.3
|
129 |
#pytesseract==0.3.10
|
@@ -146,8 +146,8 @@ tabulate==0.9.0
|
|
146 |
pip-licenses==4.3.0
|
147 |
|
148 |
# weaviate vector db
|
149 |
-
weaviate-client==3.
|
150 |
faiss-gpu==1.7.2
|
151 |
-
arxiv==1.4.
|
152 |
-
pymupdf==1.22.
|
153 |
# extract-msg==0.41.1 # GPL3
|
|
|
1 |
# for generate (gradio server) and finetune
|
2 |
datasets==2.13.0
|
3 |
sentencepiece==0.1.99
|
4 |
+
gradio==3.37.0
|
5 |
+
huggingface_hub==0.16.4
|
6 |
appdirs==1.4.4
|
7 |
fire==0.5.0
|
8 |
docutils==0.20.1
|
|
|
19 |
loralib==0.1.1
|
20 |
bitsandbytes==0.39.0
|
21 |
accelerate==0.20.3
|
22 |
+
peft==0.4.0
|
23 |
transformers==4.30.2
|
24 |
tokenizers==0.13.3
|
25 |
APScheduler==3.10.1
|
|
|
35 |
neptune==1.2.0
|
36 |
|
37 |
# for gradio client
|
38 |
+
gradio_client==0.2.10
|
39 |
beautifulsoup4==4.12.2
|
40 |
markdown==3.4.3
|
41 |
|
|
|
64 |
# optional: for OpenAI endpoint or embeddings (requires key)
|
65 |
openai==0.27.8
|
66 |
# optional for chat with PDF
|
67 |
+
langchain==0.0.235
|
68 |
+
pypdf==3.12.2
|
69 |
# avoid textract, requires old six
|
70 |
#textract==1.6.5
|
71 |
|
|
|
78 |
#pymilvus==2.2.8
|
79 |
|
80 |
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
81 |
+
# unstructured==0.8.1
|
82 |
|
83 |
# strong support for images
|
84 |
+
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
|
85 |
unstructured[local-inference]==0.7.4
|
86 |
#pdf2image==1.16.3
|
87 |
#pytesseract==0.3.10
|
|
|
104 |
pip-licenses==4.3.0
|
105 |
|
106 |
# weaviate vector db
|
107 |
+
weaviate-client==3.22.1
|
108 |
# optional for chat with PDF
|
109 |
+
langchain==0.0.235
|
110 |
+
pypdf==3.12.2
|
111 |
# avoid textract, requires old six
|
112 |
#textract==1.6.5
|
113 |
|
|
|
120 |
#pymilvus==2.2.8
|
121 |
|
122 |
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
123 |
+
# unstructured==0.8.1
|
124 |
|
125 |
# strong support for images
|
126 |
+
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
|
127 |
unstructured[local-inference]==0.7.4
|
128 |
#pdf2image==1.16.3
|
129 |
#pytesseract==0.3.10
|
|
|
146 |
pip-licenses==4.3.0
|
147 |
|
148 |
# weaviate vector db
|
149 |
+
weaviate-client==3.22.1
|
150 |
faiss-gpu==1.7.2
|
151 |
+
arxiv==1.4.8
|
152 |
+
pymupdf==1.22.5 # AGPL license
|
153 |
# extract-msg==0.41.1 # GPL3
|
utils.py
CHANGED
@@ -5,6 +5,7 @@ import inspect
|
|
5 |
import os
|
6 |
import gc
|
7 |
import pathlib
|
|
|
8 |
import random
|
9 |
import shutil
|
10 |
import subprocess
|
@@ -111,12 +112,15 @@ def system_info():
|
|
111 |
system = {}
|
112 |
# https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
|
113 |
# https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
|
114 |
-
|
115 |
-
|
116 |
-
coretemp
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
120 |
|
121 |
# https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
|
122 |
try:
|
@@ -779,6 +783,9 @@ def _traced_func(func, *args, **kwargs):
|
|
779 |
|
780 |
|
781 |
def call_subprocess_onetask(func, args=None, kwargs=None):
|
|
|
|
|
|
|
782 |
if isinstance(args, list):
|
783 |
args = tuple(args)
|
784 |
if args is None:
|
@@ -950,7 +957,6 @@ try:
|
|
950 |
except (pkg_resources.DistributionNotFound, AssertionError):
|
951 |
have_langchain = False
|
952 |
|
953 |
-
|
954 |
import distutils.spawn
|
955 |
|
956 |
have_tesseract = distutils.spawn.find_executable("tesseract")
|
@@ -985,3 +991,90 @@ except (pkg_resources.DistributionNotFound, AssertionError):
|
|
985 |
|
986 |
# disable, hangs too often
|
987 |
have_playwright = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import os
|
6 |
import gc
|
7 |
import pathlib
|
8 |
+
import pickle
|
9 |
import random
|
10 |
import shutil
|
11 |
import subprocess
|
|
|
112 |
system = {}
|
113 |
# https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
|
114 |
# https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
|
115 |
+
try:
|
116 |
+
temps = psutil.sensors_temperatures(fahrenheit=False)
|
117 |
+
if 'coretemp' in temps:
|
118 |
+
coretemp = temps['coretemp']
|
119 |
+
temp_dict = {k.label: k.current for k in coretemp}
|
120 |
+
for k, v in temp_dict.items():
|
121 |
+
system['CPU_C/%s' % k] = v
|
122 |
+
except AttributeError:
|
123 |
+
pass
|
124 |
|
125 |
# https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
|
126 |
try:
|
|
|
783 |
|
784 |
|
785 |
def call_subprocess_onetask(func, args=None, kwargs=None):
|
786 |
+
import platform
|
787 |
+
if platform.system() in ['Darwin', 'Windows']:
|
788 |
+
return func(*args, **kwargs)
|
789 |
if isinstance(args, list):
|
790 |
args = tuple(args)
|
791 |
if args is None:
|
|
|
957 |
except (pkg_resources.DistributionNotFound, AssertionError):
|
958 |
have_langchain = False
|
959 |
|
|
|
960 |
import distutils.spawn
|
961 |
|
962 |
have_tesseract = distutils.spawn.find_executable("tesseract")
|
|
|
991 |
|
992 |
# disable, hangs too often
|
993 |
have_playwright = False
|
994 |
+
|
995 |
+
|
996 |
+
def set_openai(inference_server):
|
997 |
+
if inference_server.startswith('vllm'):
|
998 |
+
import openai_vllm
|
999 |
+
openai_vllm.api_key = "EMPTY"
|
1000 |
+
inf_type = inference_server.split(':')[0]
|
1001 |
+
ip_vllm = inference_server.split(':')[1]
|
1002 |
+
port_vllm = inference_server.split(':')[2]
|
1003 |
+
openai_vllm.api_base = f"http://{ip_vllm}:{port_vllm}/v1"
|
1004 |
+
return openai_vllm, inf_type
|
1005 |
+
else:
|
1006 |
+
import openai
|
1007 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
1008 |
+
openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
|
1009 |
+
inf_type = inference_server
|
1010 |
+
return openai, inf_type
|
1011 |
+
|
1012 |
+
|
1013 |
+
visible_langchain_modes_file = 'visible_langchain_modes.pkl'
|
1014 |
+
|
1015 |
+
|
1016 |
+
def save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode, db1s):
|
1017 |
+
"""
|
1018 |
+
extra controls if UserData type of MyData type
|
1019 |
+
"""
|
1020 |
+
|
1021 |
+
# use first default MyData hash as general user hash to maintain file
|
1022 |
+
# if user moves MyData from langchain modes, db will still survive, so can still use hash
|
1023 |
+
scratch_collection_names = list(db1s.keys())
|
1024 |
+
user_hash = db1s.get(LangChainMode.MY_DATA.value, '')[1]
|
1025 |
+
|
1026 |
+
llms = ['ChatLLM', 'LLM', 'Disabled']
|
1027 |
+
|
1028 |
+
scratch_langchain_modes = [x for x in langchain_modes if x in scratch_collection_names]
|
1029 |
+
scratch_visible_langchain_modes = [x for x in visible_langchain_modes if x in scratch_collection_names]
|
1030 |
+
scratch_langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if
|
1031 |
+
k in scratch_collection_names and k not in llms}
|
1032 |
+
|
1033 |
+
user_langchain_modes = [x for x in langchain_modes if x not in scratch_collection_names]
|
1034 |
+
user_visible_langchain_modes = [x for x in visible_langchain_modes if x not in scratch_collection_names]
|
1035 |
+
user_langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if
|
1036 |
+
k not in scratch_collection_names and k not in llms}
|
1037 |
+
|
1038 |
+
base_path = 'locks'
|
1039 |
+
makedirs(base_path)
|
1040 |
+
|
1041 |
+
# user
|
1042 |
+
extra = ''
|
1043 |
+
file = "%s%s" % (visible_langchain_modes_file, extra)
|
1044 |
+
with filelock.FileLock(os.path.join(base_path, "%s.lock" % file)):
|
1045 |
+
with open(file, 'wb') as f:
|
1046 |
+
pickle.dump((user_langchain_modes, user_visible_langchain_modes, user_langchain_mode_paths), f)
|
1047 |
+
|
1048 |
+
# scratch
|
1049 |
+
extra = user_hash
|
1050 |
+
file = "%s%s" % (visible_langchain_modes_file, extra)
|
1051 |
+
with filelock.FileLock(os.path.join(base_path, "%s.lock" % file)):
|
1052 |
+
with open(file, 'wb') as f:
|
1053 |
+
pickle.dump((scratch_langchain_modes, scratch_visible_langchain_modes, scratch_langchain_mode_paths), f)
|
1054 |
+
|
1055 |
+
|
1056 |
+
def load_collection_enum(extra):
|
1057 |
+
"""
|
1058 |
+
extra controls if UserData type of MyData type
|
1059 |
+
"""
|
1060 |
+
file = "%s%s" % (visible_langchain_modes_file, extra)
|
1061 |
+
langchain_modes_from_file = []
|
1062 |
+
visible_langchain_modes_from_file = []
|
1063 |
+
langchain_mode_paths_from_file = {}
|
1064 |
+
if os.path.isfile(visible_langchain_modes_file):
|
1065 |
+
try:
|
1066 |
+
with filelock.FileLock("%s.lock" % file):
|
1067 |
+
with open(file, 'rb') as f:
|
1068 |
+
langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file = pickle.load(
|
1069 |
+
f)
|
1070 |
+
except BaseException as e:
|
1071 |
+
print("Cannot load %s, ignoring error: %s" % (file, str(e)), flush=True)
|
1072 |
+
for k, v in langchain_mode_paths_from_file.items():
|
1073 |
+
if v is not None and not os.path.isdir(v) and isinstance(v, str):
|
1074 |
+
# assume was deleted, but need to make again to avoid extra code elsewhere
|
1075 |
+
makedirs(v)
|
1076 |
+
return langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file
|
1077 |
+
|
1078 |
+
|
1079 |
+
def remove_collection_enum():
|
1080 |
+
remove(visible_langchain_modes_file)
|