Spaces:
Runtime error
Runtime error
pseudotensor
commited on
Commit
•
9bcca78
1
Parent(s):
63ce3fa
Update with h2oGPT hash 1b295baace42908075b47f31a84b359d8c6b1e52
Browse files- client_test.py +5 -1
- finetune.py +1 -1
- generate.py +6 -3
- gpt_langchain.py +21 -4
- gradio_runner.py +144 -87
- prompter.py +2 -0
- utils.py +1 -1
client_test.py
CHANGED
@@ -3,7 +3,7 @@ Client test.
|
|
3 |
|
4 |
Run server:
|
5 |
|
6 |
-
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-
|
7 |
|
8 |
NOTE: For private models, add --use-auth_token=True
|
9 |
|
@@ -39,6 +39,7 @@ Loaded as API: https://gpt.h2o.ai ✔
|
|
39 |
import time
|
40 |
import os
|
41 |
import markdown # pip install markdown
|
|
|
42 |
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
43 |
|
44 |
debug = False
|
@@ -79,6 +80,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token
|
|
79 |
instruction_nochat=prompt if not chat else '',
|
80 |
iinput_nochat='', # only for chat=False
|
81 |
langchain_mode='Disabled',
|
|
|
82 |
)
|
83 |
if chat:
|
84 |
# add chatbot output on end. Assumes serialize=False
|
@@ -87,6 +89,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token
|
|
87 |
return kwargs, list(kwargs.values())
|
88 |
|
89 |
|
|
|
90 |
def test_client_basic():
|
91 |
return run_client_nochat(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
|
92 |
|
@@ -106,6 +109,7 @@ def run_client_nochat(prompt, prompt_type, max_new_tokens):
|
|
106 |
return res_dict
|
107 |
|
108 |
|
|
|
109 |
def test_client_chat():
|
110 |
return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50)
|
111 |
|
|
|
3 |
|
4 |
Run server:
|
5 |
|
6 |
+
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
|
7 |
|
8 |
NOTE: For private models, add --use-auth_token=True
|
9 |
|
|
|
39 |
import time
|
40 |
import os
|
41 |
import markdown # pip install markdown
|
42 |
+
import pytest
|
43 |
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
44 |
|
45 |
debug = False
|
|
|
80 |
instruction_nochat=prompt if not chat else '',
|
81 |
iinput_nochat='', # only for chat=False
|
82 |
langchain_mode='Disabled',
|
83 |
+
document_choice=['All'],
|
84 |
)
|
85 |
if chat:
|
86 |
# add chatbot output on end. Assumes serialize=False
|
|
|
89 |
return kwargs, list(kwargs.values())
|
90 |
|
91 |
|
92 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
93 |
def test_client_basic():
|
94 |
return run_client_nochat(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
|
95 |
|
|
|
109 |
return res_dict
|
110 |
|
111 |
|
112 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
113 |
def test_client_chat():
|
114 |
return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50)
|
115 |
|
finetune.py
CHANGED
@@ -26,7 +26,7 @@ def train(
|
|
26 |
save_code: bool = False,
|
27 |
run_id: int = None,
|
28 |
|
29 |
-
base_model: str = 'h2oai/h2ogpt-oig-oasst1-512-
|
30 |
# base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
|
31 |
# base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
|
32 |
# base_model: str = 'EleutherAI/gpt-neox-20b',
|
|
|
26 |
save_code: bool = False,
|
27 |
run_id: int = None,
|
28 |
|
29 |
+
base_model: str = 'h2oai/h2ogpt-oig-oasst1-512-6_9b',
|
30 |
# base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
|
31 |
# base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
|
32 |
# base_model: str = 'EleutherAI/gpt-neox-20b',
|
generate.py
CHANGED
@@ -297,7 +297,7 @@ def main(
|
|
297 |
if psutil.virtual_memory().available < 94 * 1024 ** 3:
|
298 |
# 12B uses ~94GB
|
299 |
# 6.9B uses ~47GB
|
300 |
-
base_model = 'h2oai/h2ogpt-oig-oasst1-512-
|
301 |
|
302 |
# get defaults
|
303 |
model_lower = base_model.lower()
|
@@ -864,6 +864,7 @@ eval_func_param_names = ['instruction',
|
|
864 |
'instruction_nochat',
|
865 |
'iinput_nochat',
|
866 |
'langchain_mode',
|
|
|
867 |
]
|
868 |
|
869 |
|
@@ -891,6 +892,7 @@ def evaluate(
|
|
891 |
instruction_nochat,
|
892 |
iinput_nochat,
|
893 |
langchain_mode,
|
|
|
894 |
# END NOTE: Examples must have same order of parameters
|
895 |
src_lang=None,
|
896 |
tgt_lang=None,
|
@@ -1010,6 +1012,7 @@ def evaluate(
|
|
1010 |
chunk=chunk,
|
1011 |
chunk_size=chunk_size,
|
1012 |
langchain_mode=langchain_mode,
|
|
|
1013 |
db_type=db_type,
|
1014 |
k=k,
|
1015 |
temperature=temperature,
|
@@ -1446,7 +1449,7 @@ y = np.random.randint(0, 1, 100)
|
|
1446 |
|
1447 |
# move to correct position
|
1448 |
for example in examples:
|
1449 |
-
example += [chat, '', '', 'Disabled']
|
1450 |
# adjust examples if non-chat mode
|
1451 |
if not chat:
|
1452 |
example[eval_func_param_names.index('instruction_nochat')] = example[
|
@@ -1546,6 +1549,6 @@ if __name__ == "__main__":
|
|
1546 |
can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
|
1547 |
python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
|
1548 |
|
1549 |
-
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-
|
1550 |
"""
|
1551 |
fire.Fire(main)
|
|
|
297 |
if psutil.virtual_memory().available < 94 * 1024 ** 3:
|
298 |
# 12B uses ~94GB
|
299 |
# 6.9B uses ~47GB
|
300 |
+
base_model = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' if not base_model else base_model
|
301 |
|
302 |
# get defaults
|
303 |
model_lower = base_model.lower()
|
|
|
864 |
'instruction_nochat',
|
865 |
'iinput_nochat',
|
866 |
'langchain_mode',
|
867 |
+
'document_choice',
|
868 |
]
|
869 |
|
870 |
|
|
|
892 |
instruction_nochat,
|
893 |
iinput_nochat,
|
894 |
langchain_mode,
|
895 |
+
document_choice,
|
896 |
# END NOTE: Examples must have same order of parameters
|
897 |
src_lang=None,
|
898 |
tgt_lang=None,
|
|
|
1012 |
chunk=chunk,
|
1013 |
chunk_size=chunk_size,
|
1014 |
langchain_mode=langchain_mode,
|
1015 |
+
document_choice=document_choice,
|
1016 |
db_type=db_type,
|
1017 |
k=k,
|
1018 |
temperature=temperature,
|
|
|
1449 |
|
1450 |
# move to correct position
|
1451 |
for example in examples:
|
1452 |
+
example += [chat, '', '', 'Disabled', ['All']]
|
1453 |
# adjust examples if non-chat mode
|
1454 |
if not chat:
|
1455 |
example[eval_func_param_names.index('instruction_nochat')] = example[
|
|
|
1549 |
can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
|
1550 |
python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
|
1551 |
|
1552 |
+
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
|
1553 |
"""
|
1554 |
fire.Fire(main)
|
gpt_langchain.py
CHANGED
@@ -150,7 +150,7 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
|
|
150 |
assert model_name is None
|
151 |
assert tokenizer is None
|
152 |
model_name = 'h2oai/h2ogpt-oasst1-512-12b'
|
153 |
-
# model_name = 'h2oai/h2ogpt-oig-oasst1-512-
|
154 |
# model_name = 'h2oai/h2ogpt-oasst1-512-20b'
|
155 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
156 |
device, torch_dtype, context_class = get_device_dtype()
|
@@ -593,7 +593,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
593 |
):
|
594 |
globs_image_types = []
|
595 |
globs_non_image_types = []
|
596 |
-
if path_or_paths
|
597 |
return []
|
598 |
elif url:
|
599 |
globs_non_image_types = [url]
|
@@ -846,6 +846,7 @@ def _run_qa_db(query=None,
|
|
846 |
top_k=40,
|
847 |
top_p=0.7,
|
848 |
langchain_mode=None,
|
|
|
849 |
n_jobs=-1):
|
850 |
"""
|
851 |
|
@@ -917,7 +918,23 @@ def _run_qa_db(query=None,
|
|
917 |
k_db = 1000 if db_type == 'chroma' else k # k=100 works ok too for
|
918 |
|
919 |
if db and use_context:
|
920 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
921 |
# cut off so no high distance docs/sources considered
|
922 |
docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
|
923 |
scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
|
@@ -939,7 +956,7 @@ def _run_qa_db(query=None,
|
|
939 |
reduced_query_words = reduced_query.split(' ')
|
940 |
set_common = set(df['Lemma'].values.tolist())
|
941 |
num_common = len([x.lower() in set_common for x in reduced_query_words])
|
942 |
-
frac_common = num_common / len(reduced_query)
|
943 |
# FIXME: report to user bad query that uses too many common words
|
944 |
print("frac_common: %s" % frac_common, flush=True)
|
945 |
|
|
|
150 |
assert model_name is None
|
151 |
assert tokenizer is None
|
152 |
model_name = 'h2oai/h2ogpt-oasst1-512-12b'
|
153 |
+
# model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
|
154 |
# model_name = 'h2oai/h2ogpt-oasst1-512-20b'
|
155 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
156 |
device, torch_dtype, context_class = get_device_dtype()
|
|
|
593 |
):
|
594 |
globs_image_types = []
|
595 |
globs_non_image_types = []
|
596 |
+
if not path_or_paths and not url and not text:
|
597 |
return []
|
598 |
elif url:
|
599 |
globs_non_image_types = [url]
|
|
|
846 |
top_k=40,
|
847 |
top_p=0.7,
|
848 |
langchain_mode=None,
|
849 |
+
document_choice=['All'],
|
850 |
n_jobs=-1):
|
851 |
"""
|
852 |
|
|
|
918 |
k_db = 1000 if db_type == 'chroma' else k # k=100 works ok too for
|
919 |
|
920 |
if db and use_context:
|
921 |
+
if isinstance(document_choice, str):
|
922 |
+
# support string as well
|
923 |
+
document_choice = [document_choice]
|
924 |
+
if not isinstance(db, Chroma) or len(document_choice) <= 1 and document_choice[0].lower() == 'all':
|
925 |
+
# treat empty list as All for now, not 'None'
|
926 |
+
filter_kwargs = {}
|
927 |
+
else:
|
928 |
+
if len(document_choice) >= 2:
|
929 |
+
or_filter = [{"source": {"$eq": x}} for x in document_choice]
|
930 |
+
filter_kwargs = dict(filter={"$or": or_filter})
|
931 |
+
else:
|
932 |
+
one_filter = [{"source": {"$eq": x}} for x in document_choice][0]
|
933 |
+
filter_kwargs = dict(filter=one_filter)
|
934 |
+
if len(document_choice) == 1 and document_choice[0].lower() == 'none':
|
935 |
+
k_db = 1
|
936 |
+
k = 0
|
937 |
+
docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:k]
|
938 |
# cut off so no high distance docs/sources considered
|
939 |
docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
|
940 |
scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
|
|
|
956 |
reduced_query_words = reduced_query.split(' ')
|
957 |
set_common = set(df['Lemma'].values.tolist())
|
958 |
num_common = len([x.lower() in set_common for x in reduced_query_words])
|
959 |
+
frac_common = num_common / len(reduced_query) if reduced_query else 0
|
960 |
# FIXME: report to user bad query that uses too many common words
|
961 |
print("frac_common: %s" % frac_common, flush=True)
|
962 |
|
gradio_runner.py
CHANGED
@@ -96,7 +96,13 @@ def go_gradio(**kwargs):
|
|
96 |
css_code = """footer {visibility: hidden}"""
|
97 |
css_code += """
|
98 |
body.dark{#warning {background-color: #555555};}
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
if kwargs['gradio_avoid_processing_markdown']:
|
102 |
from gradio_client import utils as client_utils
|
@@ -167,6 +173,7 @@ body.dark{#warning {background-color: #555555};}
|
|
167 |
lora_options_state = gr.State([lora_options])
|
168 |
my_db_state = gr.State([None, None])
|
169 |
chat_state = gr.State({})
|
|
|
170 |
gr.Markdown(f"""
|
171 |
{get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
|
172 |
|
@@ -175,7 +182,7 @@ body.dark{#warning {background-color: #555555};}
|
|
175 |
""")
|
176 |
if is_hf:
|
177 |
gr.HTML(
|
178 |
-
|
179 |
|
180 |
# go button visible if
|
181 |
base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
|
@@ -220,7 +227,7 @@ body.dark{#warning {background-color: #555555};}
|
|
220 |
submit = gr.Button(value='Submit').style(full_width=False, size='sm')
|
221 |
stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
|
222 |
with gr.Row():
|
223 |
-
clear = gr.Button("Save
|
224 |
flag_btn = gr.Button("Flag")
|
225 |
if not kwargs['auto_score']: # FIXME: For checkbox model2
|
226 |
with gr.Column(visible=kwargs['score_model']):
|
@@ -251,19 +258,16 @@ body.dark{#warning {background-color: #555555};}
|
|
251 |
radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
|
252 |
type='value')
|
253 |
with gr.Row():
|
254 |
-
remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True)
|
255 |
clear_chat_btn = gr.Button(value="Clear Chat", visible=True)
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
with chats_row2:
|
262 |
chatsup_output = gr.File(label="Upload Chat File(s)",
|
263 |
file_types=['.json'],
|
264 |
file_count='multiple',
|
265 |
elem_id="warning", elem_classes="feedback")
|
266 |
-
add_to_chats_btn = gr.Button("Add File(s) to Chats")
|
267 |
with gr.TabItem("Data Source"):
|
268 |
langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/README_LangChain.md',
|
269 |
from_str=True)
|
@@ -275,8 +279,8 @@ body.dark{#warning {background-color: #555555};}
|
|
275 |
<p>
|
276 |
For more options see: {langchain_readme}""",
|
277 |
visible=kwargs['langchain_mode'] == 'Disabled', interactive=False)
|
278 |
-
|
279 |
-
with
|
280 |
if is_hf:
|
281 |
# don't show 'wiki' since only usually useful for internal testing at moment
|
282 |
no_show_modes = ['Disabled', 'wiki']
|
@@ -292,77 +296,92 @@ body.dark{#warning {background-color: #555555};}
|
|
292 |
langchain_mode = gr.Radio(
|
293 |
[x for x in langchain_modes if x in allowed_modes and x not in no_show_modes],
|
294 |
value=kwargs['langchain_mode'],
|
295 |
-
label="Data
|
296 |
visible=kwargs['langchain_mode'] != 'Disabled')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
-
def upload_file(files, x):
|
299 |
-
file_paths = [file.name for file in files]
|
300 |
-
return files, file_paths
|
301 |
-
|
302 |
-
upload_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload).style(
|
303 |
-
equal_height=False)
|
304 |
# import control
|
305 |
if kwargs['langchain_mode'] != 'Disabled':
|
306 |
from gpt_langchain import file_types, have_arxiv
|
307 |
else:
|
308 |
have_arxiv = False
|
309 |
file_types = []
|
310 |
-
|
311 |
-
|
312 |
-
fileup_output = gr.File(label=f'Upload {file_types_str}',
|
313 |
-
file_types=file_types,
|
314 |
-
file_count="multiple",
|
315 |
-
elem_id="warning", elem_classes="feedback")
|
316 |
-
with gr.Row():
|
317 |
-
upload_button = gr.UploadButton("Upload %s" % file_types_str,
|
318 |
-
file_types=file_types,
|
319 |
-
file_count="multiple",
|
320 |
-
visible=False,
|
321 |
-
)
|
322 |
-
# add not visible until upload something
|
323 |
-
with gr.Column():
|
324 |
-
add_to_shared_db_btn = gr.Button("Add File(s) to Shared UserData DB",
|
325 |
-
visible=allow_upload_to_user_data) # and False)
|
326 |
-
add_to_my_db_btn = gr.Button("Add File(s) to Scratch MyData DB",
|
327 |
-
visible=allow_upload_to_my_data) # and False)
|
328 |
-
url_row = gr.Row(
|
329 |
-
visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload).style(
|
330 |
equal_height=False)
|
331 |
-
with
|
332 |
-
url_label = 'URL (http/https) or ArXiv:' if have_arxiv else 'URL (http/https)'
|
333 |
-
url_text = gr.Textbox(label=url_label, interactive=True)
|
334 |
with gr.Column():
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
equal_height=False)
|
342 |
-
with
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
visible=allow_upload_to_my_data)
|
349 |
-
# WIP:
|
350 |
-
with gr.Row(visible=False).style(equal_height=False):
|
351 |
-
github_textbox = gr.Textbox(label="Github URL")
|
352 |
-
with gr.Row(visible=True):
|
353 |
-
github_shared_btn = gr.Button(value="Add Github to Shared UserData DB",
|
354 |
-
visible=allow_upload_to_user_data)
|
355 |
-
github_my_btn = gr.Button(value="Add Github to Scratch MyData DB",
|
356 |
-
visible=allow_upload_to_my_data)
|
357 |
sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
|
358 |
equal_height=False)
|
359 |
with sources_row:
|
360 |
sources_text = gr.HTML(label='Sources Added', interactive=False)
|
361 |
-
sources_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
|
362 |
-
equal_height=False)
|
363 |
-
with sources_row2:
|
364 |
-
get_sources_btn = gr.Button(value="Get Sources List for Selected DB")
|
365 |
-
file_source = gr.File(interactive=False, label="Download File with list of Sources")
|
366 |
|
367 |
with gr.TabItem("Expert"):
|
368 |
with gr.Row():
|
@@ -545,14 +564,6 @@ body.dark{#warning {background-color: #555555};}
|
|
545 |
def make_visible():
|
546 |
return gr.update(visible=True)
|
547 |
|
548 |
-
# add itself to output to ensure shows working and can't click again
|
549 |
-
upload_button.upload(upload_file, inputs=[upload_button, fileup_output],
|
550 |
-
outputs=[upload_button, fileup_output], queue=queue,
|
551 |
-
api_name='upload_file' if allow_api else None) \
|
552 |
-
.then(make_add_visible, fileup_output, add_to_shared_db_btn, queue=queue) \
|
553 |
-
.then(make_add_visible, fileup_output, add_to_my_db_btn, queue=queue) \
|
554 |
-
.then(make_invisible, outputs=upload_button, queue=queue)
|
555 |
-
|
556 |
# Add to UserData
|
557 |
update_user_db_func = functools.partial(update_user_db, dbs=dbs, db_type=db_type, langchain_mode='UserData',
|
558 |
use_openai_embedding=use_openai_embedding,
|
@@ -623,8 +634,23 @@ body.dark{#warning {background-color: #555555};}
|
|
623 |
.then(clear_textbox, outputs=user_text_text, queue=queue)
|
624 |
|
625 |
get_sources1 = functools.partial(get_sources, dbs=dbs)
|
626 |
-
|
627 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
628 |
|
629 |
def check_admin_pass(x):
|
630 |
return gr.update(visible=x == admin_pass)
|
@@ -818,6 +844,11 @@ body.dark{#warning {background-color: #555555};}
|
|
818 |
my_db_state1 = args_list[-2]
|
819 |
history = args_list[-1]
|
820 |
|
|
|
|
|
|
|
|
|
|
|
821 |
args_list = args_list[:-3] # only keep rest needed for evaluate()
|
822 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
823 |
if retry and history:
|
@@ -827,13 +858,19 @@ body.dark{#warning {background-color: #555555};}
|
|
827 |
args_list[eval_func_param_names.index('do_sample')] = True
|
828 |
if not history:
|
829 |
print("No history", flush=True)
|
830 |
-
history = [
|
831 |
yield history, ''
|
832 |
return
|
833 |
# ensure output will be unique to models
|
834 |
_, _, _, max_prompt_length = get_cutoffs(is_low_mem, for_context=True)
|
835 |
history = copy.deepcopy(history)
|
836 |
instruction1 = history[-1][0]
|
|
|
|
|
|
|
|
|
|
|
|
|
837 |
context1 = ''
|
838 |
if max_prompt_length is not None and langchain_mode1 not in ['LLM']:
|
839 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
@@ -867,10 +904,6 @@ body.dark{#warning {background-color: #555555};}
|
|
867 |
context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
|
868 |
args_list[0] = instruction1 # override original instruction with history from user
|
869 |
args_list[2] = context1
|
870 |
-
if model_state1[0] is None or model_state1[0] == no_model_str:
|
871 |
-
history = [['', None]]
|
872 |
-
yield history, ''
|
873 |
-
return
|
874 |
fun1 = partial(evaluate,
|
875 |
model_state1,
|
876 |
my_db_state1,
|
@@ -1086,10 +1119,14 @@ body.dark{#warning {background-color: #555555};}
|
|
1086 |
api_name='export_chats' if allow_api else None)
|
1087 |
|
1088 |
def add_chats_from_file(file, chat_state1, add_btn):
|
|
|
|
|
1089 |
if isinstance(file, str):
|
1090 |
files = [file]
|
1091 |
else:
|
1092 |
files = file
|
|
|
|
|
1093 |
for file1 in files:
|
1094 |
try:
|
1095 |
if hasattr(file1, 'name'):
|
@@ -1350,22 +1387,28 @@ def get_inputs_list(inputs_dict, model_lower):
|
|
1350 |
def get_sources(db1, langchain_mode, dbs=None):
|
1351 |
if langchain_mode in ['ChatLLM', 'LLM']:
|
1352 |
source_files_added = "NA"
|
|
|
1353 |
elif langchain_mode in ['wiki_full']:
|
1354 |
source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
|
1355 |
" Ask jon.mckinney@h2o.ai for file if required."
|
|
|
1356 |
elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
|
1357 |
db_get = db1[0].get()
|
1358 |
-
|
|
|
1359 |
elif langchain_mode in dbs and dbs[langchain_mode] is not None:
|
1360 |
db1 = dbs[langchain_mode]
|
1361 |
db_get = db1.get()
|
1362 |
-
|
|
|
1363 |
else:
|
|
|
1364 |
source_files_added = "None"
|
1365 |
sources_file = 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))
|
1366 |
with open(sources_file, "wt") as f:
|
1367 |
f.write(source_files_added)
|
1368 |
-
|
|
|
1369 |
|
1370 |
|
1371 |
def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData', **kwargs):
|
@@ -1465,6 +1508,20 @@ def _update_user_db(file, db1, x, y, dbs=None, db_type=None, langchain_mode='Use
|
|
1465 |
return x, y, source_files_added
|
1466 |
|
1467 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1468 |
def get_source_files(db, exceptions=None):
|
1469 |
if exceptions is None:
|
1470 |
exceptions = []
|
|
|
96 |
css_code = """footer {visibility: hidden}"""
|
97 |
css_code += """
|
98 |
body.dark{#warning {background-color: #555555};}
|
99 |
+
#small_btn {
|
100 |
+
margin: 0.6em 0em 0.55em 0;
|
101 |
+
max-width: 20em;
|
102 |
+
min-width: 5em !important;
|
103 |
+
height: 5em;
|
104 |
+
font-size: 14px !important
|
105 |
+
}"""
|
106 |
|
107 |
if kwargs['gradio_avoid_processing_markdown']:
|
108 |
from gradio_client import utils as client_utils
|
|
|
173 |
lora_options_state = gr.State([lora_options])
|
174 |
my_db_state = gr.State([None, None])
|
175 |
chat_state = gr.State({})
|
176 |
+
docs_state = gr.State(['All'])
|
177 |
gr.Markdown(f"""
|
178 |
{get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
|
179 |
|
|
|
182 |
""")
|
183 |
if is_hf:
|
184 |
gr.HTML(
|
185 |
+
)
|
186 |
|
187 |
# go button visible if
|
188 |
base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
|
|
|
227 |
submit = gr.Button(value='Submit').style(full_width=False, size='sm')
|
228 |
stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
|
229 |
with gr.Row():
|
230 |
+
clear = gr.Button("Save Chat / New Chat")
|
231 |
flag_btn = gr.Button("Flag")
|
232 |
if not kwargs['auto_score']: # FIXME: For checkbox model2
|
233 |
with gr.Column(visible=kwargs['score_model']):
|
|
|
258 |
radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
|
259 |
type='value')
|
260 |
with gr.Row():
|
|
|
261 |
clear_chat_btn = gr.Button(value="Clear Chat", visible=True)
|
262 |
+
export_chats_btn = gr.Button(value="Export Chats to Download")
|
263 |
+
remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True)
|
264 |
+
add_to_chats_btn = gr.Button("Import Chats from Upload")
|
265 |
+
with gr.Row():
|
266 |
+
chats_file = gr.File(interactive=False, label="Download Exported Chats")
|
|
|
267 |
chatsup_output = gr.File(label="Upload Chat File(s)",
|
268 |
file_types=['.json'],
|
269 |
file_count='multiple',
|
270 |
elem_id="warning", elem_classes="feedback")
|
|
|
271 |
with gr.TabItem("Data Source"):
|
272 |
langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/README_LangChain.md',
|
273 |
from_str=True)
|
|
|
279 |
<p>
|
280 |
For more options see: {langchain_readme}""",
|
281 |
visible=kwargs['langchain_mode'] == 'Disabled', interactive=False)
|
282 |
+
data_row1 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
|
283 |
+
with data_row1:
|
284 |
if is_hf:
|
285 |
# don't show 'wiki' since only usually useful for internal testing at moment
|
286 |
no_show_modes = ['Disabled', 'wiki']
|
|
|
296 |
langchain_mode = gr.Radio(
|
297 |
[x for x in langchain_modes if x in allowed_modes and x not in no_show_modes],
|
298 |
value=kwargs['langchain_mode'],
|
299 |
+
label="Data Collection of Sources",
|
300 |
visible=kwargs['langchain_mode'] != 'Disabled')
|
301 |
+
data_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
|
302 |
+
with data_row2:
|
303 |
+
with gr.Column(scale=50):
|
304 |
+
document_choice = gr.Dropdown(docs_state.value,
|
305 |
+
label="Choose Subset of Doc(s) in Collection [click get to update]",
|
306 |
+
value=docs_state.value[0],
|
307 |
+
interactive=True,
|
308 |
+
multiselect=True,
|
309 |
+
)
|
310 |
+
with gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list):
|
311 |
+
get_sources_btn = gr.Button(value="Get Sources",
|
312 |
+
).style(full_width=False, size='sm')
|
313 |
+
show_sources_btn = gr.Button(value="Show Sources",
|
314 |
+
).style(full_width=False, size='sm')
|
315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
# import control
|
317 |
if kwargs['langchain_mode'] != 'Disabled':
|
318 |
from gpt_langchain import file_types, have_arxiv
|
319 |
else:
|
320 |
have_arxiv = False
|
321 |
file_types = []
|
322 |
+
|
323 |
+
upload_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload).style(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
equal_height=False)
|
325 |
+
with upload_row:
|
|
|
|
|
326 |
with gr.Column():
|
327 |
+
file_types_str = '[' + ' '.join(file_types) + ']'
|
328 |
+
fileup_output = gr.File(label=f'Upload {file_types_str}',
|
329 |
+
file_types=file_types,
|
330 |
+
file_count="multiple",
|
331 |
+
elem_id="warning", elem_classes="feedback")
|
332 |
+
with gr.Row():
|
333 |
+
add_to_shared_db_btn = gr.Button("Add File(s) to UserData",
|
334 |
+
visible=allow_upload_to_user_data, elem_id='small_btn')
|
335 |
+
add_to_my_db_btn = gr.Button("Add File(s) to Scratch MyData",
|
336 |
+
visible=allow_upload_to_my_data,
|
337 |
+
elem_id='small_btn' if allow_upload_to_user_data else None,
|
338 |
+
).style(
|
339 |
+
size='sm' if not allow_upload_to_user_data else None)
|
340 |
+
with gr.Column(
|
341 |
+
visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload):
|
342 |
+
url_label = 'URL (http/https) or ArXiv:' if have_arxiv else 'URL (http/https)'
|
343 |
+
url_text = gr.Textbox(label=url_label, interactive=True)
|
344 |
+
with gr.Row():
|
345 |
+
url_user_btn = gr.Button(value='Add URL content to Shared UserData',
|
346 |
+
visible=allow_upload_to_user_data, elem_id='small_btn')
|
347 |
+
url_my_btn = gr.Button(value='Add URL content to Scratch MyData',
|
348 |
+
visible=allow_upload_to_my_data,
|
349 |
+
elem_id='small_btn' if allow_upload_to_user_data else None,
|
350 |
+
).style(size='sm' if not allow_upload_to_user_data else None)
|
351 |
+
with gr.Column(
|
352 |
+
visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload):
|
353 |
+
user_text_text = gr.Textbox(label='Paste Text [Shift-Enter more lines]', interactive=True)
|
354 |
+
with gr.Row():
|
355 |
+
user_text_user_btn = gr.Button(value='Add Text to Shared UserData',
|
356 |
+
visible=allow_upload_to_user_data,
|
357 |
+
elem_id='small_btn')
|
358 |
+
user_text_my_btn = gr.Button(value='Add Text to Scratch MyData',
|
359 |
+
visible=allow_upload_to_my_data,
|
360 |
+
elem_id='small_btn' if allow_upload_to_user_data else None,
|
361 |
+
).style(
|
362 |
+
size='sm' if not allow_upload_to_user_data else None)
|
363 |
+
with gr.Column(visible=False):
|
364 |
+
# WIP:
|
365 |
+
with gr.Row(visible=False).style(equal_height=False):
|
366 |
+
github_textbox = gr.Textbox(label="Github URL")
|
367 |
+
with gr.Row(visible=True):
|
368 |
+
github_shared_btn = gr.Button(value="Add Github to Shared UserData",
|
369 |
+
visible=allow_upload_to_user_data,
|
370 |
+
elem_id='small_btn')
|
371 |
+
github_my_btn = gr.Button(value="Add Github to Scratch MyData",
|
372 |
+
visible=allow_upload_to_my_data, elem_id='small_btn')
|
373 |
+
sources_row3 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
|
374 |
equal_height=False)
|
375 |
+
with sources_row3:
|
376 |
+
with gr.Column(scale=1):
|
377 |
+
file_source = gr.File(interactive=False,
|
378 |
+
label="Download File with Sources [click get to make file]")
|
379 |
+
with gr.Column(scale=2):
|
380 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
|
382 |
equal_height=False)
|
383 |
with sources_row:
|
384 |
sources_text = gr.HTML(label='Sources Added', interactive=False)
|
|
|
|
|
|
|
|
|
|
|
385 |
|
386 |
with gr.TabItem("Expert"):
|
387 |
with gr.Row():
|
|
|
564 |
def make_visible():
|
565 |
return gr.update(visible=True)
|
566 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
567 |
# Add to UserData
|
568 |
update_user_db_func = functools.partial(update_user_db, dbs=dbs, db_type=db_type, langchain_mode='UserData',
|
569 |
use_openai_embedding=use_openai_embedding,
|
|
|
634 |
.then(clear_textbox, outputs=user_text_text, queue=queue)
|
635 |
|
636 |
get_sources1 = functools.partial(get_sources, dbs=dbs)
|
637 |
+
|
638 |
+
# if change collection source, must clear doc selections from it to avoid inconsistency
|
639 |
+
def clear_doc_choice():
|
640 |
+
return gr.Dropdown.update(choices=['All'], value=['All'])
|
641 |
+
|
642 |
+
langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice)
|
643 |
+
|
644 |
+
def update_dropdown(x):
|
645 |
+
return gr.Dropdown.update(choices=x, value='All')
|
646 |
+
|
647 |
+
show_sources1 = functools.partial(get_source_files_given_langchain_mode, dbs=dbs)
|
648 |
+
get_sources_btn.click(get_sources1, inputs=[my_db_state, langchain_mode], outputs=[file_source, docs_state],
|
649 |
+
queue=queue,
|
650 |
+
api_name='get_sources' if allow_api else None) \
|
651 |
+
.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice)
|
652 |
+
# show button, else only show when add. Could add to above get_sources for download/dropdown, but bit much maybe
|
653 |
+
show_sources_btn.click(fn=show_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text)
|
654 |
|
655 |
def check_admin_pass(x):
|
656 |
return gr.update(visible=x == admin_pass)
|
|
|
844 |
my_db_state1 = args_list[-2]
|
845 |
history = args_list[-1]
|
846 |
|
847 |
+
if model_state1[0] is None or model_state1[0] == no_model_str:
|
848 |
+
history = []
|
849 |
+
yield history, ''
|
850 |
+
return
|
851 |
+
|
852 |
args_list = args_list[:-3] # only keep rest needed for evaluate()
|
853 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
854 |
if retry and history:
|
|
|
858 |
args_list[eval_func_param_names.index('do_sample')] = True
|
859 |
if not history:
|
860 |
print("No history", flush=True)
|
861 |
+
history = []
|
862 |
yield history, ''
|
863 |
return
|
864 |
# ensure output will be unique to models
|
865 |
_, _, _, max_prompt_length = get_cutoffs(is_low_mem, for_context=True)
|
866 |
history = copy.deepcopy(history)
|
867 |
instruction1 = history[-1][0]
|
868 |
+
if not instruction1:
|
869 |
+
# reject empty query, can sometimes go nuts
|
870 |
+
history = []
|
871 |
+
yield history, ''
|
872 |
+
return
|
873 |
+
|
874 |
context1 = ''
|
875 |
if max_prompt_length is not None and langchain_mode1 not in ['LLM']:
|
876 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
|
|
904 |
context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
|
905 |
args_list[0] = instruction1 # override original instruction with history from user
|
906 |
args_list[2] = context1
|
|
|
|
|
|
|
|
|
907 |
fun1 = partial(evaluate,
|
908 |
model_state1,
|
909 |
my_db_state1,
|
|
|
1119 |
api_name='export_chats' if allow_api else None)
|
1120 |
|
1121 |
def add_chats_from_file(file, chat_state1, add_btn):
|
1122 |
+
if not file:
|
1123 |
+
return chat_state1, add_btn
|
1124 |
if isinstance(file, str):
|
1125 |
files = [file]
|
1126 |
else:
|
1127 |
files = file
|
1128 |
+
if not files:
|
1129 |
+
return chat_state1, add_btn
|
1130 |
for file1 in files:
|
1131 |
try:
|
1132 |
if hasattr(file1, 'name'):
|
|
|
1387 |
def get_sources(db1, langchain_mode, dbs=None):
|
1388 |
if langchain_mode in ['ChatLLM', 'LLM']:
|
1389 |
source_files_added = "NA"
|
1390 |
+
source_list = []
|
1391 |
elif langchain_mode in ['wiki_full']:
|
1392 |
source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
|
1393 |
" Ask jon.mckinney@h2o.ai for file if required."
|
1394 |
+
source_list = []
|
1395 |
elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
|
1396 |
db_get = db1[0].get()
|
1397 |
+
source_list = sorted(set([x['source'] for x in db_get['metadatas']]))
|
1398 |
+
source_files_added = '\n'.join(source_list)
|
1399 |
elif langchain_mode in dbs and dbs[langchain_mode] is not None:
|
1400 |
db1 = dbs[langchain_mode]
|
1401 |
db_get = db1.get()
|
1402 |
+
source_list = sorted(set([x['source'] for x in db_get['metadatas']]))
|
1403 |
+
source_files_added = '\n'.join(source_list)
|
1404 |
else:
|
1405 |
+
source_list = []
|
1406 |
source_files_added = "None"
|
1407 |
sources_file = 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))
|
1408 |
with open(sources_file, "wt") as f:
|
1409 |
f.write(source_files_added)
|
1410 |
+
source_list = ['All'] + source_list
|
1411 |
+
return sources_file, source_list
|
1412 |
|
1413 |
|
1414 |
def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData', **kwargs):
|
|
|
1508 |
return x, y, source_files_added
|
1509 |
|
1510 |
|
1511 |
+
def get_source_files_given_langchain_mode(db1, langchain_mode='UserData', dbs=None):
|
1512 |
+
with filelock.FileLock("db_%s.lock" % langchain_mode.replace(' ', '_')):
|
1513 |
+
if langchain_mode in ['wiki_full']:
|
1514 |
+
# NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
|
1515 |
+
db = None
|
1516 |
+
elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
|
1517 |
+
db = db1[0]
|
1518 |
+
elif langchain_mode in dbs and dbs[langchain_mode] is not None:
|
1519 |
+
db = dbs[langchain_mode]
|
1520 |
+
else:
|
1521 |
+
db = None
|
1522 |
+
return get_source_files(db, exceptions=None)
|
1523 |
+
|
1524 |
+
|
1525 |
def get_source_files(db, exceptions=None):
|
1526 |
if exceptions is None:
|
1527 |
exceptions = []
|
prompter.py
CHANGED
@@ -56,6 +56,8 @@ prompt_type_to_model_name = {
|
|
56 |
'h2oai/h2ogpt-oasst1-512-20b',
|
57 |
'h2oai/h2ogpt-oig-oasst1-256-6_9b',
|
58 |
'h2oai/h2ogpt-oig-oasst1-512-6_9b',
|
|
|
|
|
59 |
'h2oai/h2ogpt-research-oasst1-512-30b', # private
|
60 |
],
|
61 |
'dai_faq': [],
|
|
|
56 |
'h2oai/h2ogpt-oasst1-512-20b',
|
57 |
'h2oai/h2ogpt-oig-oasst1-256-6_9b',
|
58 |
'h2oai/h2ogpt-oig-oasst1-512-6_9b',
|
59 |
+
'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
|
60 |
+
'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
|
61 |
'h2oai/h2ogpt-research-oasst1-512-30b', # private
|
62 |
],
|
63 |
'dai_faq': [],
|
utils.py
CHANGED
@@ -148,7 +148,7 @@ def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
|
|
148 |
host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
|
149 |
zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
|
150 |
assert root_dirs is not None
|
151 |
-
if not os.path.isdir(os.path.dirname(zip_file)):
|
152 |
os.makedirs(os.path.dirname(zip_file), exist_ok=True)
|
153 |
with zipfile.ZipFile(zip_file, "w") as expt_zip:
|
154 |
for root_dir in root_dirs:
|
|
|
148 |
host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
|
149 |
zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
|
150 |
assert root_dirs is not None
|
151 |
+
if not os.path.isdir(os.path.dirname(zip_file)) and os.path.dirname(zip_file):
|
152 |
os.makedirs(os.path.dirname(zip_file), exist_ok=True)
|
153 |
with zipfile.ZipFile(zip_file, "w") as expt_zip:
|
154 |
for root_dir in root_dirs:
|